import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
from torch.utils.data import DataLoader, TensorDataset

from data_and_models import DeepONetWithBias, solve_pde

def main(args):
    """Main function to run the evaluation for a given run_id."""
    print(f"--- Starting Evaluation for run_id: {args.run_id} ---")
    
    # --- 1. Load Configuration and Paths ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {DEVICE}")

    run_dir = os.path.join(args.output_base_dir, args.run_id)
    model_path = os.path.join(run_dir, "deeponet_model.pth")
    hyperparams_path = os.path.join(run_dir, "hyperparams.yaml")
    
    if not os.path.exists(model_path) or not os.path.exists(hyperparams_path):
        print(f"FATAL: Model or hyperparams not found for run_id '{args.run_id}' in {run_dir}")
        return

    with open(hyperparams_path, 'r') as f:
        hyperparams = yaml.safe_load(f)

    model_arg_keys = [
        'latent_dim',
        'branch_depth',
        'branch_width',
        'trunk_depth',
        'trunk_width',
        'activation_fn'
    ]
    model_kwargs = {key: hyperparams[key] for key in model_arg_keys if key in hyperparams}

    model = DeepONetWithBias(
        branch_input_dim=config['M_SENSORS'],
        trunk_input_dim=config['TRUNK_INPUT_DIM'],
        **model_kwargs
    ).to(DEVICE)
    
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()
    print(f"Successfully loaded model '{args.run_id}' with architecture: {model_kwargs}")

    # --- 3. Evaluate MSE on Test and Train Sets ---
    results = {}
    for dataset_name in ['train', 'test']:
        data_path = os.path.join(args.output_base_dir, "data", f"{dataset_name}_data.npz")
        if not os.path.exists(data_path):
            print(f"Warning: {dataset_name} data not found at {data_path}. Skipping MSE calculation.")
            continue
            
        data = np.load(data_path)
        dataset = TensorDataset(torch.from_numpy(data['branch_inputs']).float(),
                                torch.from_numpy(data['trunk_inputs']).float(),
                                torch.from_numpy(data['outputs']).float())
        loader = DataLoader(dataset, batch_size=config['BATCH_SIZE'] * 4)

        total_se = 0
        with torch.no_grad():
            for branch, trunk, output in loader:
                predictions = model(branch.to(DEVICE), trunk.to(DEVICE))
                total_se += torch.sum((predictions - output.to(DEVICE))**2).item()
        
        mse = total_se / len(dataset)
        results[f'{dataset_name}_mse'] = mse
        print(f"Mean Squared Error on {dataset_name.capitalize()} Data: {mse:.4e}")

    if 'train_mse' in results and 'test_mse' in results:
        generalization_error = abs(results['train_mse'] - results['test_mse'])
        results['generalization_error'] = generalization_error
        print(f"Generalization Error (Train MSE - Test MSE): {generalization_error:.4e}")

    results_path = os.path.join(run_dir, "evaluation_results.yaml")
    with open(results_path, 'w') as f:
        yaml.dump(results, f)
    print(f"Evaluation metrics saved to {results_path}")

    # --- 4. Visualize a Single, New Test Case ---
    print("\nVisualizing a single new test case...")
    np.random.seed(42)
    u_case = np.random.multivariate_normal(np.zeros(config['NX_SOLVER']), np.eye(config['NX_SOLVER']))
    vref_case = np.full(config['NX_SOLVER'], config['V_REF_VAL'])
    true_solution = solve_pde(config, u_case, vref_case).T

    x_grid_solver = np.linspace(0, config['L'], config['NX_SOLVER'])
    sensor_locs = np.linspace(0, config['L'], config['M_SENSORS'])
    u_sensors = np.interp(sensor_locs, x_grid_solver, u_case)
    branch_input_case = torch.tensor(u_sensors, dtype=torch.float32).unsqueeze(0).to(DEVICE)
    
    xv, tv = np.meshgrid(x_grid_solver, np.linspace(0, config['T_FINAL'], config['NT_SOLVER']))
    trunk_input_grid = torch.tensor(np.stack([xv.ravel(), tv.ravel()], axis=-1), dtype=torch.float32).to(DEVICE)
    branch_input_grid = branch_input_case.repeat(trunk_input_grid.shape[0], 1)
    
    with torch.no_grad():
        predicted_solution = model(branch_input_grid, trunk_input_grid).cpu().numpy().reshape(config['NT_SOLVER'], config['NX_SOLVER']).T

    # --- 5. Plotting ---
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Define a font size for better readability in publications
    font_size = 20 
    
    # The main title is removed as requested
    # fig.suptitle(f'Evaluation for run_id: {args.run_id}', fontsize=16)
    
    v_min, v_max = true_solution.min(), true_solution.max()
    
    # Plot for the Exact Solution
    im1 = axes[0].imshow(true_solution, extent=[0, config['T_FINAL'], 0, config['L']], origin='lower', aspect='auto', vmin=v_min, vmax=v_max, cmap='viridis')
    axes[0].set_title("Exact Solution", fontsize=font_size)
    axes[0].set_xlabel("t", fontsize=font_size)
    axes[0].set_ylabel("x", fontsize=font_size)
    cbar1 = fig.colorbar(im1, ax=axes[0])
    cbar1.ax.tick_params(labelsize=font_size)
    
    # Plot for the Predicted Solution
    im2 = axes[1].imshow(predicted_solution, extent=[0, config['T_FINAL'], 0, config['L']], origin='lower', aspect='auto', vmin=v_min, vmax=v_max, cmap='viridis')
    axes[1].set_title("Predicted Solution", fontsize=font_size) # Corrected title
    axes[1].set_xlabel("t", fontsize=font_size)
    axes[1].set_ylabel("x", fontsize=font_size)
    cbar2 = fig.colorbar(im2, ax=axes[1])
    cbar2.ax.tick_params(labelsize=font_size)
    
    # Plot for the Absolute Error
    error = np.abs(true_solution - predicted_solution)
    im3 = axes[2].imshow(error, extent=[0, config['T_FINAL'], 0, config['L']], origin='lower', aspect='auto', cmap='viridis')
    axes[2].set_title("Absolute Error", fontsize=font_size)
    axes[2].set_xlabel("t", fontsize=font_size)
    axes[2].set_ylabel("x", fontsize=font_size)
    cbar3 = fig.colorbar(im3, ax=axes[2])
    cbar3.ax.tick_params(labelsize=font_size) 
    
   
    for ax in axes:
        ax.tick_params(axis='both', which='major', labelsize=font_size)
    
    plt.tight_layout()

    
    plot_path = os.path.join(run_dir, "evaluation_plot.pdf")
    
    
    plt.savefig(plot_path, bbox_inches='tight')
    
    print(f"Evaluation PDF saved to {plot_path}")
    print(f"\n--- Evaluation for {args.run_id} Finished ---")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate a trained DeepONet model.")
    parser.add_argument("--config_path", type=str, required=True, help="Path to the base YAML config file.")
    parser.add_argument("--output_base_dir", type=str, required=True, help="Base directory where results are stored.")
    parser.add_argument("--run_id", type=str, required=True, help="The specific run_id of the model to evaluate.")
    args = parser.parse_args()
    main(args)