import time

DATASET_FILE = 'train_sample.csv'  

STANDARD_CONFIG = {
    'PARALLEL_TRAINING': False,      
    'CONTINUE_TRAIN': False,      
    'CONTINUE_TRAIN_MODEL': "model_1.dict",  
    'ADDITIONAL_EPOCHS': 1,      
    
    'optimal_parameters': {
        'learning_rate': 0.001,    
        'transformer_depth': 2,      
        'attention_heads': 2,      
        'hidden_dimension': 64    
    },
    
    'model_parameters': {
        'input_features': 2048,
        'embedding_size': 128,         
        'epochs': 2,                   
        'batch_size': 32              
    },
    
    'num_networks': 2,             
    'activation': 'relu',         
    'device': 'cpu',               
    'max_workers': None,       
}

def apply_configurations():
    
    
    try:
        from molactivity import A28_train
        
        def configured_training():
            
            config = STANDARD_CONFIG

            print(f"dataset: {DATASET_FILE}")

            data_handler = A28_train.prepare_pure_training_dataset(
                DATASET_FILE, 
                fingerprint_type='Morgan', 
                batch_size=config['model_parameters']['batch_size'], 
                shuffle=False, 
                balance_data=True
            )
            
            optimal_parameters = config['optimal_parameters']
            
            if config['CONTINUE_TRAIN']:
                print(f"loading model: {config['CONTINUE_TRAIN_MODEL']}")
                result = A28_train.continue_train(
                    model_file=config['CONTINUE_TRAIN_MODEL'],
                    data_handler=data_handler,
                    additional_epochs=config['ADDITIONAL_EPOCHS'],
                    activation=config['activation'],
                    optimal_parameters=optimal_parameters,
                    new_model_suffix='_continued'
                )
                
                if result is not None:
                    trained_network, new_model_file = result
                else:
                    print("failed")
            else:
                if (config['PARALLEL_TRAINING']) and config['num_networks'] > 1:
                    print("using parallel training")
                    try:
                        trained_model_files = A28_train.parallel_training(
                        num_networks=config['num_networks'],
                        optimal_parameters=optimal_parameters,
                        activation=config['activation'],
                        dataset_file=DATASET_FILE,
                        max_workers=config['max_workers'],
                        input_features=config['model_parameters']['input_features'],
                        embedding_size=config['model_parameters']['embedding_size'],
                        epochs=config['model_parameters']['epochs'],
                        batch_size=config['model_parameters']['batch_size']
                    )
                        
                        print(f"successfully trained {len(trained_model_files)} models")
                        print(f"model file: {trained_model_files}") 
                        
                    except ImportError as e:
                        print(f"parallel training failed: {e}")
                        config['parallel'] = False
                    except Exception as e:
                        print(f"parallel training failed: {e}")
                        config['parallel'] = False
                
                if not (config['PARALLEL_TRAINING']) or config['num_networks'] == 1:
                    
                    trained_networks = []
                        
                    for network_idx in range(config['num_networks']):
                        print(f'\n--- training model {network_idx+1} ---')
                        network, optimizer = A28_train.initialize_network_and_optimizer(
                                optimal_parameters, 
                                config['activation'],
                                config['model_parameters']['input_features'],
                                config['model_parameters']['embedding_size']
                            )
                            
                        trained_network = A28_train.conduct_individual_training(
                                network, data_handler, optimizer, 0, 2, network_idx, config['model_parameters']['epochs']
                            )
                        trained_networks.append(trained_network)

                    print(f"successfully trained {len(trained_networks)} models")
                    print(f"model file: {[f'model_{i+1}.dict' for i in range(len(trained_networks))]}")
        
        A28_train.training = configured_training
        
        return configured_training
        
    except ImportError as e:
        print(f" {e}")
        return None
    except Exception as e:
        print(f" {e}")
        return None

def run_training():
    config = STANDARD_CONFIG
    print("CURRENT_CONFIG:")
    print(f"activation: {config['activation']}")
    print(f"device: {config['device']}")
    print(f"parallel: {config['PARALLEL_TRAINING']}")
    print(f"continue:{config['CONTINUE_TRAIN']}")
    print(f"lr: {config['optimal_parameters']['learning_rate']}")
    print(f"transformer_depth: {config['optimal_parameters']['transformer_depth']}")
    print(f"attention_heads: {config['optimal_parameters']['attention_heads']}")
    print(f"hidden_dimension: {config['optimal_parameters']['hidden_dimension']}")
    print(f"input_features: {config['model_parameters']['input_features']}")
    print(f"embedding_size: {config['model_parameters']['embedding_size']}")
    print(f"epochs: {config['model_parameters']['epochs']}")
    print(f"batch_size: {config['model_parameters']['batch_size']}")
    
    configured_training = apply_configurations()
    
    if configured_training:
        configured_training()
    else:
        print("Configuration failed")


if __name__ == "__main__":
    start_time = time.time()
    
    run_training()

    end_time = time.time()
    total_time = end_time - start_time
    print(f"\ntime used: {total_time:.2f} seconds")