# src/evaluate_propagator_deeponet.py

import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
from scipy.interpolate import interp1d

from data_and_models import PropagatorDeepONet, solve_pde_time_varying, generate_grf_time_series

def main(args):
    # --- 1. Load Configs and Set Up ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Autoregressive Evaluation for Propagator: {args.run_id} ---")
    print(f"Using device: {DEVICE}")

    run_dir = os.path.join(args.output_base_dir, args.run_id)
    model_path = os.path.join(run_dir, "propagator_deeponet_best.pth")
    hyperparams_path = os.path.join(run_dir, "hyperparams.yaml")
    if not os.path.exists(model_path):
        print(f"FATAL: Model not found at {model_path}"); return

    # --- 2. Load the Trained Propagator Model ---
    with open(hyperparams_path, 'r') as f:
        hyperparams = yaml.safe_load(f)
    
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    model_kwargs = {key: hyperparams[key] for key in model_arg_keys}
    model = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **model_kwargs).to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()
    print("PropagatorDeepONet model loaded successfully.")

    # --- 3. Generate a Single, New, Unseen Test Case ---
    print("\n--- Generating a new test case... ---")
    np.random.seed(42) # Use a fixed seed for a consistent test case
    w_case_sequence = generate_grf_time_series(config, config['NT_SOLVER'], config['NUM_BASIS_FUNCTIONS'], 1.5)
    w_case_sequence = np.clip(w_case_sequence * 0.7, -1.0, 1.0)

    # --- 4. Get the Ground Truth solution from the numerical solver ---
    x_grid_solver = np.linspace(0, config['L'], config['NX_SOLVER'])
    basis_functions = np.cos(np.arange(config['NUM_BASIS_FUNCTIONS']) * np.pi * x_grid_solver[:, None] / config['L'])
    u_case_xt = w_case_sequence @ basis_functions.T
    true_solution = solve_pde_time_varying(config, u_case_xt)

    # --- 5. Perform AUTOREGRESSIVE Prediction (Rollout) ---
    print("--- Performing autoregressive rollout... ---")
    predicted_states = []
    current_state_sensors = torch.full((1, config['M_SENSORS']), config['INITIAL_STATE_VAL'], device=DEVICE)
    sensor_locs = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).float().unsqueeze(0).unsqueeze(-1)
    w_case_tensor = torch.from_numpy(w_case_sequence).float().to(DEVICE)

    with torch.no_grad():
        for k in range(config['NT_SOLVER'] - 1):
            w_k = w_case_tensor[k, :].unsqueeze(0)
            next_state_pred = model(current_state_sensors, w_k, sensor_locs).squeeze()
            predicted_states.append(next_state_pred.cpu().numpy())
            current_state_sensors = next_state_pred.unsqueeze(0)

    predicted_solution_on_sensors = np.array(predicted_states)

    # --- 6. Calculate Final MSE and Save Results ---
    sensor_locs_np = sensor_locs.cpu().numpy().flatten()
    interpolator = interp1d(x_grid_solver, true_solution, axis=1, kind='cubic')
    true_solution_on_sensors = interpolator(sensor_locs_np)
    
    # Compare from the second time step onwards
    rollout_mse = np.mean((true_solution_on_sensors[1:, :] - predicted_solution_on_sensors)**2)
    print(f"\nFinal Autoregressive Rollout MSE: {rollout_mse:.4e}")

    results = {'autoregressive_rollout_mse': float(rollout_mse)}
    results_path = os.path.join(run_dir, "evaluation_rollout_results.yaml")
    with open(results_path, 'w') as f: yaml.dump(results, f, sort_keys=False)

# --- 7. Plotting Results ---
    print("--- Plotting results... ---")
    fig, axes = plt.subplots(1, 3, figsize=(22, 6))
    
    # Define a font size for better readability in publications
    font_size = 20
    
    
    v_min, v_max = true_solution.min(), true_solution.max()
    
    # Plot for the Exact Solution
    im1 = axes[0].imshow(true_solution.T, extent=[0, config['T_FINAL'], 0, config['L']], origin='lower', aspect='auto', vmin=v_min, vmax=v_max, cmap='viridis')
    axes[0].set_title("Exact Solution", fontsize=font_size)
    axes[0].set_xlabel("t", fontsize=font_size)
    axes[0].set_ylabel("x", fontsize=font_size)
    cbar1 = fig.colorbar(im1, ax=axes[0])
    cbar1.ax.tick_params(labelsize=font_size)
    
    # Plot for the Predicted Solution (Rollout)
    im2 = axes[1].imshow(predicted_solution_on_sensors.T, extent=[0, config['T_FINAL'], 0, config['L']], origin='lower', aspect='auto', vmin=v_min, vmax=v_max, cmap='viridis')
    axes[1].set_title("Predicted Solution", fontsize=font_size)
    axes[1].set_xlabel("t", fontsize=font_size)
    axes[1].set_ylabel("x", fontsize=font_size)
    cbar2 = fig.colorbar(im2, ax=axes[1])
    cbar2.ax.tick_params(labelsize=font_size)
    
    # Plot for the Absolute Error
    error = np.abs(true_solution_on_sensors[1:, :].T - predicted_solution_on_sensors.T)
    im3 = axes[2].imshow(error, extent=[0, config['T_FINAL'], 0, config['L']], origin='lower', aspect='auto', cmap='viridis')
    axes[2].set_title("Absolute Error", fontsize=font_size)
    axes[2].set_xlabel("t", fontsize=font_size)
    axes[2].set_ylabel("x", fontsize=font_size)
    cbar3 = fig.colorbar(im3, ax=axes[2])
    cbar3.ax.tick_params(labelsize=font_size)
    
    for ax in axes:
        ax.tick_params(axis='both', which='major', labelsize=font_size)
    
    plt.tight_layout()
    
    plot_path = os.path.join(run_dir, "evaluation_plot.pdf")
    
    
    plt.savefig(plot_path, bbox_inches='tight')
    
    plt.close()
    
    print(f"Evaluation PDF saved to {plot_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Autoregressive evaluation of a trained PropagatorDeepONet.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True, help="The specific run_id of the model to evaluate.")
    args = parser.parse_args()
    main(args)