import time

DATASET_FILE = 'train_sample.csv'  

STANDARD_CONFIG = {
    'PARALLEL_EVALUATION': False,       
    'MAX_WORKERS': None,              
    'SAVE_INDIVIDUAL_PREDICTIONS': True, 
    
    'evaluation_parameters': {
        'device': 'cpu',            
        'model_file': 'model_1.dict', 
        'models': 'auto',             
        'ensemble': True,           
        'batch_size': 32,        
    },
    
    'output_parameters': {
        'output_file': 'evaluate_sample_with_predictions.csv',
        'analyze_quality': True,    
    }
}

def apply_configurations():
    try:
        from molactivity import A9_evaluate
        
        def configured_evaluation():
            config = STANDARD_CONFIG
            
            print(f"dataset: {DATASET_FILE}")
            print(f"parallel evaluation: {config['PARALLEL_EVALUATION']}")
            if config['PARALLEL_EVALUATION'] and config['MAX_WORKERS']:
                print(f"max workers: {config['MAX_WORKERS']}")
            
            data_provider = A9_evaluate.prepare_pure_predicting_dataset(
                DATASET_FILE, 
                fingerprint_type='Morgan', 
                batch_size=config['evaluation_parameters']['batch_size'], 
                shuffle=False
            )
            
            model_files = A9_evaluate.parse_model_selection(config['evaluation_parameters']['models'])
            print(f"using {len(model_files)} models for evaluation")
            
            
            predictions = []
            std_devs = []
            individual_predictions = {}
            
            if config['PARALLEL_EVALUATION'] and len(model_files) > 1:
                print("using parallel evaluation")
                try:
                    from molactivity.A33_eval_parallel import generate_parallel_ensemble_predictions
                    
                    predictions, std_devs, individual_predictions = generate_parallel_ensemble_predictions(
                        model_files=model_files,
                        dataset_file=DATASET_FILE,
                        max_workers=config['MAX_WORKERS'],
                    )
                    
                    print(f"parallel evaluation completed: {len(predictions)} predictions")
                    
                except ImportError as e:
                    print(f"parallel evaluation failed: {e}")
                    print("falling back to sequential evaluation...")
                    config['PARALLEL_EVALUATION'] = False
                except Exception as e:
                    print(f"parallel evaluation failed: {e}")
                    print("falling back to sequential evaluation...")
                    config['PARALLEL_EVALUATION'] = False
            
            if not config['PARALLEL_EVALUATION'] or len(model_files) == 1:
                print("using sequential evaluation...")
                loaded_models = A9_evaluate.load_multiple_models(model_files)
                predictions, std_devs, individual_predictions = A9_evaluate.generate_ensemble_predictions(loaded_models, data_provider, model_files=model_files)
            
            output_file = config['output_parameters']['output_file']
            if predictions:
                individual_preds_to_save = individual_predictions if config['SAVE_INDIVIDUAL_PREDICTIONS'] else None
                
                if config['evaluation_parameters']['ensemble'] and std_devs:
                    save_success = A9_evaluate.save_predictions_to_csv(predictions, output_file, std_devs, individual_preds_to_save, input_file=DATASET_FILE)
                else:
                    save_success = A9_evaluate.save_predictions_to_csv(predictions, output_file, individual_predictions=individual_preds_to_save, input_file=DATASET_FILE)
                
                if save_success and config['output_parameters']['analyze_quality']:
                    A9_evaluate.analyze_prediction_quality(output_file)
                elif not save_success:
                    print("failed to save predictions")
            else:
                print("no predictions to save")
        
        A9_evaluate.evaluation = configured_evaluation
        
        return configured_evaluation
        
    except ImportError as e:
        print(f"configuration failed: {e}")
        return None
    except Exception as e:
        print(f"configuration failed: {e}")
        return None


def run_evaluation():
    config = STANDARD_CONFIG
    print("CURRENT_CONFIG:")
    print(f"device: {config['evaluation_parameters']['device']}")
    print(f"parallel evaluation: {config['PARALLEL_EVALUATION']}")
    print(f"max workers: {config['MAX_WORKERS']}")
    print(f"save individual predictions: {config['SAVE_INDIVIDUAL_PREDICTIONS']}")
    print(f"models: {config['evaluation_parameters']['models']}")
    print(f"ensemble: {config['evaluation_parameters']['ensemble']}")
    print(f"batch size: {config['evaluation_parameters']['batch_size']}")
    print(f"output file: {config['output_parameters']['output_file']}")
    print(f"analyze quality: {config['output_parameters']['analyze_quality']}")
    
    configured_evaluation = apply_configurations()
    
    if configured_evaluation:
        configured_evaluation()
    else:
        print("Configuration failed")


if __name__ == "__main__":
    start_time = time.time()
    
    run_evaluation()

    end_time = time.time()
    total_time = end_time - start_time
    print(f"\ntime used: {total_time:.2f} seconds")