"""
2D Burgers equation experiment script.
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import sys
import os
import argparse

# Add src and config to path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root) 
sys.path.append(os.path.join(project_root, 'src'))

from models.problem_models import create_problem_model
from problems.burgers_2d import Burgers2D
from config.burgers_config import BurgersConfig
from utils.config_loader import get_optimized_config, print_config_summary


def run_experiment(config: BurgersConfig, method: str = "rpit"):
    """
    Run 2D Burgers equation experiment.
    
    Args:
        config: Configuration object
        method: Method to use ('rpit', 'standard', 'bayesian')
    """
    print(f"Running 2D Burgers experiment with {method} method...")
    
    # Set random seeds
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    
    # Create problem
    problem = Burgers2D(
        nu=config.nu,
        x_start=config.x_start,
        x_end=config.x_end,
        y_start=config.y_start,
        y_end=config.y_end,
        t_start=config.t_start,
        t_end=config.t_end,
        device=config.device
    )
    
    # Generate training data
    print("Generating training data...")
    x_data, y_data = problem.generate_training_data(
        num_initial_points=config.num_initial_points,
        num_boundary_points=config.num_boundary_points,
        num_interior_points=config.num_interior_points,
        noise_std=config.noise_std
    )
    
    # Generate collocation points
    x_collocation = problem.generate_collocation_points(
        num_points=config.num_collocation_points,
        num_x=config.num_x,
        num_y=config.num_y,
        num_t=config.num_t
    )
    
    print(f"Training data shape: {x_data.shape}, {y_data.shape}")
    print(f"Collocation points shape: {x_collocation.shape}")
    
    # Create model using problem-specific factory
    if method == "rpit":
        model = create_problem_model(
            problem_type="burgers",
            method=method,
            input_dim=3,  # x, y, t
            output_dim=2,  # u, v components
            hidden_layers=config.hidden_layers,
            activation=config.activation,
            device=config.device,
            nu=config.nu,
            lambda_sens=config.lambda_sens,
            lambda_var=config.lambda_var,
            noise_std=config.noise_std,
            uncertainty_output=True
        )
    elif method == "bayesian":
        model = create_problem_model(
            problem_type="burgers",
            method=method,
            input_dim=3,  # x, y, t
            output_dim=2,  # u, v components
            hidden_layers=config.hidden_layers,
            activation=config.activation,
            device=config.device,
            n_ensemble=5,
            dropout_rate=0.1,
            nu=config.nu
        )
    else:  # standard
        model = create_problem_model(
            problem_type="burgers",
            method=method,
            input_dim=3,  # x, y, t
            output_dim=2,  # u, v components
            hidden_layers=config.hidden_layers,
            activation=config.activation,
            device=config.device,
            nu=config.nu
        )
    
    print(f"Model created: {model.get_model_info()}")
    
    # Set up optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    
    # Training loop
    print("Starting training...")
    model.train()
    
    losses_history = {
        'total_loss': [],
        'physics_loss': [],
        'data_loss': [],
    }
    
    if method == "rpit":
        losses_history['sensitivity_loss'] = []
    elif method == "bayesian":
        losses_history['uncertainty_loss'] = []
    
    for epoch in range(config.num_epochs):
        optimizer.zero_grad()
        
        # Compute loss
        if method == "rpit":
            losses = model.compute_total_loss(x_collocation, x_data, y_data)
        else:
            losses = model.compute_total_loss(x_collocation, x_data, y_data)
        
        # Backward pass
        losses['total_loss'].backward()
        optimizer.step()
        
        # Log losses
        if epoch % config.log_frequency == 0:
            print(f"Epoch {epoch:5d}: Total Loss = {losses['total_loss'].item():.6f}")
            
            for loss_name, loss_value in losses.items():
                losses_history[loss_name].append(loss_value.item())
    
    print("Training completed!")
    
    # Test prediction
    print("Testing predictions...")
    model.eval()
    
    # Generate test grid
    x_test = torch.linspace(config.x_start, config.x_end, 50, device=config.device)
    y_test = torch.linspace(config.y_start, config.y_end, 50, device=config.device)
    t_test = torch.linspace(config.t_start, config.t_end, 20, device=config.device)
    
    # Create test meshgrid
    X_test, Y_test, T_test = torch.meshgrid(x_test, y_test, t_test, indexing='ij')
    test_points = torch.stack([X_test.flatten(), Y_test.flatten(), T_test.flatten()], dim=1)
    
    # Get analytical solution for comparison
    u_analytical = torch.sin(np.pi * X_test) * torch.sin(np.pi * Y_test) * torch.exp(-config.nu * np.pi**2 * T_test)
    v_analytical = torch.cos(np.pi * X_test) * torch.cos(np.pi * Y_test) * torch.exp(-config.nu * np.pi**2 * T_test)
    
    with torch.no_grad():
        if method == "rpit":
            y_pred_mean, y_pred_var = model.predict_mean_variance(test_points)
            uncertainty = model.get_uncertainty_estimates(test_points)
        else:
            y_pred_mean = model.predict(test_points)
            y_pred_var = torch.zeros_like(y_pred_mean)
            uncertainty = torch.zeros_like(y_pred_mean)
    
    # Reshape predictions
    u_pred = y_pred_mean[:, 0].reshape(X_test.shape)
    v_pred = y_pred_mean[:, 1].reshape(X_test.shape)
    
    # Convert to numpy
    u_pred = u_pred.cpu().numpy()
    v_pred = v_pred.cpu().numpy()
    u_analytical = u_analytical.cpu().numpy()
    v_analytical = v_analytical.cpu().numpy()
    
    # Compute metrics
    u_mse = np.mean((u_pred - u_analytical)**2)
    v_mse = np.mean((v_pred - v_analytical)**2)
    total_mse = (u_mse + v_mse) / 2
    
    u_mae = np.mean(np.abs(u_pred - u_analytical))
    v_mae = np.mean(np.abs(v_pred - v_analytical))
    total_mae = (u_mae + v_mae) / 2
    
    print(f"Test MSE (u): {u_mse:.6f}")
    print(f"Test MSE (v): {v_mse:.6f}")
    print(f"Test Total MSE: {total_mse:.6f}")
    print(f"Test Total MAE: {total_mae:.6f}")
    
    # Save results
    results_dir = Path(config.results_dir) / f"burgers_2d_{method}"
    results_dir.mkdir(parents=True, exist_ok=True)
    
    # Save model
    model_path = results_dir / "model.pth"
    model.save_checkpoint(str(model_path))
    
    # Save predictions
    np.save(results_dir / "X_test.npy", X_test.cpu().numpy())
    np.save(results_dir / "Y_test.npy", Y_test.cpu().numpy())
    np.save(results_dir / "T_test.npy", T_test.cpu().numpy())
    np.save(results_dir / "u_pred.npy", u_pred)
    np.save(results_dir / "v_pred.npy", v_pred)
    np.save(results_dir / "u_analytical.npy", u_analytical)
    np.save(results_dir / "v_analytical.npy", v_analytical)
    
    # Save losses
    np.save(results_dir / "losses_history.npy", losses_history)
    
    # Create visualization
    create_burgers_visualization(
        X_test.cpu().numpy(), Y_test.cpu().numpy(), T_test.cpu().numpy(),
        u_pred, v_pred, u_analytical, v_analytical,
        losses_history, method, results_dir
    )
    
    print(f"Results saved to {results_dir}")
    
    return {
        'total_mse': total_mse,
        'total_mae': total_mae,
        'u_mse': u_mse,
        'v_mse': v_mse,
        'losses_history': losses_history,
        'results_dir': results_dir
    }


def create_burgers_visualization(X, Y, T, u_pred, v_pred, u_analytical, v_analytical, 
                                losses_history, method, results_dir):
    """Create visualization for Burgers equation results."""
    fig = plt.figure(figsize=(20, 15))
    
    # Select time slices for visualization
    time_indices = [0, T.shape[2]//4, T.shape[2]//2, 3*T.shape[2]//4, T.shape[2]-1]
    time_labels = ['t=0.0', 't=0.25', 't=0.5', 't=0.75', 't=1.0']
    
    # Plot u component
    for i, (t_idx, t_label) in enumerate(zip(time_indices, time_labels)):
        # Predicted u
        ax1 = plt.subplot(4, 5, i+1)
        im1 = ax1.contourf(X[:, :, t_idx], Y[:, :, t_idx], u_pred[:, :, t_idx], 
                          levels=20, cmap='RdBu_r')
        ax1.set_title(f'{method.upper()}: u {t_label}')
        ax1.set_xlabel('x')
        ax1.set_ylabel('y')
        plt.colorbar(im1, ax=ax1)
        
        # Analytical u
        ax2 = plt.subplot(4, 5, i+6)
        im2 = ax2.contourf(X[:, :, t_idx], Y[:, :, t_idx], u_analytical[:, :, t_idx], 
                          levels=20, cmap='RdBu_r')
        ax2.set_title(f'Analytical: u {t_label}')
        ax2.set_xlabel('x')
        ax2.set_ylabel('y')
        plt.colorbar(im2, ax=ax2)
    
    # Plot v component
    for i, (t_idx, t_label) in enumerate(zip(time_indices, time_labels)):
        # Predicted v
        ax3 = plt.subplot(4, 5, i+11)
        im3 = ax3.contourf(X[:, :, t_idx], Y[:, :, t_idx], v_pred[:, :, t_idx], 
                          levels=20, cmap='RdBu_r')
        ax3.set_title(f'{method.upper()}: v {t_label}')
        ax3.set_xlabel('x')
        ax3.set_ylabel('y')
        plt.colorbar(im3, ax=ax3)
        
        # Analytical v
        ax4 = plt.subplot(4, 5, i+16)
        im4 = ax4.contourf(X[:, :, t_idx], Y[:, :, t_idx], v_analytical[:, :, t_idx], 
                          levels=20, cmap='RdBu_r')
        ax4.set_title(f'Analytical: v {t_label}')
        ax4.set_xlabel('x')
        ax4.set_ylabel('y')
        plt.colorbar(im4, ax=ax4)
    
    plt.tight_layout()
    plt.savefig(results_dir / "burgers_visualization.png", dpi=150, bbox_inches='tight')
    plt.close()
    
    # Plot loss history
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    epochs = range(len(losses_history['total_loss']))
    ax.plot(epochs, losses_history['total_loss'], label='Total Loss', linewidth=2)
    ax.plot(epochs, losses_history['physics_loss'], label='Physics Loss', linewidth=2)
    ax.plot(epochs, losses_history['data_loss'], label='Data Loss', linewidth=2)
    if method == "rpit":
        ax.plot(epochs, losses_history['sensitivity_loss'], label='Sensitivity Loss', linewidth=2)
    elif method == "bayesian":
        ax.plot(epochs, losses_history['uncertainty_loss'], label='Uncertainty Loss', linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title(f'{method.upper()} Training Loss History')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_yscale('log')
    plt.tight_layout()
    plt.savefig(results_dir / "loss_history.png", dpi=150, bbox_inches='tight')
    plt.close()


def main():
    """Main function to run experiments."""
    parser = argparse.ArgumentParser(description='Run 2D Burgers equation experiment')
    parser.add_argument('--method', type=str, default='rpit', 
                       choices=['rpit', 'standard', 'bayesian'],
                       help='Method to use (default: rpit)')
    parser.add_argument('--epochs', type=int, default=5000,
                       help='Number of training epochs (default: 5000)')
    parser.add_argument('--seed', type=int, default=42,
                       help='Random seed (default: 42)')
    parser.add_argument('--device', type=str, default='auto',
                       choices=['auto', 'cpu', 'cuda'],
                       help='Device to use (default: auto)')
    
    args = parser.parse_args()
    
    # Check for GPU availability
    if args.device == 'auto':
        device = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        device = args.device
    
    print(f"Using device: {device}")
    if device == "cuda":
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    
    # Create configuration with best hyperparameters
    print(f"\n🔧 Loading configuration for {args.method} method...")
    config = get_optimized_config(
        problem_type="burgers",
        method=args.method,
        experiment_name=f"burgers_{args.method}",
        num_epochs=args.epochs,
        log_frequency=250,
        device=device,
        seed=args.seed
    )
    
    # Print configuration summary
    print_config_summary(config, args.method)
    
    print("Configuration:")
    print(f"  Experiment: {config.experiment_name}")
    print(f"  Method: {args.method}")
    print(f"  Device: {config.device}")
    print(f"  Epochs: {config.num_epochs}")
    print(f"  Viscosity (nu): {config.nu}")
    
    # Run experiment with specified method
    print("\n" + "="*50)
    results = run_experiment(config, method=args.method)
    
    # Print results
    print("\n" + "="*50)
    print("EXPERIMENT RESULTS:")
    print(f"Method: {args.method.upper()}")
    print(f"Total MSE:     {results['total_mse']:.6f}")
    print(f"Total MAE:     {results['total_mae']:.6f}")
    print(f"Results saved to: {results['results_dir']}")


if __name__ == "__main__":
    main()
