"""Script to run Raindrop experiments on all splits and aggregate results"""

import os
import json
import numpy as np
from datetime import datetime
from run_raindrop_p12 import run_raindrop_experiment

def run_all_splits_experiment(
    data_path,
    n_splits=5,
    num_epochs=20,
    learning_rate=0.0001,
    batch_size=128,
    n_runs=1
):
    """
    Run Raindrop experiments on all splits and aggregate results.
    
    Args:
        data_path: Path to the converted data directory
        n_splits: Number of splits to run
        num_epochs: Number of training epochs
        learning_rate: Learning rate
        batch_size: Batch size
        n_runs: Number of runs per split
    """
    
    all_results = {}
    split_results = []
    
    print(f"Running Raindrop experiments on {n_splits} splits...")
    print(f"Parameters: epochs={num_epochs}, lr={learning_rate}, batch_size={batch_size}, runs={n_runs}")
    
    for split_idx in range(1, n_splits + 1):
        print(f"\n{'='*50}")
        print(f"Running Split {split_idx}/{n_splits}")
        print(f"{'='*50}")
        
        try:
            results = run_raindrop_experiment(
                data_path=data_path,
                split_idx=split_idx,
                num_epochs=num_epochs,
                learning_rate=learning_rate,
                batch_size=batch_size,
                n_runs=n_runs
            )
            
            split_results.append({
                'split_idx': split_idx,
                'results': results
            })
            
            print(f"Split {split_idx} completed successfully!")
            
        except Exception as e:
            print(f"Error in split {split_idx}: {e}")
            continue
    
    # Aggregate results across splits
    if split_results:
        accuracies = [r['results']['accuracy'] for r in split_results]
        auprcs = [r['results']['auprc'] for r in split_results]
        aurocs = [r['results']['auroc'] for r in split_results]
        
        # Calculate mean and std across splits
        mean_acc = np.mean(accuracies)
        std_acc = np.std(accuracies)
        mean_auprc = np.mean(auprcs)
        std_auprc = np.std(auprcs)
        mean_auroc = np.mean(aurocs)
        std_auroc = np.std(aurocs)
        
        print(f"\n{'='*60}")
        print(f"AGGREGATED RESULTS ACROSS {len(split_results)} SPLITS")
        print(f"{'='*60}")
        print(f'Accuracy = {mean_acc:.1f} +/- {std_acc:.1f}')
        print(f'AUPRC    = {mean_auprc:.1f} +/- {std_auprc:.1f}')
        print(f'AUROC    = {mean_auroc:.1f} +/- {std_auroc:.1f}')
        
        # Create comprehensive results dictionary
        all_results = {
            "experiment_info": {
                "model": "Raindrop",
                "dataset": "PhysioNet2012",
                "n_splits": len(split_results),
                "num_epochs": num_epochs,
                "learning_rate": learning_rate,
                "batch_size": batch_size,
                "n_runs_per_split": n_runs,
                "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            },
            "aggregated_metrics": {
                "accuracy_mean": float(mean_acc),
                "accuracy_std": float(std_acc),
                "auprc_mean": float(mean_auprc),
                "auprc_std": float(std_auprc),
                "auroc_mean": float(mean_auroc),
                "auroc_std": float(std_auroc)
            },
            "split_results": split_results
        }
        
        # Save results to file
        results_dir = "raindrop_results"
        os.makedirs(results_dir, exist_ok=True)
        
        timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S')
        results_filename = f"raindrop_p12_results_{timestamp_str}.json"
        results_path = os.path.join(results_dir, results_filename)
        
        with open(results_path, 'w') as f:
            json.dump(all_results, f, indent=2)
        
        print(f"\n📁 Results saved to: {results_path}")
        print(f"   Mean AUROC: {mean_auroc:.4f} ± {std_auroc:.4f}")
        print(f"   Mean AUPRC: {mean_auprc:.4f} ± {std_auprc:.4f}")
        
        return all_results
    
    else:
        print("No successful splits completed!")
        return None

def compare_with_gman_results(raindrop_results, gman_results_path=None):
    """
    Compare Raindrop results with GMAN results.
    
    Args:
        raindrop_results: Results from Raindrop experiments
        gman_results_path: Path to GMAN results file
    """
    
    if raindrop_results is None:
        print("No Raindrop results to compare!")
        return
    
    print(f"\n{'='*60}")
    print("COMPARISON WITH GMAN RESULTS")
    print(f"{'='*60}")
    
    raindrop_auroc = raindrop_results['aggregated_metrics']['auroc_mean']
    raindrop_auprc = raindrop_results['aggregated_metrics']['auprc_mean']
    
    print(f"Raindrop Results:")
    print(f"  AUROC: {raindrop_auroc:.4f} ± {raindrop_results['aggregated_metrics']['auroc_std']:.4f}")
    print(f"  AUPRC: {raindrop_auprc:.4f} ± {raindrop_results['aggregated_metrics']['auprc_std']:.4f}")
    
    if gman_results_path and os.path.exists(gman_results_path):
        try:
            with open(gman_results_path, 'r') as f:
                gman_results = json.load(f)
            
            gman_auroc = gman_results.get('best_metrics', {}).get('test_auc', 0)
            gman_auprc = gman_results.get('best_metrics', {}).get('test_auprc', 0)
            
            print(f"\nGMAN Results:")
            print(f"  AUROC: {gman_auroc:.4f}")
            print(f"  AUPRC: {gman_auprc:.4f}")
            
            print(f"\nDifference (Raindrop - GMAN):")
            print(f"  AUROC: {raindrop_auroc - gman_auroc:.4f}")
            print(f"  AUPRC: {raindrop_auprc - gman_auprc:.4f}")
            
        except Exception as e:
            print(f"Error loading GMAN results: {e}")
    else:
        print("No GMAN results file provided for comparison.")

if __name__ == "__main__":
    # Configuration
    data_path = "Raindrop/P12data_converted"  # Path to your converted data
    n_splits = 5
    num_epochs = 20
    learning_rate = 0.0001
    batch_size = 128
    n_runs = 1
    
    # Run experiments
    results = run_all_splits_experiment(
        data_path=data_path,
        n_splits=n_splits,
        num_epochs=num_epochs,
        learning_rate=learning_rate,
        batch_size=batch_size,
        n_runs=n_runs
    )
    
    # Compare with GMAN results (if available)
    gman_results_path = "training_results/results_gnan_p12_seed42_20241201_120000.json"  # Update path as needed
    compare_with_gman_results(results, gman_results_path)
    
    print("\nExperiment completed!") 