import time

DATASET_FILE = 'train_sample.csv'

FAST_CONFIG = {
    'PARALLEL_TRAINING': False,          
    'CONTINUE_TRAIN': False,             
    'CONTINUE_TRAIN_MODEL': "model_1.dict",  
    'ADDITIONAL_EPOCHS': 3,              
    
    'optimal_parameters': {
        'learning_rate': 0.001,         
        'transformer_depth': 6,          
        'attention_heads': 8,            
        'hidden_dimension': 512          
    },
    
    'model_parameters': {
        'input_features': 2048,          
        'embedding_size': 512,          
        'epochs': 1,                   
        'batch_size': 32              
    },
    
    'num_networks': 1,                 
    'activation': 'gelu',                
    'max_workers': None,                
}

def apply_fast_configurations():
    
    try:
        from molactivity import B17_train as train_fast
        
        def configured_fast_training():
            
            config = FAST_CONFIG
            
            data_handler = train_fast.prepare_pure_training_dataset(
                DATASET_FILE, 
                fingerprint_type='Morgan', 
                batch_size=config['model_parameters']['batch_size'], 
                shuffle=False, 
                balance_data=True
            )
            
            optimal_parameters = config['optimal_parameters']
            
            if config['CONTINUE_TRAIN']:
                try:
                    from molactivity.B18_train_further import fast_continue_train
                    result = fast_continue_train(
                        model_file=config['CONTINUE_TRAIN_MODEL'],
                        data_handler=data_handler,
                        additional_epochs=config['ADDITIONAL_EPOCHS'],
                        activation=config['activation'],
                        optimal_parameters=optimal_parameters,
                        new_model_suffix='_continued',
                        input_features=config['model_parameters']['input_features'],
                        embedding_size=config['model_parameters']['embedding_size']
                    )
                    
                    if result is not None:
                        trained_network, new_model_file = result
                        print(f"continue train complete: {new_model_file}")
                    else:
                        print("continue train failed")
                except:
                    pass
            else:
                
                if config['PARALLEL_TRAINING'] and config['num_networks'] > 1:
                    try:
                        from molactivity.B19_train_parallel import fast_parallel_training
                        
                        trained_model_files = fast_parallel_training(
                            num_networks=config['num_networks'],
                            optimal_parameters=optimal_parameters,
                            activation=config['activation'],
                            dataset_file=DATASET_FILE,
                            max_workers=config['max_workers'],
                            epochs=config['model_parameters']['epochs'],
                            batch_size=config['model_parameters']['batch_size'],
                            input_features=config['model_parameters']['input_features'],
                            embedding_size=config['model_parameters']['embedding_size']
                        )
                        
                        print("PARALLEL_TRAINING complete")
                        print(f"model: {trained_model_files}")
                        return  
                        
                    except ImportError as e:
                        print(f"PARALLEL_TRAINING failed: {e}")
                        print("train one by one...")
                    except Exception as e:
                        print(f"PARALLEL_TRAINING failed: {e}")
                        print("train one by one...")
                
                
                trained_networks = []
                
                for network_idx in range(config['num_networks']):
                    print(f'\n--- train model {network_idx+1}/{config["num_networks"]} ---')
                    network, optimizer = train_fast.initialize_network_and_optimizer(
                        optimal_parameters, 
                        config['activation'],
                        config['model_parameters']['input_features'],
                        config['model_parameters']['embedding_size']
                    )
                    
                    trained_network = train_fast.conduct_individual_training(
                        network, data_handler, optimizer, 0, 2, network_idx, config['model_parameters']['epochs']
                    )
                    trained_networks.append(trained_network)

                print(f"train {len(trained_networks)} models in total")
                print(f"model: {[f'model_{i+1}.dict' for i in range(len(trained_networks))]}")
        
        train_fast.training = configured_fast_training
        
        return configured_fast_training
        
    except:
        return None

def run_fast_training():
    config = FAST_CONFIG
    print("CONFIG:")
    print(f"DATASET: {DATASET_FILE}")
    print(f"PARALLEL_TRAINING: {config['PARALLEL_TRAINING']}")
    print(f"CONTINUE_TRAIN: {config['CONTINUE_TRAIN']}")
    if config['CONTINUE_TRAIN']:
        print(f"CONTINUE_TRAIN: {config['CONTINUE_TRAIN_MODEL']}")
        print(f"train epochs: {config['ADDITIONAL_EPOCHS']}")
    else:
        print(f"models: {config['num_networks']}")
        if config['PARALLEL_TRAINING'] and config['max_workers']:
            print(f"workers: {config['max_workers']}")
        else:
            print("workers: AUTO")
    print(f"activation: {config['activation']}")
    print(f"lr: {config['optimal_parameters']['learning_rate']}")
    print(f"Transformer depth: {config['optimal_parameters']['transformer_depth']}")
    print(f"attention_heads: {config['optimal_parameters']['attention_heads']}")
    print(f"attention_heads: {config['optimal_parameters']['attention_heads']}")
    print(f"input_features: {config['model_parameters']['input_features']}")
    print(f"embedding_size: {config['model_parameters']['embedding_size']}")
    print(f"epochs: {config['model_parameters']['epochs']}")
    print(f"batch_size: {config['model_parameters']['batch_size']}")
    
    configured_training = apply_fast_configurations()
    
    if configured_training:
        print("start training...")
        configured_training()
    else:
        print("Configuration failed")

if __name__ == "__main__":
    start_time = time.time()
    
    run_fast_training()

    end_time = time.time()
    total_time = end_time - start_time
    print(f"\ntime used: {total_time:.2f} seconds")