import os
import time
import torch

IMAGE_CONFIG = {
    'DATASET_FOLDER': './train_images_sample',          
    'CONTINUE_TRAIN': True,               
    'CONTINUE_TRAIN_MODEL': "best_trained_model.pth", 
    'ADDITIONAL_EPOCHS': 20,                
    
    'folder_label_mapping': {
        'inactive_0': 0,  
        'active_1': 1,   
    },
    
    'optimal_parameters': {
        'learning_rate': 5e-5,            
        'weight_decay': 0.01,             
        'dropout': 0.1                     
    },
    
    'model_parameters': {
        'num_classes': 2,                   
        'epochs': 20,                       
        'batch_size': 32,                  
        'image_size': 224                   
    },
    
    'num_models': 1,                       
    'device': 'auto',                      
    'max_workers': None,                   
    'save_model_prefix': 'image_model',     
}

def setup_image_config(config):
    from molactivity import D1_config
    
    D1_config.data_dir = config['DATASET_FOLDER']
    D1_config.batch_size = config['model_parameters']['batch_size']
    D1_config.learning_rate = config['optimal_parameters']['learning_rate']
    D1_config.num_epochs = config['model_parameters']['epochs']
    D1_config.num_classes = config['model_parameters']['num_classes']
    D1_config.weight_decay = config['optimal_parameters']['weight_decay']
    D1_config.dropout = config['optimal_parameters']['dropout']
    
    D1_config.folder_label_mapping = config.get('folder_label_mapping', None)
    
    return D1_config

def load_model_for_continue_training(model, model_file, device):
    try:
        if os.path.exists(model_file):
            model.load_state_dict(torch.load(model_file, map_location=device))
            print("Model loaded successfully for continue training")
            return True
        else:
            print(f"Model file {model_file} not found")
            return False
    except Exception as e:
        print(f"Failed to load model: {str(e)}")
        return False

def image_continue_train(model_file, additional_epochs, config, new_model_suffix='_continued'):
    print(f"Starting image continue training using: {model_file}")
    
    from molactivity.D2_data_loader import load_datasets, get_data_loaders
    from molactivity.D3_model import VisionTrans_model
    from molactivity.D5_train import train_model_improved_continue
    
    if config['device'] == 'auto':
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(config['device'])
    
    print(f"Using device: {device}")
    
    D1_config = setup_image_config(config)
    
    base_name = model_file.rsplit('.', 1)[0]
    extension = model_file.rsplit('.', 1)[1] if '.' in model_file else 'pth'
    new_model_file = f"{base_name}{new_model_suffix}.{extension}"
    
    D1_config.model_save_path = new_model_file
    
    train_dataset, val_dataset = load_datasets(D1_config.data_dir, D1_config.folder_label_mapping)
    train_loader, val_loader = get_data_loaders(train_dataset, val_dataset, D1_config.batch_size)
    
    model = VisionTrans_model(D1_config.num_classes)
        
    original_epochs = D1_config.num_epochs
    D1_config.num_epochs = additional_epochs
            
    train_losses, val_losses, val_accuracies = train_model_improved_continue(
        model, train_loader, val_loader, D1_config.num_epochs, D1_config.learning_rate, is_continue_training=True
    )
    
    D1_config.num_epochs = original_epochs
    
    if os.path.exists(new_model_file):
        return model, new_model_file
    else:
        print("Continue training completed but no model was saved")
        return None


def apply_image_configurations():
    try:
        
        def configured_image_training():
            config = IMAGE_CONFIG
            
            if not os.path.exists(config['DATASET_FOLDER']):
                print(f"Error: Dataset folder '{config['DATASET_FOLDER']}' not found!")
                return
            
            if config['CONTINUE_TRAIN']:
                result = image_continue_train(
                    model_file=config['CONTINUE_TRAIN_MODEL'],
                    additional_epochs=config['ADDITIONAL_EPOCHS'],
                    config=config,
                    new_model_suffix='_continued'
                )
                
                if result is not None:
                    trained_model, new_model_file = result
                else:
                    print("Continue training failed")
            else:

                    trained_models = []
                    
                    for model_idx in range(config['num_models']):
                        print(f'\n--- Training model {model_idx+1}/{config["num_models"]} ---')
                        
                        D1_config = setup_image_config(config)
                        
                        if config['num_models'] > 1:
                            D1_config.model_save_path = f"{config['save_model_prefix']}_{model_idx+1}.pth"
                        
                        from molactivity.D5_train import main as original_main
                        original_main()
                        
                        trained_models.append(D1_config.model_save_path)                    
        
        return configured_image_training
        
    except ImportError as e:
        print(f"Module import error: {e}")
        return None
    except Exception as e:
        print(f"Configuration error: {e}")
        return None

def run_image_training():
    config = IMAGE_CONFIG

    print(f"Dataset Folder: {config['DATASET_FOLDER']}")
    print(f"Continue Training: {config['CONTINUE_TRAIN']}")
    if config['CONTINUE_TRAIN']:
        print(f"Continue Model: {config['CONTINUE_TRAIN_MODEL']}")
        print(f"Additional Epochs: {config['ADDITIONAL_EPOCHS']}")
    else:
        print(f"Number of Models: {config['num_models']}")
    
    if 'folder_label_mapping' in config and config['folder_label_mapping']:
        print("\nFolder Label Mapping:")
        for folder, label in config['folder_label_mapping'].items():
            print(f"  {folder} -> label {label}")
    else:
        print("\nUsing default alphabetical folder sorting for labels")
    
    configured_training = apply_image_configurations()
    
    if configured_training:
        print("Starting image training...")
        configured_training()
    else:
        print("Configuration failed, using original image training...")
        try:
            from molactivity import train_image
            train_image()
        except Exception as e:
            print(f"Training failed: {e}")

if __name__ == "__main__":
    start_time = time.time()
    
    run_image_training()
    
    end_time = time.time()
    total_time = end_time - start_time
    print(f"\nTime used for image training: {total_time:.2f} seconds")