# src/evaluate_propagator_deeponet.py
# Autoregressive evaluation of a trained PropagatorDeepONet for the Burgers' Equation.
# This version correctly handles different solver and sensor grid resolutions via interpolation.

import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
from scipy.interpolate import interp1d

# Import from the Burgers' project files
from data_and_models import PropagatorDeepONet, BurgersSimulator, generate_grf_time_series

def main(args):
    # --- 1. Load Configs and Set Up ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Autoregressive Evaluation for Burgers' Propagator: {args.run_id} ---")

    run_dir = os.path.join(args.output_base_dir, args.run_id)
    model_path = os.path.join(run_dir, "burgers_propagator_best.pth")
    hyperparams_path = os.path.join(run_dir, "hyperparams.yaml")
    if not os.path.exists(model_path):
        print(f"FATAL: Model not found at {model_path}"); return

    # --- 2. Load the Trained Propagator Model ---
    with open(hyperparams_path, 'r') as f:
        hyperparams = yaml.safe_load(f)
    
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    model_kwargs = {key: hyperparams[key] for key in model_arg_keys}
    model = PropagatorDeepONet(
        M_sensors=config['M_SENSORS'],
        num_basis_functions=config['NUM_BASIS_FUNCTIONS'],
        trunk_input_dim=config['TRUNK_INPUT_DIM'],
        **model_kwargs
    ).to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()
    print("Burgers' PropagatorDeepONet model loaded successfully.")

    # --- 3. Generate a Single, New, Unseen Test Case ---
    print("\n--- Generating a new, unseen test case... ---")
    np.random.seed(42) # Use a fixed seed for a consistent test case
    w_case_sequence = generate_grf_time_series(config, config['NUM_BASIS_FUNCTIONS'])
    w_case_sequence = np.clip(w_case_sequence * 1.5, -config['CONTROL_SCALE'], config['CONTROL_SCALE'])
    x_grid_solver = np.linspace(0, config['L'], config['NX_SOLVER'])

    # --- 4. Get the Ground Truth solution from the HIGH-RES numerical solver ---
    simulator = BurgersSimulator(config)
    initial_state_np = np.zeros(config['NX_SOLVER'])
    for _ in range(np.random.randint(1, 4)):
        initial_state_np += np.random.uniform(-0.5, 0.5) * np.sin(np.random.randint(1, 4) * np.pi * x_grid_solver)
    
    current_state_gt = initial_state_np.copy()
    ground_truth_history = [current_state_gt]
    for k in range(config['NT_SOLVER'] - 1):
        next_state_gt = simulator.step(current_state_gt, w_case_sequence[k, :])
        ground_truth_history.append(next_state_gt)
        current_state_gt = next_state_gt
    ground_truth_trajectory = np.array(ground_truth_history)

    # --- 5. Perform AUTOREGRESSIVE Prediction (Rollout) ---
    print("--- Performing autoregressive rollout... ---")
    
    # Interpolate the fine-grid initial state down to the sensor grid for the model's first input
    sensor_locs_np = np.linspace(0, config['L'], config['M_SENSORS'])
    interp_initial = interp1d(x_grid_solver, initial_state_np)
    initial_state_sensors = interp_initial(sensor_locs_np)
    
    predicted_states_sensors = [initial_state_sensors]
    current_state_sensors_torch = torch.from_numpy(initial_state_sensors).float().unsqueeze(0).to(DEVICE)
    
    x_grid_sensors = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).float().unsqueeze(0).unsqueeze(-1)
    w_case_tensor = torch.from_numpy(w_case_sequence).float().to(DEVICE)

    with torch.no_grad():
        for k in range(config['NT_SOLVER'] - 1):
            w_k = w_case_tensor[k, :].unsqueeze(0)
            next_state_pred_sensors = model(current_state_sensors_torch, w_k, x_grid_sensors).squeeze()
            predicted_states_sensors.append(next_state_pred_sensors.cpu().numpy())
            current_state_sensors_torch = next_state_pred_sensors.unsqueeze(0)

    predicted_trajectory_sensors = np.array(predicted_states_sensors)

    # --- 6. Calculate Final MSE and Save Results ---
    # For a fair comparison, interpolate the high-res ground truth to the sensor grid
    interp_gt = interp1d(x_grid_solver, ground_truth_trajectory, axis=1, kind='cubic')
    true_solution_on_sensors = interp_gt(sensor_locs_np)
    
    rollout_mse = np.mean((true_solution_on_sensors - predicted_trajectory_sensors)**2)
    print(f"\nFinal Autoregressive Rollout MSE: {rollout_mse:.4e}")
    results = {'autoregressive_rollout_mse': float(rollout_mse)}
    results_path = os.path.join(run_dir, "evaluation_rollout_results.yaml")
    with open(results_path, 'w') as f: yaml.dump(results, f, sort_keys=False)

    # --- 7. Plotting Results (Modified to save as PDF for LaTeX) ---
    print("--- Plotting results... ---")
    # --- OPTIONAL IMPROVEMENT: Add sharey=True ---
    fig, axes = plt.subplots(1, 3, figsize=(22, 6))
    
    font_size = 20
    
    v_min = ground_truth_trajectory.min()
    v_max = ground_truth_trajectory.max()
    
    # Plot for the Exact Solution
    im1 = axes[0].imshow(ground_truth_trajectory.T, extent=[0, config['T_FINAL'], 0, config['L']], origin='lower', aspect='auto', vmin=v_min, vmax=v_max, cmap='viridis')
    axes[0].set_title("Exact Solution", fontsize=font_size)
    axes[0].set_xlabel("t", fontsize=font_size)
    axes[0].set_ylabel("x", fontsize=font_size) 
    cbar1 = fig.colorbar(im1, ax=axes[0])
    cbar1.ax.tick_params(labelsize=font_size)
    
    # Plot for the Predicted Solution
    im2 = axes[1].imshow(predicted_trajectory_sensors.T, extent=[0, config['T_FINAL'], 0, config['L']], origin='lower', aspect='auto', vmin=v_min, vmax=v_max, cmap='viridis')
    axes[1].set_title("Predicted Solution", fontsize=font_size)
    axes[1].set_xlabel("t", fontsize=font_size)
    axes[1].set_ylabel("x", fontsize=font_size) 
    # The y-label is now correctly omitted because of sharey=True
    cbar2 = fig.colorbar(im2, ax=axes[1])
    cbar2.ax.tick_params(labelsize=font_size)
    
    # Plot for the Absolute Error
    error = np.abs(true_solution_on_sensors - predicted_trajectory_sensors)
    im3 = axes[2].imshow(error.T, extent=[0, config['T_FINAL'], 0, config['L']], origin='lower', aspect='auto', cmap='viridis')
    axes[2].set_title("Absolute Error", fontsize=font_size)
    axes[2].set_xlabel("t", fontsize=font_size)
    axes[2].set_ylabel("x", fontsize=font_size) 
    # The y-label is also omitted here
    cbar3 = fig.colorbar(im3, ax=axes[2])
    cbar3.ax.tick_params(labelsize=font_size)
    
    # Increase tick label font size for all axes
    for ax in axes:
        ax.tick_params(axis='both', which='major', labelsize=font_size)
    
    plt.tight_layout()
    
    plot_path = os.path.join(run_dir, "evaluation_plot.pdf")
    
    plt.savefig(plot_path, bbox_inches='tight')
    
    plt.close()
    print(f"Evaluation PDF saved to {plot_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Autoregressive evaluation of a trained PropagatorDeepONet.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True, help="The specific run_id of the model to evaluate.")
    args = parser.parse_args()
    main(args)