#!/usr/bin/env python3
"""
Example script for solving the Helmholtz equation with the mesh_opt package.

This example demonstrates how to:
1. Set up and solve the Helmholtz equation (wave equation)
2. Optimize the mesh for better accuracy
3. Visualize the solution

The Helmholtz equation is: -Δu - k²u = f
where k is the wavenumber parameter.
"""

import os
import sys
import argparse
import numpy as np
import json

# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import mesh_opt
from skfem import Basis, ElementTriP1
from skfem.helpers import dot, grad

CONFIG_PATH = "../baseline_configs/center_uniform.yaml"
# CONFIG_PATH = "../baseline_configs/mesh_aware_enhanced.yaml"
# CONFIG_PATH = "../baseline_configs/rejection_method.yaml"
# CONFIG_PATH = "../baseline_configs/convex_combination.yaml"


def main():
    """Run the Helmholtz equation example."""
    parser = argparse.ArgumentParser(description="Helmholtz Equation Example")
    parser.add_argument(
        "--mesh", 
        type=str, 
        default=None, 
        help="Path to initial mesh file (optional)"
    )
    parser.add_argument(
        "--output", 
        type=str, 
        default="./helmholtz_output", 
        help="Output directory"
    )
    parser.add_argument(
        "--wavenumber", 
        type=float, 
        default=5.0, 
        help="Wavenumber (k) parameter"
    )
    parser.add_argument(
        "--iterations", 
        type=int, 
        default=20, 
        help="Number of optimization iterations"
    )
    parser.add_argument(
        "--step-size", 
        type=float, 
        default=0.01, 
        help="Step size for gradient descent"
    )
    parser.add_argument(
        "--dim", 
        type=int, 
        default=20, 
        help="Mesh dimension (number of elements in x and y directions)"
    )
    
    args = parser.parse_args()
    
    print("Helmholtz Equation Example")
    print("------------------------")
    print(f"Initial mesh: {'Generated automatically' if args.mesh is None else args.mesh}")
    print(f"Output directory: {args.output}")
    print(f"Wavenumber (k): {args.wavenumber}")
    print(f"Iterations: {args.iterations}")
    print(f"Step size: {args.step_size}")
    print(f"Mesh dimension: {args.dim}")
    print()
    
    # Create output directory if it doesn't exist
    if not os.path.exists(args.output):
        os.makedirs(args.output)
    
    # Create or load initial mesh
    if args.mesh is None:
        print("Creating uniform mesh...")
        mesh = mesh_opt.create_uniform_mesh(nx=args.dim, ny=args.dim)
    else:
        print(f"Loading mesh from {args.mesh}...")
        mesh = mesh_opt.load_mesh(args.mesh)
    
    # Define the Helmholtz equation
    k = args.wavenumber
    
    def helmholtz_form(u, v, w):
        # Bilinear form: -Δu - k²u
        # The Laplacian term: ∫(∇u·∇v) dx
        # The k² term: -k² ∫(u·v) dx
        return dot(grad(u), grad(v)) - k**2 * u * v
    
    # Source term (right-hand side)
    def source_function(x, y):
        # A source function with a few localized sources
        source1 = 10 * np.exp(-50 * ((x - 0.25)**2 + (y - 0.25)**2))
        source2 = -10 * np.exp(-50 * ((x - 0.75)**2 + (y - 0.75)**2))
        return source1 + source2
    
    def load_form(v, w):
        x = w.x
        f = source_function(x[0], x[1])
        return f * v
    
    weak_form = mesh_opt.WeakForm(bilinear_form=helmholtz_form, linear_form=load_form)
    
    # Define boundary conditions (zero at boundaries)
    def zero_bc(x):
        return 0.0
    
    bc = mesh_opt.DirichletBC(zero_bc)
    solver = mesh_opt.PDESolver(weak_form)
    
    # Analytical solution is complex for this problem
    # We'll use a fine mesh solution as our reference
    
    print("Computing reference solution on fine mesh...")
    fine_mesh = mesh_opt.create_uniform_mesh(nx=100, ny=100)
    fine_basis = Basis(fine_mesh, ElementTriP1())
    fine_solution = solver.solve(fine_basis, bc)
    
    def ground_truth(points):
        # Interpolate the fine-mesh solution
        return mesh_opt.interpolate_solution(fine_solution, points)
    
    # Configure trainer
    print("Setting up trainer...") 
    
    trainer_config = mesh_opt.TrainerConfig.from_yaml(CONFIG_PATH)
    trainer_config.step_size =  args.step_size
    trainer_config.n_iterations = args.iterations
    trainer_config.tag = f"helmholtz_equation_k={k}_dim={args.dim}_step_size={args.step_size}_iterations={args.iterations}"
    trainer = mesh_opt.Trainer(trainer_config, solver, bc, ground_truth)
    
    # Run optimization
    print("Running optimization...")
    result = trainer.run(mesh)
    
    # Generate reports
    print("Generating reports...")
    mesh_opt.plot_mse_history(
        result['loss_history'], 
        output_file=os.path.join(args.output, "mse_history.png")
    )
    
    mesh_opt.export_optimized_mesh(
        result['final_mesh'], 
        output_file=os.path.join(args.output, "optimized_mesh.npz")
    )
    
    # Solve on initial and optimized mesh for visualization
    initial_basis = Basis(result['initial_mesh'], ElementTriP1())
    initial_solution = solver.solve(initial_basis, bc)
    
    optimized_basis = Basis(result['final_mesh'], ElementTriP1())
    optimized_solution = solver.solve(optimized_basis, bc)
    # TODO: move this visualization to the trainer 
    # Visualize solutions
    mesh_opt.visualize_mesh_solution(
        result['initial_mesh'],
        initial_solution,
        output_file=os.path.join(args.output, "initial_solution.png"),
        title=f"Helmholtz Equation (k={k}) on Initial Mesh"
    )
    
    mesh_opt.visualize_mesh_solution(
        result['final_mesh'],
        optimized_solution,
        output_file=os.path.join(args.output, "optimized_solution.png"),
        title=f"Helmholtz Equation (k={k}) on Optimized Mesh"
    )
    
    # Create a difference plot between initial and optimized solutions
    def visualize_solution_difference(
        initial_mesh, initial_solution, 
        optimized_mesh, optimized_solution,
        output_file=None,
        title="Difference between Initial and Optimized Solutions",
        colorbar_label="Difference",
        figsize=(10, 8)
    ):
        """
        Visualize the difference between initial and optimized solutions.
        
        This function creates a common grid and interpolates both solutions
        onto it to compute their difference.
        
        Returns the computed data for saving.
        """
        import matplotlib.pyplot as plt
        import numpy as np
        from matplotlib.tri import Triangulation
        from scipy.interpolate import LinearNDInterpolator
        
        # Create a regular grid for comparison
        grid_size = 200
        x = np.linspace(0, 1, grid_size)
        y = np.linspace(0, 1, grid_size)
        X, Y = np.meshgrid(x, y)
        grid_points = np.vstack([X.flatten(), Y.flatten()]).T
        
        # Extract values from initial solution
        initial_x = initial_mesh.p[0, :]
        initial_y = initial_mesh.p[1, :]
        initial_mesh_points = np.vstack([initial_x, initial_y]).T
        initial_values = initial_solution.value
        
        # Extract values from optimized solution
        optimized_x = optimized_mesh.p[0, :]
        optimized_y = optimized_mesh.p[1, :]
        optimized_mesh_points = np.vstack([optimized_x, optimized_y]).T
        optimized_values = optimized_solution.value
        
        # Create interpolators for both solutions
        initial_interpolator = LinearNDInterpolator(initial_mesh_points, initial_values)
        optimized_interpolator = LinearNDInterpolator(optimized_mesh_points, optimized_values)
        
        # Interpolate both solutions onto the regular grid
        initial_on_grid = initial_interpolator(grid_points)
        optimized_on_grid = optimized_interpolator(grid_points)
        
        # Compute the difference
        difference = optimized_on_grid - initial_on_grid
        
        # Create the visualization
        plt.figure(figsize=figsize)
        
        # Plot the difference as a contour plot
        plt.contourf(X, Y, difference.reshape(grid_size, grid_size), 
                     cmap='RdBu_r', levels=50)
        
        # Add colorbar
        cbar = plt.colorbar()
        cbar.set_label(colorbar_label, fontsize=12)
        
        # Add labels and title
        plt.xlabel('x', fontsize=12)
        plt.ylabel('y', fontsize=12)
        plt.title(title, fontsize=14)
        
        # Set aspect ratio to equal
        plt.axis('equal')
        plt.tight_layout()
        
        # Save if requested
        if output_file:
            plt.savefig(output_file, dpi=300, bbox_inches='tight')
            plt.close()
        else:
            plt.show()
        
        # Return the computed data for saving
        return {
            'grid_x': X,
            'grid_y': Y,
            'initial_on_grid': initial_on_grid.reshape(grid_size, grid_size),
            'optimized_on_grid': optimized_on_grid.reshape(grid_size, grid_size),
            'difference': difference.reshape(grid_size, grid_size),
            'grid_points': grid_points,
            'initial_mesh_points': initial_mesh_points,
            'optimized_mesh_points': optimized_mesh_points,
            'initial_values': initial_values,
            'optimized_values': optimized_values
        }
    
    # Generate and save the difference plot
    difference_data = visualize_solution_difference(
        result['initial_mesh'], initial_solution,
        result['final_mesh'], optimized_solution,
        output_file=os.path.join(args.output, "solution_difference.png"),
        title=f"Difference in Solutions for Helmholtz Equation (k={k})"
    )
    
    # ========== SAVE ALL DATA FOR PLOTTING ==========
    print("Saving plotting data...")
    
    # 1. Save MSE history data
    np.savetxt(os.path.join(args.output, "mse_history.csv"), 
               result['loss_history'], 
               delimiter=',', 
               header='iteration,mse',
               comments='')
    
    # 2. Save initial mesh and solution data
    initial_mesh_data = {
        'coordinates': mesh_opt.mesh_to_coords(result['initial_mesh']),
        'triangulation': result['initial_mesh'].t.T,  # Transpose for standard format
        'solution_values': initial_solution.value
    }
    np.savez(os.path.join(args.output, "initial_mesh_solution.npz"), **initial_mesh_data)
    
    # 3. Save optimized mesh and solution data
    optimized_mesh_data = {
        'coordinates': mesh_opt.mesh_to_coords(result['final_mesh']),
        'triangulation': result['final_mesh'].t.T,  # Transpose for standard format
        'solution_values': optimized_solution.value
    }
    np.savez(os.path.join(args.output, "optimized_mesh_solution.npz"), **optimized_mesh_data)
    
    # 4. Save difference analysis data
    np.savez(os.path.join(args.output, "solution_difference_data.npz"), **difference_data)
    
    # 5. Save configuration and metadata
    metadata = {
        'wavenumber': k,
        'mesh_dimension': args.dim,
        'step_size': args.step_size,
        'iterations': args.iterations,
        'initial_mse': float(result['loss_history'][0]),
        'final_mse': float(result['loss_history'][-1]),
        'improvement_percentage': float((result['loss_history'][0] - result['loss_history'][-1]) / result['loss_history'][0] * 100),
        'config_path': CONFIG_PATH,
        'trainer_config': {
            'step_size': trainer_config.step_size,
            'fd_radius': trainer_config.fd_radius,
            'n_samples': trainer_config.n_samples,
            'n_iterations': trainer_config.n_iterations,
            'regularization_weight': trainer_config.regularization_weight,
            'estimator_type': trainer_config.estimator_type,
            'gradient_estimator': trainer_config.gradient_estimator
        }
    }
    
    # Save metadata as JSON
    with open(os.path.join(args.output, "experiment_metadata.json"), 'w') as f:
        json.dump(metadata, f, indent=2)
    
    # 6. Save additional result data if available
    if 'gradient_norm_history' in result:
        np.savetxt(os.path.join(args.output, "gradient_norm_history.csv"), 
                   result['gradient_norm_history'], 
                   delimiter=',', 
                   header='iteration,gradient_norm',
                   comments='')
    
    if 'mesh_mse_history' in result:
        np.savetxt(os.path.join(args.output, "mesh_mse_history.csv"), 
                   result['mesh_mse_history'], 
                   delimiter=',', 
                   header='iteration,mesh_mse',
                   comments='')
    
    # 7. Save fine mesh reference solution (ground truth)
    fine_mesh_data = {
        'coordinates': mesh_opt.mesh_to_coords(fine_mesh),
        'triangulation': fine_mesh.t.T,
        'solution_values': fine_solution.value
    }
    np.savez(os.path.join(args.output, "fine_mesh_reference.npz"), **fine_mesh_data)
    
    # Create a comprehensive data summary
    data_summary = {
        'files': {
            'mse_history': 'mse_history.csv',
            'gradient_norm_history': 'gradient_norm_history.csv' if 'gradient_norm_history' in result else None,
            'mesh_mse_history': 'mesh_mse_history.csv' if 'mesh_mse_history' in result else None,
            'initial_mesh_solution': 'initial_mesh_solution.npz',
            'optimized_mesh_solution': 'optimized_mesh_solution.npz',
            'solution_difference_data': 'solution_difference_data.npz',
            'fine_mesh_reference': 'fine_mesh_reference.npz',
            'experiment_metadata': 'experiment_metadata.json'
        },
        'description': {
            'mse_history.csv': 'MSE values for each iteration',
            'gradient_norm_history.csv': 'Gradient norm values for each iteration',
            'mesh_mse_history.csv': 'Mesh MSE values for each iteration',
            'initial_mesh_solution.npz': 'Initial mesh coordinates, triangulation, and solution values',
            'optimized_mesh_solution.npz': 'Optimized mesh coordinates, triangulation, and solution values',
            'solution_difference_data.npz': 'Grid-based difference analysis between initial and optimized solutions',
            'fine_mesh_reference.npz': 'Fine mesh used as ground truth reference',
            'experiment_metadata.json': 'Complete experiment configuration and results summary'
        },
        'data_format': {
            'csv_files': 'Plain text, comma-separated values',
            'npz_files': 'Numpy compressed archive, load with np.load()',
            'json_files': 'JSON format, load with json.load()'
        }
    }
    
    # Save data summary
    with open(os.path.join(args.output, "data_summary.json"), 'w') as f:
        json.dump(data_summary, f, indent=2)
    
    print(f"All plotting data saved to: {args.output}")
    print("Data files created:")
    for file_type, filename in data_summary['files'].items():
        if filename:
            print(f"  - {filename}: {data_summary['description'][filename]}")
    
    # Print some statistics
    initial_loss = result['loss_history'][0]
    final_loss = result['loss_history'][-1]
    improvement = (initial_loss - final_loss) / initial_loss * 100
    
    print("\nOptimization Results:")
    print(f"Initial MSE: {initial_loss:.6f}")
    print(f"Final MSE: {final_loss:.6f}")
    print(f"Improvement: {improvement:.2f}%")
    
    # Print paths to output files
    print("\nVisualization Files:")
    print(f"MSE History Plot: {os.path.join(args.output, 'mse_history.png')}")
    print(f"Initial Solution: {os.path.join(args.output, 'initial_solution.png')}")
    print(f"Optimized Solution: {os.path.join(args.output, 'optimized_solution.png')}")
    print(f"Solution Difference: {os.path.join(args.output, 'solution_difference.png')}")
    print(f"Optimized Mesh: {os.path.join(args.output, 'optimized_mesh.npz')}")


