"""
Lorenz system experiment script.
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import sys
import os
import argparse

# Add src and config to path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root) 
sys.path.append(os.path.join(project_root, 'src'))

from models.problem_models import create_problem_model
from problems.lorenz_system import LorenzSystem
from config.lorenz_config import LorenzConfig
from utils.config_loader import get_optimized_config, print_config_summary


def run_experiment(config: LorenzConfig, method: str = "rpit"):
    """
    Run Lorenz system experiment.
    
    Args:
        config: Configuration object
        method: Method to use ('rpit', 'standard', 'bayesian')
    """
    print(f"Running Lorenz experiment with {method} method...")
    
    # Set random seeds
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    
    # Create problem
    problem = LorenzSystem(
        sigma=config.sigma,
        rho=config.rho,
        beta=config.beta,
        noise_std=config.noise_std,
        device=config.device
    )
    
    # Generate training data
    print("Generating training data...")
    x_data, y_data = problem.generate_training_data(
        num_trajectories=config.num_trajectories,
        num_time_points=config.num_time_points,
        t_start=config.t_start,
        t_end=config.t_end,
        corruption_level=config.corruption_level
    )
    
    # Generate collocation points
    x_collocation = problem.generate_collocation_points(
        num_points=config.num_collocation_points,
        t_start=config.t_start,
        t_end=config.t_end
    )
    
    print(f"Training data shape: {x_data.shape}, {y_data.shape}")
    print(f"Collocation points shape: {x_collocation.shape}")
    
    # Create model using problem-specific factory
    if method == "rpit":
        model = create_problem_model(
            problem_type="lorenz",
            method=method,
            input_dim=1,  # Time only
            output_dim=3,  # x, y, z components
            hidden_layers=config.hidden_layers,
            activation=config.activation,
            device=config.device,
            sigma=config.sigma,
            rho=config.rho,
            beta=config.beta,
            lambda_sens=config.lambda_sens,
            lambda_var=config.lambda_var,
            noise_std=config.noise_std,
            uncertainty_output=True
        )
    elif method == "bayesian":
        model = create_problem_model(
            problem_type="lorenz",
            method=method,
            input_dim=1,  # Time only
            output_dim=3,  # x, y, z components
            hidden_layers=config.hidden_layers,
            activation=config.activation,
            device=config.device,
            n_ensemble=5,
            dropout_rate=0.1,
            sigma=config.sigma,
            rho=config.rho,
            beta=config.beta
        )
    else:  # standard
        model = create_problem_model(
            problem_type="lorenz",
            method=method,
            input_dim=1,  # Time only
            output_dim=3,  # x, y, z components
            hidden_layers=config.hidden_layers,
            activation=config.activation,
            device=config.device,
            sigma=config.sigma,
            rho=config.rho,
            beta=config.beta
        )
    
    print(f"Model created: {model.get_model_info()}")
    
    # Set up optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    
    # Training loop
    print("Starting training...")
    model.train()
    
    losses_history = {
        'total_loss': [],
        'physics_loss': [],
        'data_loss': [],
    }
    
    if method == "rpit":
        losses_history['sensitivity_loss'] = []
    elif method == "bayesian":
        losses_history['uncertainty_loss'] = []
    
    for epoch in range(config.num_epochs):
        optimizer.zero_grad()
        
        # Compute loss
        if method == "rpit":
            losses = model.compute_total_loss(x_collocation, x_data, y_data)
        else:
            losses = model.compute_total_loss(x_collocation, x_data, y_data)
        
        # Backward pass
        losses['total_loss'].backward()
        optimizer.step()
        
        # Log losses
        if epoch % config.log_frequency == 0:
            print(f"Epoch {epoch:5d}: Total Loss = {losses['total_loss'].item():.6f}")
            
            for loss_name, loss_value in losses.items():
                losses_history[loss_name].append(loss_value.item())
    
    print("Training completed!")
    
    # Test prediction
    print("Testing predictions...")
    model.eval()
    
    # Generate test trajectory
    t_test, y_test = problem.generate_trajectory(
        t_start=config.t_start,
        t_end=config.t_end,
        dt=0.01,
        add_noise=False  # Clean trajectory for testing
    )
    
    t_test_tensor = torch.tensor(t_test, dtype=torch.float32, device=config.device).unsqueeze(1)
    
    with torch.no_grad():
        if method == "rpit":
            y_pred_mean, y_pred_var = model.predict_mean_variance(t_test_tensor)
            uncertainty = model.get_uncertainty_estimates(t_test_tensor)
        elif method == "bayesian":
            y_pred_mean, y_pred_var = model.forward_with_uncertainty(t_test_tensor)
            uncertainty = torch.sqrt(y_pred_var)  # Standard deviation
        else:  # standard
            y_pred_mean = model.predict(t_test_tensor)
            y_pred_var = torch.zeros_like(y_pred_mean)
            uncertainty = torch.zeros_like(y_pred_mean)
    
    # Convert to numpy
    y_pred_mean = y_pred_mean.cpu().numpy()
    y_pred_var = y_pred_var.cpu().numpy()
    uncertainty = uncertainty.cpu().numpy()
    
    # Compute metrics
    mse = np.mean((y_pred_mean - y_test)**2)
    mae = np.mean(np.abs(y_pred_mean - y_test))
    rmse = np.sqrt(mse)
    relative_error = np.mean(np.abs(y_pred_mean - y_test) / (np.abs(y_test) + 1e-8))
    
    print(f"Test MSE: {mse:.6f}")
    print(f"Test MAE: {mae:.6f}")
    print(f"Test RMSE: {rmse:.6f}")
    print(f"Test Relative Error: {relative_error:.6f}")
    
    # Save results
    results_dir = Path(config.results_dir) / f"lorenz_{method}"
    results_dir.mkdir(parents=True, exist_ok=True)
    
    # Save model
    model_path = results_dir / "model.pth"
    model.save_checkpoint(str(model_path))
    
    # Save predictions
    np.save(results_dir / "t_test.npy", t_test)
    np.save(results_dir / "y_test.npy", y_test)
    np.save(results_dir / "y_pred_mean.npy", y_pred_mean)
    np.save(results_dir / "y_pred_var.npy", y_pred_var)
    np.save(results_dir / "uncertainty.npy", uncertainty)
    
    # Save losses
    np.save(results_dir / "losses_history.npy", losses_history)
    
    # Create simple visualization
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Plot trajectories
    for i, component in enumerate(['x', 'y', 'z']):
        ax = axes[i//2, i%2]
        ax.plot(t_test, y_test[:, i], 'b-', label='True', alpha=0.7)
        ax.plot(t_test, y_pred_mean[:, i], 'r--', label='Predicted', alpha=0.7)
        
        if method in ["rpit", "bayesian"]:
            ax.fill_between(
                t_test,
                y_pred_mean[:, i] - 2*uncertainty[:, i],
                y_pred_mean[:, i] + 2*uncertainty[:, i],
                alpha=0.3, color='red', label='±2σ'
            )
        
        ax.set_xlabel('Time')
        ax.set_ylabel(f'{component}')
        ax.set_title(f'Lorenz {component} Component')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    # Plot loss history
    ax = axes[1, 1]
    ax.plot(losses_history['total_loss'], label='Total Loss')
    ax.plot(losses_history['physics_loss'], label='Physics Loss')
    ax.plot(losses_history['data_loss'], label='Data Loss')
    if method == "rpit":
        ax.plot(losses_history['sensitivity_loss'], label='Sensitivity Loss')
    elif method == "bayesian":
        ax.plot(losses_history['uncertainty_loss'], label='Uncertainty Loss')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('Training Loss History')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_yscale('log')
    
    plt.tight_layout()
    plt.savefig(results_dir / "results.png", dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"Results saved to {results_dir}")
    
    return {
        'mse': mse,
        'mae': mae,
        'rmse': rmse,
        'relative_error': relative_error,
        'losses_history': losses_history,
        'results_dir': results_dir
    }


def main():
    """Main function to run experiments."""
    parser = argparse.ArgumentParser(description='Run Lorenz system experiment')
    parser.add_argument('--method', type=str, default='rpit', 
                       choices=['rpit', 'standard', 'bayesian'],
                       help='Method to use (default: rpit)')
    parser.add_argument('--epochs', type=int, default=1000,
                       help='Number of training epochs (default: 1000)')
    parser.add_argument('--seed', type=int, default=42,
                       help='Random seed (default: 42)')
    parser.add_argument('--device', type=str, default='auto',
                       choices=['auto', 'cpu', 'cuda'],
                       help='Device to use (default: auto)')
    
    args = parser.parse_args()
    
    # Check for GPU availability
    if args.device == 'auto':
        device = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        device = args.device
    
    print(f"Using device: {device}")
    if device == "cuda":
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    
    # Create configuration with best hyperparameters
    print(f"\n🔧 Loading configuration for {args.method} method...")
    config = get_optimized_config(
        problem_type="lorenz",
        method=args.method,
        experiment_name=f"lorenz_{args.method}",
        num_epochs=args.epochs,
        log_frequency=100,
        device=device,
        seed=args.seed
    )
    
    # Print configuration summary
    print_config_summary(config, args.method)
    
    print("Configuration:")
    print(f"  Experiment: {config.experiment_name}")
    print(f"  Method: {args.method}")
    print(f"  Device: {config.device}")
    print(f"  Epochs: {config.num_epochs}")
    print(f"  Seed: {config.seed}")
    print(f"  Lambda sens: {config.lambda_sens}")
    print(f"  Lambda var: {config.lambda_var}")
    print(f"  Noise std: {config.noise_std}")
    
    # Run experiment
    print("\n" + "="*50)
    results = run_experiment(config, method=args.method)
    
    # Print results
    print("\n" + "="*50)
    print("EXPERIMENT RESULTS:")
    print(f"MSE: {results['mse']:.6f}")
    print(f"MAE: {results['mae']:.6f}")
    print(f"RMSE: {results['rmse']:.6f}")
    print(f"Relative Error: {results['relative_error']:.6f}")


if __name__ == "__main__":
    main()