
import time
import torch
DATASET_FILE = 'train_sample.csv'

ROCKET_CONFIG = {
    'PARALLEL_TRAINING': True,
    'CONTINUE_TRAIN': False,
    'CONTINUE_TRAIN_MODEL': "model_1.pt",
    'ADDITIONAL_EPOCHS': 3,
    
    'optimal_parameters': {
        'learning_rate': 0.001,
        'transformer_depth': 6,
        'attention_heads': 8,
        'hidden_dimension': 2048
    },
    
    'model_parameters': {
        'input_features': 2048,
        'embedding_size': 512,
        'epochs': 2,
        'batch_size': 32
    },
    
    'num_networks': 1,
    'max_workers': None,
    'device': 'auto'
}

def apply_rocket_configurations():
    try:
        
        def configured_rocket_training():
            config = ROCKET_CONFIG
            
            if config['device'] == 'auto':
                compute_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            else:
                compute_device = torch.device(config['device'])
            
            optimal_parameters = config['optimal_parameters']
            
            if config['CONTINUE_TRAIN']:
                print(f"Loading model for continue training: {config['CONTINUE_TRAIN_MODEL']}")
                from molactivity.C3_utils import prepare_training_dataset
                data_handler = prepare_training_dataset(
                    DATASET_FILE, 
                    fingerprint_type='Morgan', 
                    batch_size=config['model_parameters']['batch_size'], 
                    shuffle=False, 
                    balance_data=True
                )
                from molactivity.C5_train_tools import rocket_continue_train
                result = rocket_continue_train(
                    model_file=config['CONTINUE_TRAIN_MODEL'],
                    data_handler=data_handler,
                    additional_epochs=config['ADDITIONAL_EPOCHS'],
                    optimal_parameters=optimal_parameters,
                    compute_device=compute_device,
                    new_model_suffix='_continued'
                )
                
                if result is not None:
                    trained_network, new_model_file = result
                else:
                    print("Continue training failed")
            else:
                if config['PARALLEL_TRAINING'] and config['num_networks'] > 1:
                    try:
                        from molactivity.C5_train_tools import rocket_parallel_training
                        trained_model_files = rocket_parallel_training(
                            num_networks=config['num_networks'],
                            optimal_parameters=optimal_parameters,
                            compute_device=compute_device,
                            dataset_file=DATASET_FILE,
                            max_workers=config['max_workers'],
                            epochs=config['model_parameters']['epochs'],
                            batch_size=config['model_parameters']['batch_size']
                        )
                        
                        print(f"Model files: {trained_model_files}")
                        
                    except Exception as e:
                        print(f"Parallel training failed: {e}")
                        print("Falling back to sequential training...")
                        config['PARALLEL_TRAINING'] = False
                
                if not config['PARALLEL_TRAINING'] or config['num_networks'] == 1:
                    print("Using rocket sequential training")
                    from molactivity.C3_utils import prepare_training_dataset
                    from molactivity.C2_train import initialize_network_and_optimizer, conduct_individual_training
                    
                    data_handler = prepare_training_dataset(
                        DATASET_FILE, 
                        fingerprint_type='Morgan', 
                        batch_size=config['model_parameters']['batch_size'], 
                        shuffle=False, 
                        balance_data=True
                    )
                    
                    trained_networks = []
                    
                    for network_idx in range(config['num_networks']):
                        network, optimizer = initialize_network_and_optimizer(compute_device, optimal_parameters)
                        
                        trained_network = conduct_individual_training(
                            network, data_handler, compute_device, optimizer, 0, 
                            config['model_parameters']['epochs'], network_idx
                        )
                        trained_networks.append(trained_network)
                    
                    print(f"Successfully trained {len(trained_networks)} rocket models")
                    print(f"Model files: {[f'model_{i+1}.pt' for i in range(len(trained_networks))]}")
        
        return configured_rocket_training
        
    except ImportError as e:
        print(f"Module import error: {e}")
        return None
    except Exception as e:
        print(f"Configuration error: {e}")
        return None

def run_rocket_training():
    
    config = ROCKET_CONFIG
    print("CONFIGURATION:")
    print(f"Dataset: {DATASET_FILE}")
    print(f"Parallel Training: {config['PARALLEL_TRAINING']}")
    print(f"Continue Training: {config['CONTINUE_TRAIN']}")
    if config['CONTINUE_TRAIN']:
        print(f"Continue Model: {config['CONTINUE_TRAIN_MODEL']}")
        print(f"Additional Epochs: {config['ADDITIONAL_EPOCHS']}")
    else:
        print(f"Number of Networks: {config['num_networks']}")
        if config['PARALLEL_TRAINING'] and config['max_workers']:
            print(f"Max Workers: {config['max_workers']}")
        else:
            print("Max Workers: auto")
    print(f"Device: {config['device']}")
    print(f"Learning Rate: {config['optimal_parameters']['learning_rate']}")
    print(f"Transformer Depth: {config['optimal_parameters']['transformer_depth']}")
    print(f"Attention Heads: {config['optimal_parameters']['attention_heads']}")
    print(f"Hidden Dimension: {config['optimal_parameters']['hidden_dimension']}")
    print(f"Input Features: {config['model_parameters']['input_features']}")
    print(f"Embedding Size: {config['model_parameters']['embedding_size']}")
    print(f"Epochs: {config['model_parameters']['epochs']}")
    print(f"Batch Size: {config['model_parameters']['batch_size']}")
    
    configured_training = apply_rocket_configurations()
    
    configured_training()


if __name__ == "__main__":
    start_time = time.time()
    
    run_rocket_training()
    
    end_time = time.time()
    total_time = end_time - start_time
    print(f"\nTime used for rocket training: {total_time:.2f} seconds")