"""
1D Inverse Poisson problem experiment script.
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import sys
import os
import argparse

# Add src and config to path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root) 
sys.path.append(os.path.join(project_root, 'src'))

from models.problem_models import create_problem_model
from problems.inverse_poisson import InversePoisson
from config.inverse_config import InverseConfig
from utils.config_loader import get_optimized_config, print_config_summary


def run_experiment(config: InverseConfig, method: str = "rpit"):
    """
    Run 1D Inverse Poisson problem experiment.
    
    Args:
        config: Configuration object
        method: Method to use ('rpit', 'standard', 'bayesian')
    """
    print(f"Running 1D Inverse Poisson experiment with {method} method...")
    
    # Set random seeds
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    
    # Create problem
    problem = InversePoisson(
        x_start=config.x_start,
        x_end=config.x_end,
        device=config.device
    )
    
    # Generate training data
    print("Generating training data...")
    x_data, y_data = problem.generate_training_data(
        num_boundary_points=config.num_boundary_points,
        num_interior_points=config.num_interior_points,
        corruption_level=config.corruption_level,
        outlier_std=config.outlier_std,
        missing_data_ratio=config.missing_data_ratio
    )
    
    # Generate collocation points
    x_collocation = problem.generate_collocation_points(
        num_points=config.num_collocation_points
    )
    
    print(f"Training data shape: {x_data.shape}, {y_data.shape}")
    print(f"Collocation points shape: {x_collocation.shape}")
    print(f"Data corruption level: {config.corruption_level:.1%}")
    print(f"Missing data ratio: {config.missing_data_ratio:.1%}")
    
    # Create model using problem-specific factory
    if method == "rpit":
        model = create_problem_model(
            problem_type="inverse_poisson",
            method=method,
            input_dim=1,  # Spatial coordinate only
            output_dim=1,  # Scalar field
            hidden_layers=config.hidden_layers,
            activation=config.activation,
            device=config.device,
            lambda_sens=config.lambda_sens,
            lambda_var=config.lambda_var,
            noise_std=config.noise_std,
            uncertainty_output=True
        )
    else:
        model = create_problem_model(
            problem_type="inverse_poisson",
            method=method,
            input_dim=1,  # Spatial coordinate only
            output_dim=1,  # Scalar field
            hidden_layers=config.hidden_layers,
            activation=config.activation,
            device=config.device
        )
    
    print(f"Model created: {model.get_model_info()}")
    
    # Set up optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    
    # Training loop
    print("Starting training...")
    model.train()
    
    losses_history = {
        'total_loss': [],
        'physics_loss': [],
        'data_loss': [],
    }
    
    if method == "rpit":
        losses_history['sensitivity_loss'] = []
    elif method == "bayesian":
        losses_history['uncertainty_loss'] = []
    
    for epoch in range(config.num_epochs):
        optimizer.zero_grad()
        
        # Compute loss
        if method == "rpit":
            losses = model.compute_total_loss(x_collocation, x_data, y_data)
        else:
            losses = model.compute_total_loss(x_collocation, x_data, y_data)
        
        # Backward pass
        losses['total_loss'].backward()
        optimizer.step()
        
        # Log losses
        if epoch % config.log_frequency == 0:
            print(f"Epoch {epoch:5d}: Total Loss = {losses['total_loss'].item():.6f}")
            
            for loss_name, loss_value in losses.items():
                losses_history[loss_name].append(loss_value.item())
    
    print("Training completed!")
    
    # Test prediction
    print("Testing predictions...")
    model.eval()
    
    # Generate test points
    x_test = torch.linspace(config.x_start, config.x_end, 200, device=config.device).unsqueeze(1)
    
    # Get analytical solution
    u_analytical = problem.solve_analytical(x_test)
    
    with torch.no_grad():
        if method == "rpit":
            u_pred_mean, u_pred_var = model.predict_mean_variance(x_test)
            uncertainty = model.get_uncertainty_estimates(x_test)
        else:
            u_pred_mean = model.predict(x_test)
            u_pred_var = torch.zeros_like(u_pred_mean)
            uncertainty = torch.zeros_like(u_pred_mean)
    
    # Convert to numpy
    x_test_np = x_test.cpu().numpy().flatten()
    u_pred_mean = u_pred_mean.cpu().numpy().flatten()
    u_pred_var = u_pred_var.cpu().numpy().flatten()
    u_analytical = u_analytical.cpu().numpy().flatten()
    uncertainty = uncertainty.cpu().numpy().flatten()
    
    # Compute metrics
    mse = np.mean((u_pred_mean - u_analytical)**2)
    mae = np.mean(np.abs(u_pred_mean - u_analytical))
    rmse = np.sqrt(mse)
    
    print(f"Test MSE: {mse:.6f}")
    print(f"Test MAE: {mae:.6f}")
    print(f"Test RMSE: {rmse:.6f}")
    
    # Save results
    results_dir = Path(config.results_dir) / f"inverse_poisson_{method}"
    results_dir.mkdir(parents=True, exist_ok=True)
    
    # Save model
    model_path = results_dir / "model.pth"
    model.save_checkpoint(str(model_path))
    
    # Save predictions
    np.save(results_dir / "x_test.npy", x_test_np)
    np.save(results_dir / "u_pred_mean.npy", u_pred_mean)
    np.save(results_dir / "u_pred_var.npy", u_pred_var)
    np.save(results_dir / "u_analytical.npy", u_analytical)
    np.save(results_dir / "uncertainty.npy", uncertainty)
    
    # Save training data
    np.save(results_dir / "x_data.npy", x_data.cpu().numpy())
    np.save(results_dir / "y_data.npy", y_data.cpu().numpy())
    
    # Save losses
    np.save(results_dir / "losses_history.npy", losses_history)
    
    # Create visualization
    create_inverse_visualization(
        x_test_np, u_pred_mean, u_pred_var, u_analytical, uncertainty,
        x_data.cpu().numpy(), y_data.cpu().numpy(),
        losses_history, method, results_dir
    )
    
    print(f"Results saved to {results_dir}")
    
    return {
        'mse': mse,
        'mae': mae,
        'rmse': rmse,
        'losses_history': losses_history,
        'results_dir': results_dir
    }


def create_inverse_visualization(x_test, u_pred_mean, u_pred_var, u_analytical, uncertainty,
                                x_data, y_data, losses_history, method, results_dir):
    """Create visualization for inverse problem results."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Plot 1: Solution comparison
    ax1 = axes[0, 0]
    ax1.plot(x_test, u_analytical, 'b-', label='True Solution', linewidth=2, alpha=0.8)
    ax1.plot(x_test, u_pred_mean, 'r--', label=f'{method.upper()} Prediction', linewidth=2, alpha=0.8)
    ax1.scatter(x_data, y_data, color='green', s=30, alpha=0.7, label='Training Data', zorder=5)
    
    if method == "rpit":
        ax1.fill_between(x_test,
                        u_pred_mean - 2*uncertainty,
                        u_pred_mean + 2*uncertainty,
                        alpha=0.3, color='red', label='±2σ Uncertainty')
    
    ax1.set_xlabel('x')
    ax1.set_ylabel('u(x)')
    ax1.set_title(f'{method.upper()}: Solution Comparison')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Error analysis
    ax2 = axes[0, 1]
    error = np.abs(u_pred_mean - u_analytical)
    ax2.plot(x_test, error, 'r-', linewidth=2, label='Absolute Error')
    ax2.set_xlabel('x')
    ax2.set_ylabel('|u_pred - u_true|')
    ax2.set_title(f'{method.upper()}: Error Analysis')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_yscale('log')
    
    # Plot 3: Uncertainty analysis (R-PIT only)
    ax3 = axes[1, 0]
    if method == "rpit":
        ax3.plot(x_test, uncertainty, 'purple', linewidth=2, label='Uncertainty Estimation')
        ax3.set_xlabel('x')
        ax3.set_ylabel('Uncertainty')
        ax3.set_title(f'{method.upper()}: Uncertainty Estimation')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
    else:
        ax3.text(0.5, 0.5, 'Standard PINN does not provide\nuncertainty estimation', 
                ha='center', va='center', transform=ax3.transAxes, fontsize=12)
        ax3.set_title('Uncertainty Estimation')
    
    # Plot 4: Loss history
    ax4 = axes[1, 1]
    epochs = range(len(losses_history['total_loss']))
    ax4.plot(epochs, losses_history['total_loss'], label='Total Loss', linewidth=2)
    ax4.plot(epochs, losses_history['physics_loss'], label='Physics Loss', linewidth=2)
    ax4.plot(epochs, losses_history['data_loss'], label='Data Loss', linewidth=2)
    if method == "rpit":
        ax4.plot(epochs, losses_history['sensitivity_loss'], label='Sensitivity Loss', linewidth=2)
    elif method == "bayesian":
        ax4.plot(epochs, losses_history['uncertainty_loss'], label='Uncertainty Loss', linewidth=2)
    ax4.set_xlabel('Training Epochs')
    ax4.set_ylabel('Loss Value')
    ax4.set_title(f'{method.upper()}: Training Loss History')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    ax4.set_yscale('log')
    
    plt.tight_layout()
    plt.savefig(results_dir / "inverse_visualization.png", dpi=150, bbox_inches='tight')
    plt.close()
    
    # Create detailed error analysis
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    
    # Error distribution
    error = np.abs(u_pred_mean - u_analytical)
    ax.hist(error, bins=50, alpha=0.7, density=True, label=f'{method.upper()} Error Distribution')
    ax.axvline(np.mean(error), color='red', linestyle='--', 
              label=f'Average Error: {np.mean(error):.4f}')
    ax.set_xlabel('Absolute Error')
    ax.set_ylabel('Density')
    ax.set_title(f'{method.upper()}: Error Distribution')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_yscale('log')
    
    plt.tight_layout()
    plt.savefig(results_dir / "error_distribution.png", dpi=150, bbox_inches='tight')
    plt.close()