if __name__ == "__main__":
    main() 

    # Original examples:
    # python examples/helmholtz_equation.py --wavenumber 10 --iterations 20000 --output ./results/helmholtz-baselines-center-uniform --step-size 500.0 

    # Baseline 1: 
    # python examples/helmholtz_equation.py --wavenumber 10 --iterations 20000 --output ./results/helmholtz-baselines-rejection-method --step-size 500.0
    # python examples/helmholtz_equation.py --wavenumber 10 --iterations 20000 --output ./results/helmholtz-baselines-rejection-method --step-size 400.0
    # python examples/helmholtz_equation.py --wavenumber 10 --iterations 20000 --output ./results/helmholtz-baselines-rejection-method --step-size 300.0

    # Baseline 2: 
    # python examples/helmholtz_equation.py --wavenumber 10 --iterations 20000 --output ./results/helmholtz-baselines-convex-combination --step-size 500.0 

    # Baseline 3: 
    # python examples/helmholtz_equation.py --wavenumber 10 --iterations 20000 --output ./results/helmholtz-baselines-mesh-aware --step-size 500.0 
    # python examples/helmholtz_equation.py --wavenumber 10 --iterations 20000 --output ./results/helmholtz-baselines-mesh-aware --step-size 400.0 
    # python examples/helmholtz_equation.py --wavenumber 10 --iterations 20000 --output ./results/helmholtz-baselines-mesh-aware --step-size 300.0 