#!/usr/bin/env python3
"""
Complete Enhanced Experiments Runner
====================================

This script runs the complete enhanced hyperparameter optimization experiments
for all three problems (Lorenz, Burgers, Inverse Poisson) and all three methods
(Standard, R-PIT, Bayesian).

The experiments include:
1. Enhanced hyperparameter optimization for Burgers equation
2. Bayesian PINN hyperparameter search for fair comparison
3. Comprehensive evaluation across all problem-method combinations
"""

import os
import sys
import time
import json
from datetime import datetime
from pathlib import Path

# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

from experiments.enhanced_hyperparameter_optimization import EnhancedHyperparameterOptimizer

def main():
    """Run complete enhanced experiments."""
    print("=" * 80)
    print("COMPLETE ENHANCED EXPERIMENTS")
    print("=" * 80)
    print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print()
    
    # Create results directory
    results_dir = project_root / "data" / "results" / "enhanced_hyperopt_results"
    results_dir.mkdir(parents=True, exist_ok=True)
    
    # Initialize optimizer
    print("Initializing Enhanced Hyperparameter Optimizer...")
    optimizer = EnhancedHyperparameterOptimizer(
        device="cuda" if os.system("nvidia-smi > /dev/null 2>&1") == 0 else "cpu"
    )
    print(f"Using device: {optimizer.device}")
    print()
    
    # Define experiment configuration
    problems = ["lorenz", "burgers", "inverse_poisson"]
    methods = ["standard", "rpit", "bayesian"]
    
    # Experiment parameters
    num_trials = 20  # Number of hyperparameter search trials per problem-method
    num_epochs = 1000  # Training epochs
    
    all_results = {}
    total_experiments = len(problems) * len(methods)
    experiment_count = 0
    
    print(f"Running {total_experiments} problem-method combinations...")
    print(f"Each with {num_trials} hyperparameter search trials")
    print(f"Training for {num_epochs} epochs per experiment")
    print()
    
    start_time = time.time()
    
    for problem in problems:
        print(f"{'='*60}")
        print(f"PROBLEM: {problem.upper()}")
        print(f"{'='*60}")
        
        all_results[problem] = {}
        
        for method in methods:
            experiment_count += 1
            print(f"\n[{experiment_count}/{total_experiments}] {problem.upper()} + {method.upper()}")
            print("-" * 50)
            
            try:
                # Run hyperparameter optimization
                print(f"Starting hyperparameter search for {problem} with {method} method...")
                search_results = optimizer.random_search(
                    problem=problem,
                    method=method,
                    num_trials=num_trials,
                    num_epochs=num_epochs
                )
                
                # Process results
                successful_experiments = sum(1 for r in search_results if r['metrics']['converged'])
                total_experiments = len(search_results)
                
                # Store results in expected format
                processed_results = {
                    'successful_experiments': successful_experiments,
                    'total_experiments': total_experiments,
                    'results': search_results
                }
                all_results[problem][method] = processed_results
                
                # Print summary
                if successful_experiments > 0:
                    best_mse = min([r['metrics']['mse'] for r in search_results if r['metrics']['converged']])
                    print(f"✅ Completed: {successful_experiments}/{total_experiments} successful")
                    print(f"   Best MSE: {best_mse:.6f}")
                else:
                    print(f"❌ Failed: 0/{total_experiments} successful")
                
            except Exception as e:
                print(f"❌ Error in {problem} + {method}: {e}")
                all_results[problem][method] = {
                    'successful_experiments': 0,
                    'total_experiments': num_trials,
                    'error': str(e)
                }
    
    # Calculate total time
    total_time = time.time() - start_time
    
    # Save results
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_file = results_dir / f"enhanced_hyperopt_results_{timestamp}.json"
    
    # Prepare summary
    summary = {
        'timestamp': timestamp,
        'total_time_seconds': total_time,
        'total_time_hours': total_time / 3600,
        'problems': problems,
        'methods': methods,
        'num_trials_per_combination': num_trials,
        'num_epochs': num_epochs,
        'results': all_results
    }
    
    # Save to JSON
    with open(results_file, 'w') as f:
        json.dump(summary, f, indent=2, default=str)
    
    # Print final summary
    print(f"\n{'='*80}")
    print("EXPERIMENT SUMMARY")
    print(f"{'='*80}")
    print(f"Total time: {total_time/3600:.2f} hours")
    print(f"Results saved to: {results_file}")
    print()
    
    # Success rate summary
    total_successful = 0
    total_experiments_run = 0
    
    for problem in problems:
        print(f"{problem.upper()}:")
        for method in methods:
            if method in all_results[problem]:
                result = all_results[problem][method]
                if 'successful_experiments' in result:
                    successful = result['successful_experiments']
                    total = result.get('total_experiments', num_trials)
                    total_successful += successful
                    total_experiments_run += total
                    success_rate = (successful / total * 100) if total > 0 else 0
                    print(f"  {method}: {successful}/{total} ({success_rate:.1f}%)")
                else:
                    print(f"  {method}: ERROR - {result.get('error', 'Unknown error')}")
        print()
    
    overall_success_rate = (total_successful / total_experiments_run * 100) if total_experiments_run > 0 else 0
    print(f"Overall success rate: {total_successful}/{total_experiments_run} ({overall_success_rate:.1f}%)")
    
    if overall_success_rate >= 80:
        print("🎉 Enhanced hyperparameter optimization completed successfully!")
    else:
        print("⚠️  Some experiments failed - check the results for details.")
    
    print(f"\nResults saved to: {results_file}")
    print("=" * 80)

if __name__ == "__main__":
    main()