def main():
    """Main function to run experiments."""
    parser = argparse.ArgumentParser(description='Run 1D Inverse Poisson problem experiment')
    parser.add_argument('--method', type=str, default='rpit', 
                       choices=['rpit', 'standard', 'bayesian'],
                       help='Method to use (default: rpit)')
    parser.add_argument('--epochs', type=int, default=3000,
                       help='Number of training epochs (default: 3000)')
    parser.add_argument('--seed', type=int, default=42,
                       help='Random seed (default: 42)')
    parser.add_argument('--device', type=str, default='auto',
                       choices=['auto', 'cpu', 'cuda'],
                       help='Device to use (default: auto)')
    
    args = parser.parse_args()
    
    # Check for GPU availability
    if args.device == 'auto':
        device = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        device = args.device
    
    print(f"Using device: {device}")
    if device == "cuda":
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    
    # Create configuration with best hyperparameters
    print(f"\n🔧 Loading configuration for {args.method} method...")
    config = get_optimized_config(
        problem_type="inverse_poisson",
        method=args.method,
        experiment_name=f"inverse_{args.method}",
        num_epochs=args.epochs,
        log_frequency=150,
        device=device,
        seed=args.seed
    )
    
    # Print configuration summary
    print_config_summary(config, args.method)
    
    print("Configuration:")
    print(f"  Experiment: {config.experiment_name}")
    print(f"  Method: {args.method}")
    print(f"  Device: {config.device}")
    print(f"  Epochs: {config.num_epochs}")
    print(f"  Corruption level: {config.corruption_level:.1%}")
    print(f"  Missing data ratio: {config.missing_data_ratio:.1%}")
    
    # Run experiment with specified method
    print("\n" + "="*50)
    results = run_experiment(config, method=args.method)
    
    # Print results
    print("\n" + "="*50)
    print("EXPERIMENT RESULTS:")
    print(f"Method: {args.method.upper()}")
    print(f"MSE:     {results['mse']:.6f}")
    print(f"MAE:     {results['mae']:.6f}")
    print(f"RMSE:    {results['rmse']:.6f}")
    print(f"Results saved to: {results['results_dir']}")


if __name__ == "__main__":
    main()
