import os
import time
import torch
import numpy as np
import shutil
from molactivity.D7_predict_tools import (find_model_files,load_multiple_models, 
                                          get_prediction_data_loader,
                                          generate_ensemble_predictions,
                                          save_prediction_results)

DATASET_FOLDER = './predict_images_sample'
IMAGE_PREDICT_CONFIG = {
    'prediction_parameters': {
        'device': 'auto',
        'model_files': 'auto',  
        'ensemble': True,
        'batch_size': 32,
        'image_size': 224,
        'num_classes': 2,
    },
    
    'output_parameters': {
        'output_file': 'image_prediction_results.csv',
        'save_individual_predictions': True,
        'show_top_predictions': 10,
    },
}

def run_image_prediction():
    config = IMAGE_PREDICT_CONFIG
    
    print("\nConfiguration:")
    print(f"Dataset folder: {DATASET_FOLDER}")
    print(f"Batch size: {config['prediction_parameters']['batch_size']}")
    print(f"Output file: {config['output_parameters']['output_file']}")
    
    if config['prediction_parameters']['device'] == 'auto':
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(config['prediction_parameters']['device'])
    
    print(f"Using device: {device}")
    
    try:
        model_files = find_model_files(config['prediction_parameters']['model_files'])
        
        if not model_files:
            print("No model files found for prediction!")
            return False
        
        print(f"\nFound {len(model_files)} model file(s):")
        for i, model_file in enumerate(model_files):
            file_size = os.path.getsize(model_file) / (1024*1024)
            print(f"  {i+1}. {model_file} ({file_size:.1f} MB)")
        
        models, successful_model_files = load_multiple_models(
            model_files, 
            config['prediction_parameters']['num_classes'], 
            device
        )
        
        if not models:
            print("Failed to load any models!")
            return False
        
        prediction_loader = get_prediction_data_loader(
            DATASET_FOLDER, 
            config['prediction_parameters']['batch_size']
        )
        
        if prediction_loader is None:
            return False
                
        probabilities, std_devs, individual_predictions = generate_ensemble_predictions(
            models, successful_model_files, prediction_loader, device
        )
        
        original_image_files = getattr(prediction_loader, 'original_image_files', None)
        
        if hasattr(prediction_loader, 'temp_dir'):
            shutil.rmtree(prediction_loader.temp_dir, ignore_errors=True)
        
        top_n = config['output_parameters']['show_top_predictions']
        print(f"\nTop {top_n} Highest Probabilities:")
        sorted_indices = np.argsort(probabilities)[::-1]
        
        if original_image_files:
            print(f"{'Rank':<5} {'Image_Name':<30} {'Probability':<12}")
            print("-" * 55)
            for i in range(min(top_n, len(sorted_indices))):
                idx = sorted_indices[i]
                prob = probabilities[idx]
                image_name = os.path.basename(original_image_files[idx])
                print(f"{i+1:<5} {image_name:<30} {prob:<12.6f}")
        else:
            print(f"{'Rank':<5} {'Image_Index':<12} {'Probability':<12}")
            print("-" * 35)
            for i in range(min(top_n, len(sorted_indices))):
                idx = sorted_indices[i]
                prob = probabilities[idx]
                print(f"{i+1:<5} {idx+1:<12} {prob:<12.6f}")
        
        save_success = save_prediction_results(
            probabilities, std_devs, 
            individual_predictions, config['output_parameters']['output_file'], original_image_files
        )
        
        if save_success:
            print("\nPrediction completed successfully!")
        else:
            print("\nPrediction completed but failed to save results.")
        
        return True
        
    except Exception as e:
        if 'prediction_loader' in locals() and hasattr(prediction_loader, 'temp_dir'):
            shutil.rmtree(prediction_loader.temp_dir, ignore_errors=True)
        
        print(f"Prediction failed: {str(e)}")
        return False

if __name__ == "__main__":
    start_time = time.time()
    
    run_image_prediction()
    
    end_time = time.time()
    total_time = end_time - start_time
    
    print(f"\nTotal prediction time: {total_time:.2f} seconds")
