import time

DATASET_FILE = 'predict_sample.csv'  

STANDARD_CONFIG = {
    'PARALLEL_prediction': False,       
    'MAX_WORKERS': None,              
    'SAVE_INDIVIDUAL_PREDICTIONS': True, 
    
    'prediction_parameters': {
        'device': 'cpu',            
        'model_file': 'model_1.dict', 
        'models': 'auto',             
        'ensemble': True,           
        'batch_size': 32,        
    },
    
    'output_parameters': {
        'output_file': 'predict_sample_with_predictions.csv',
    }
}

def apply_configurations():
    try:
        from molactivity import A21_predict
        
        def configured_prediction():
            config = STANDARD_CONFIG
            
            print(f"dataset: {DATASET_FILE}")
            if config['PARALLEL_prediction'] and config['MAX_WORKERS']:
                print(f"max workers: {config['MAX_WORKERS']}")
            
            data_provider = A21_predict.prepare_pure_predicting_dataset(
                DATASET_FILE, 
                fingerprint_type='Morgan', 
                batch_size=config['prediction_parameters']['batch_size'], 
                shuffle=False
            )
            
            model_files = A21_predict.parse_model_selection(config['prediction_parameters']['models'])
            print(f"using {len(model_files)} models for prediction")
            
            
            predictions = []
            std_devs = []
            individual_predictions = {}
            
            if config['PARALLEL_prediction'] and len(model_files) > 1:
                print("using parallel prediction")
                try:
                    from molactivity.A34_predict_parallel import generate_parallel_ensemble_predictions
                    
                    predictions, std_devs, individual_predictions = generate_parallel_ensemble_predictions(
                        model_files=model_files,
                        dataset_file=DATASET_FILE,
                        max_workers=config['MAX_WORKERS'],
                    )
                    
                    print(f"parallel prediction completed: {len(predictions)} predictions")
                    
                except ImportError as e:
                    print(f"parallel prediction failed: {e}")
                    print("falling back to sequential prediction...")
                    config['PARALLEL_prediction'] = False
                except Exception as e:
                    print(f"parallel prediction failed: {e}")
                    print("falling back to sequential prediction...")
                    config['PARALLEL_prediction'] = False
            
            if not config['PARALLEL_prediction'] or len(model_files) == 1:
                print("using sequential prediction...")
                loaded_models = A21_predict.load_multiple_models(model_files)
                predictions, std_devs, individual_predictions = A21_predict.generate_ensemble_predictions(loaded_models, data_provider, model_files=model_files)
            
            output_file = config['output_parameters']['output_file']
            if predictions:
                individual_preds_to_save = individual_predictions if config['SAVE_INDIVIDUAL_PREDICTIONS'] else None
                
                if config['prediction_parameters']['ensemble'] and std_devs:
                    save_success = A21_predict.save_predictions_to_csv(predictions, output_file, std_devs, individual_preds_to_save, input_file=DATASET_FILE)
                else:
                    save_success = A21_predict.save_predictions_to_csv(predictions, output_file, individual_predictions=individual_preds_to_save, input_file=DATASET_FILE)
            else:
                print("no predictions to save")
        
        A21_predict.prediction = configured_prediction
        
        return configured_prediction
        
    except ImportError as e:
        print(f"configuration failed: {e}")
        return None
    except Exception as e:
        print(f"configuration failed: {e}")
        return None


def run_prediction():
    config = STANDARD_CONFIG
    print("CURRENT_CONFIG:")
    print(f"device: {config['prediction_parameters']['device']}")
    print(f"parallel prediction: {config['PARALLEL_prediction']}")
    print(f"max workers: {config['MAX_WORKERS']}")
    print(f"save individual predictions: {config['SAVE_INDIVIDUAL_PREDICTIONS']}")
    print(f"models: {config['prediction_parameters']['models']}")
    print(f"ensemble: {config['prediction_parameters']['ensemble']}")
    print(f"batch size: {config['prediction_parameters']['batch_size']}")
    print(f"output file: {config['output_parameters']['output_file']}")
    
    configured_prediction = apply_configurations()
    
    if configured_prediction:
        configured_prediction()
    else:
        print("Configuration failed")


if __name__ == "__main__":
    start_time = time.time()
    
    run_prediction()

    end_time = time.time()
    total_time = end_time - start_time
    print(f"\ntime used: {total_time:.2f} seconds")