import os
import time
import torch
from molactivity.D6_evaluate_tools import (find_model_files,load_multiple_models,
                                           get_evaluation_data_loader,
                                           generate_ensemble_predictions,
                                           analyze_prediction_quality,
                                           save_evaluation_results)

DATASET_FOLDER = './train_images_sample'
IMAGE_EVAL_CONFIG = {
    'folder_label_mapping': {
        'inactive_0': 0,  
        'active_1': 1,    
    },
    
    'evaluation_parameters': {
        'device': 'auto',
        'model_files': 'auto',  
        'ensemble': True,
        'batch_size': 32,
        'image_size': 224,
        'num_classes': 2,
    },
    
    'output_parameters': {
        'output_file': 'image_evaluation_results.csv',
        'analyze_quality': True,
        'save_individual_predictions': True,
    },
    
    'analysis_parameters': {
        'show_top_predictions': 10,
        'show_confusion_matrix': True,
        'calculate_metrics': True,
    }
}

def run_image_evaluation():
    config = IMAGE_EVAL_CONFIG
    
    print("\nConfiguration:")
    print(f"Dataset folder: {DATASET_FOLDER}")
    print(f"Batch size: {config['evaluation_parameters']['batch_size']}")
    print(f"Output file: {config['output_parameters']['output_file']}")
    
    if 'folder_label_mapping' in config and config['folder_label_mapping']:
        print("\nFolder Label Mapping:")
        for folder, label in config['folder_label_mapping'].items():
            print(f"  {folder} -> label {label}")
    else:
        print("\nUsing default alphabetical folder sorting for labels")
    
    if config['evaluation_parameters']['device'] == 'auto':
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(config['evaluation_parameters']['device'])
    
    print(f"Using device: {device}")
    
    try:
        
        model_files = find_model_files(config['evaluation_parameters']['model_files'])
        
        if not model_files:
            print("No model files found for evaluation!")
            return False
        
        print(f"\nFound {len(model_files)} model file(s):")
        for i, model_file in enumerate(model_files):
            file_size = os.path.getsize(model_file) / (1024*1024)
            print(f"  {i+1}. {model_file} ({file_size:.1f} MB)")
        
        models, successful_model_files = load_multiple_models(
            model_files, 
            config['evaluation_parameters']['num_classes'], 
            device
        )
        
        if not models:
            print("Failed to load any models!")
            return False
        
        test_loader = get_evaluation_data_loader(
            DATASET_FOLDER, 
            config['evaluation_parameters']['batch_size'],
            config.get('folder_label_mapping', None)
        )
        
        if test_loader is None:
            return False
                
        true_labels, probabilities, std_devs, individual_predictions = generate_ensemble_predictions(
            models, successful_model_files, test_loader, device
        )
        
        if config['output_parameters']['analyze_quality']:
            image_files = getattr(test_loader, 'image_files', None)
            analyze_prediction_quality(
                true_labels, probabilities, std_devs, config, image_files
            )
        
        save_success = save_evaluation_results(
            true_labels, probabilities, std_devs, 
            individual_predictions, config['output_parameters']['output_file'], test_loader
        )
        
        if save_success:
            print("\nEvaluation completed successfully!")
        else:
            print("\nEvaluation completed but failed to save results.")
        
        return True
        
    except Exception as e:
        print(f"Evaluation failed: {str(e)}")
        return False

if __name__ == "__main__":
    start_time = time.time()
    
    run_image_evaluation()
    
    end_time = time.time()
    total_time = end_time - start_time
    
    print(f"\nTotal evaluation time: {total_time:.2f} seconds")
