import time

DATASET_FILE = 'predict_sample.csv'

FAST_CONFIG = {
    'PARALLEL_PREDICTION': False,
    'MAX_WORKERS': None,
    'NUMPY_OPTIMIZATION': True,
    'SAVE_INDIVIDUAL_PREDICTIONS': True,
    
    'PREDICTION_parameters': {
        'device': 'cpu',
        'model_file': 'model_1.dict',
        'models': 'auto',
        'ensemble': True,
        'batch_size': 32,
        'fast_mode': True,
    },
    
    'optimization_parameters': {
        'use_vectorized_ops': True,
        'cache_fingerprints': True,
        'batch_processing': True,
        'memory_efficient': True,
    },
    
    'output_parameters': {
        'output_file': 'predict_sample_with_predictions.csv',
        'fast_save': True,
    }
}

def apply_fast_configurations():
    try:
        from molactivity import B10_predict as predict_fast
        
        def configured_fast_prediction():
            config = FAST_CONFIG

            data_provider = predict_fast.prepare_pure_predicting_dataset(
                DATASET_FILE, 
                fingerprint_type='Morgan', 
                batch_size=config['PREDICTION_parameters']['batch_size'], 
                shuffle=False
            )
            
            model_files = predict_fast.parse_model_selection(config['PREDICTION_parameters']['models'])
            print(f"using {len(model_files)} models for fast prediction")
            
            predictions = []
            std_devs = []
            individual_predictions = {}
            
            if config['PARALLEL_PREDICTION'] and len(model_files) > 1:
                try:
                    from molactivity.B20_predict_fast_parallel import generate_fast_parallel_ensemble_predictions
                    predictions, std_devs, individual_predictions = generate_fast_parallel_ensemble_predictions(
                        model_files=model_files,
                        dataset_file=DATASET_FILE,
                        max_workers=config['MAX_WORKERS'],
                        batch_size=config['PREDICTION_parameters']['batch_size'],
                        optimization_config=config['optimization_parameters']
                    )
                    print(f"prediction completed: {len(predictions)} predictions")
                    
                except ImportError as e:
                    print(f"fast parallel prediction failed: {e}")
                    print("falling back to fast sequential prediction...")
                    config['PARALLEL_PREDICTION'] = False
                except Exception as e:
                    print(f"fast parallel prediction failed: {e}")
                    print("falling back to fast sequential prediction...")
                    config['PARALLEL_PREDICTION'] = False
            
            if not config['PARALLEL_PREDICTION'] or len(model_files) == 1:
                print("using fast sequential prediction...")
                
                loaded_models = predict_fast.load_multiple_models(model_files)
                predictions, std_devs, individual_predictions = predict_fast.generate_ensemble_predictions(loaded_models, data_provider, model_files=model_files)
            
            output_file = config['output_parameters']['output_file']
            if predictions:
                individual_preds_to_save = individual_predictions if config['SAVE_INDIVIDUAL_PREDICTIONS'] else None
                
                if config['PREDICTION_parameters']['ensemble'] and std_devs:
                    save_success = predict_fast.save_predictions_to_csv(
                        predictions, output_file, std_devs, individual_preds_to_save, input_file=DATASET_FILE
                    )
                else:
                    save_success = predict_fast.save_predictions_to_csv(
                        predictions, output_file, individual_predictions=individual_preds_to_save, input_file=DATASET_FILE
                    )
                
            else:
                print("no predictions to save")
        
        predict_fast.prediction = configured_fast_prediction
        
        return configured_fast_prediction
        
    except ImportError as e:
        print(f"configuration failed: {e}")
        return None
    except Exception as e:
        print(f"configuration failed: {e}")
        return None

def run_fast_prediction():
    config = FAST_CONFIG
    print("FAST_CONFIG:")
    print(f"parallel prediction: {config['PARALLEL_PREDICTION']}")
    print(f"max workers: {config['MAX_WORKERS']}")
    print(f"models: {config['PREDICTION_parameters']['models']}")
    print(f"batch size: {config['PREDICTION_parameters']['batch_size']}")
    print(f"dataset: {DATASET_FILE}")
    print(f"output file: {config['output_parameters']['output_file']}")
    
    configured_prediction = apply_fast_configurations()
    
    if configured_prediction:
        configured_prediction()
    else:
        print("Configuration failed")

if __name__ == "__main__":
    start_time = time.time()
    
    run_fast_prediction()
    
    end_time = time.time()
    total_time = end_time - start_time
    print(f"time used for fast prediction: {total_time:.2f} seconds")