
import time

DATASET_FILE = 'predict_sample.csv'
ROCKET_CONFIG = {
    'PARALLEL_prediction': False,
    'MAX_WORKERS': None,
    'output_file': 'predict_sample_with_predictions.csv',
}

def run_prediction():
    config = ROCKET_CONFIG
    print(f"Dataset: {DATASET_FILE}")
    print(f"Output file: {config['output_file']}")
    print(f"Parallel prediction: {config['PARALLEL_prediction']}")
    if config['PARALLEL_prediction'] and config['MAX_WORKERS']:
        print(f"Max workers: {config['MAX_WORKERS']}")
    
    try:
        import torch
        from molactivity.C3_utils import prepare_predicting_dataset, MolecularTransformer
        from molactivity.C1_evaluate_predict import load_trained_network, generate_predictions
        import numpy as np
        import os
        
        compute_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print(f'Computation device: {compute_device}')

        prediction_provider = prepare_predicting_dataset(DATASET_FILE, fingerprint_type='Morgan', 
                                            batch_size=32, shuffle=False, balance_data=False)

        model_files = []
        
        import glob
        potential_models = glob.glob("*model*.pt")
        
        for model_path in sorted(potential_models):
            if os.path.exists(model_path):
                model_files.append(model_path)
                print(f"Found model file: {model_path}")
        
        if not model_files:
            print('No trained network models available! Execute training first.')
            all_pt_files = glob.glob("*.pt")
            if all_pt_files:
                print(f"Found .pt files in current directory: {all_pt_files}")
            else:
                print("No .pt files found in current directory")
            return False
        
        print(f"using {len(model_files)} models for prediction")
        
        predictions = []
        std_devs = []
        individual_predictions = {}
        
        if config['PARALLEL_prediction'] and len(model_files) > 1:
            print("using parallel prediction")
            try:
                max_workers = config['MAX_WORKERS'] if config['MAX_WORKERS'] is not None else min(len(model_files), 4)
                print(f"parallel workers: {max_workers}")
                
                from molactivity.C4_evaluate_predict_parallel import generate_rocket_parallel_ensemble_predictions
                
                import threading
                import queue
                
                result_queue = queue.Queue()
                exception_queue = queue.Queue()
                
                def parallel_worker():
                    try:
                        result = generate_rocket_parallel_ensemble_predictions(
                            model_files=model_files,
                            dataset_file=DATASET_FILE,
                            max_workers=max_workers,
                            batch_size=32,
                            device='cuda' if compute_device.type == 'cuda' else 'cpu'
                        )
                        result_queue.put(result)
                    except Exception as e:
                        exception_queue.put(e)
                
                worker_thread = threading.Thread(target=parallel_worker)
                worker_thread.daemon = True
                worker_thread.start()
                worker_thread.join(timeout=30000)
                
                if worker_thread.is_alive():
                    raise TimeoutError("Parallel prediction timed out after 5 minutes")
                
                if not exception_queue.empty():
                    raise exception_queue.get()
                
                if not result_queue.empty():
                    predictions, std_devs, individual_predictions = result_queue.get()
                else:
                    raise RuntimeError("Parallel prediction completed but no result returned")
                print(f"parallel prediction completed: {len(predictions)} predictions")
                
            except ImportError as e:
                print(f"parallel prediction failed (import error): {e}")
                print("falling back to sequential prediction...")
                config['PARALLEL_prediction'] = False
            except TimeoutError as e:
                print(f"parallel prediction timed out: {e}")
                print("falling back to sequential prediction...")
                config['PARALLEL_prediction'] = False
            except Exception as e:
                print(f"parallel prediction failed (other error): {e}")
                print("falling back to sequential prediction...")
                config['PARALLEL_prediction'] = False
        
        if not config['PARALLEL_prediction'] or len(model_files) == 1:
            print("using sequential prediction...")
            
            trained_networks = []
            for model_path in model_files:
                network = MolecularTransformer(input_features=2048, output_features=1, 
                                             embedding_size=512, layer_count=6, 
                                             head_count=8, hidden_size=2048, 
                                             dropout_rate=0.1)
                network = network.to(compute_device)
                
                network = load_trained_network(network, model_path, compute_device)
                if network is None:
                    print(f'Failed to load model {model_path}')
                    continue
                    
                network.eval()
                trained_networks.append(network)

            if not trained_networks:
                print('No trained network models loaded successfully!')
                return False

            ensemble_predictions = []
            for network_idx, network in enumerate(trained_networks):
                print(f'Starting predictions with model {network_idx+1}...')
                current_predictions = generate_predictions(network, prediction_provider, device=compute_device)
                ensemble_predictions.append(current_predictions)
                print('Predictions completed')
            
            ensemble_predictions = np.array(ensemble_predictions)
            predictions = np.mean(ensemble_predictions, axis=0).tolist()
            
            if len(ensemble_predictions) > 1:
                std_devs = np.std(ensemble_predictions, axis=0).tolist()
            else:
                std_devs = np.zeros_like(predictions).tolist()
            
            individual_predictions = {}
            for i, current_predictions in enumerate(ensemble_predictions):
                model_path = model_files[i] if i < len(model_files) else f'model_{i+1}.pt'
                individual_predictions[model_path] = current_predictions.tolist()

        output_file = config['output_file']
        
        data = prediction_provider.dataset.molecular_data.copy()
        data['ENSEMBLE_PREDICTION'] = predictions
        
        if len(std_devs) > 0 and any(std > 0 for std in std_devs):
            data['ENSEMBLE_STD_DEV'] = std_devs
        
        for model_name, model_preds in individual_predictions.items():
            column_name = model_name.replace('.pt', '').replace('model_', 'MODEL_').upper()
            data[f'PRED_{column_name}'] = model_preds
        
        data.to_csv(output_file, index=False)        
        
    except Exception as e:
        print(f"prediction failed: {e}")
        return False
    
    return True

if __name__ == "__main__":
    
    start_time = time.time()
    
    run_prediction()
    
    end_time = time.time()
    total_time = end_time - start_time
    
    print(f"prediction completed in {total_time:.2f} seconds")
