# src/evaluate_propagator_deeponet.py
# REFACTORED for consistent variable viscosity evaluation.

import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
from scipy.interpolate import interp1d

# Import from the Burgers' project files
# MODIFIED: Import the spatial GRF generator as well
from data_and_models import PropagatorDeepONet, BurgersSimulator, generate_grf_time_series, generate_grf_spatial_series

def main(args):
    # --- 1. Load Configs and Set Up ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Autoregressive Evaluation for Burgers' Propagator: {args.run_id} ---")

    run_dir = os.path.join(args.output_base_dir, args.run_id)
    model_path = os.path.join(run_dir, "burgers_propagator_best.pth")
    hyperparams_path = os.path.join(run_dir, "hyperparams.yaml")
    if not os.path.exists(model_path):
        print(f"FATAL: Model not found at {model_path}"); return

    # --- 2. Load the Trained Propagator Model ---
    with open(hyperparams_path, 'r') as f:
        hyperparams = yaml.safe_load(f)
    
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    model_kwargs = {key: hyperparams[key] for key in model_arg_keys}
    model = PropagatorDeepONet(
        M_sensors=config['M_SENSORS'],
        num_basis_functions=config['NUM_BASIS_FUNCTIONS'],
        trunk_input_dim=config['TRUNK_INPUT_DIM'],
        **model_kwargs
    ).to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()
    print("Burgers' PropagatorDeepONet model loaded successfully.")

    # --- 3. Generate a Single, New, Unseen Test Case ---
    print("\n--- Generating a new, unseen test case from the training distribution... ---")
    np.random.seed(42) # Use a fixed seed for a consistent test case
    
    x_grid_solver = np.linspace(0, config['L'], config['NX_SOLVER'])

    # a) Generate random initial state
    initial_state_np = np.zeros(config['NX_SOLVER'])
    for _ in range(np.random.randint(1, 4)):
        initial_state_np += np.random.uniform(-0.5, 0.5) * np.sin(np.random.randint(1, 4) * np.pi * x_grid_solver)

    # b) Generate random control weights
    w_case_sequence = generate_grf_time_series(config, config['NUM_BASIS_FUNCTIONS'])
    w_case_sequence = np.clip(w_case_sequence * 1.5, -config['CONTROL_SCALE'], config['CONTROL_SCALE'])
    
    # c) MODIFIED: Generate a random viscosity profile using the SAME METHOD as in data generation
    nu_raw = generate_grf_spatial_series(config, 1, config['VISCOSITY_LENGTH_SCALE']).squeeze()
    nu_min, nu_max = config['VISCOSITY_RANGE']
    v_profile_eval = nu_min + (nu_max - nu_min) * (0.5 * (np.tanh(nu_raw) + 1))
    print("Generated random initial state, controls, and viscosity profile.")

    # --- 4. Get the Ground Truth solution from the HIGH-RES numerical solver ---
    # MODIFIED: Initialize the simulator (no args needed) and call the .run() method
    simulator = BurgersSimulator(config)
    ground_truth_trajectory = simulator.run(initial_state_np, w_case_sequence, v_profile_eval)

    # --- 5. Perform AUTOREGRESSIVE Prediction (Rollout) ---
    print("--- Performing autoregressive rollout... ---")
    
    sensor_locs_np = np.linspace(0, config['L'], config['M_SENSORS'])
    
    # Interpolate both initial state and viscosity profile to the sensor grid
    interp_initial = interp1d(x_grid_solver, initial_state_np)
    initial_state_sensors = interp_initial(sensor_locs_np)
    
    interp_visc = interp1d(x_grid_solver, v_profile_eval)
    v_profile_sensors_np = interp_visc(sensor_locs_np)

    # Convert inputs to PyTorch tensors for the model
    predicted_states_sensors = [initial_state_sensors]
    current_state_sensors_torch = torch.from_numpy(initial_state_sensors).float().unsqueeze(0).to(DEVICE)
    v_profile_torch = torch.from_numpy(v_profile_sensors_np).float().unsqueeze(0).to(DEVICE)
    w_case_tensor = torch.from_numpy(w_case_sequence).float().to(DEVICE)
    x_grid_sensors_torch = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).float().unsqueeze(0).unsqueeze(-1)

    with torch.no_grad():
        for k in range(config['NT_SOLVER'] - 1):
            w_k = w_case_tensor[k, :].unsqueeze(0)
            
            # MODIFIED: Pass the consistent viscosity profile to the model at every step
            next_state_pred_sensors = model(current_state_sensors_torch, w_k, v_profile_torch, x_grid_sensors_torch).squeeze()
            
            predicted_states_sensors.append(next_state_pred_sensors.cpu().numpy())
            current_state_sensors_torch = next_state_pred_sensors.unsqueeze(0)

    predicted_trajectory_sensors = np.array(predicted_states_sensors)

    # --- 6. Calculate Final MSE and Save Results ---
    # Interpolate ground truth to sensor grid for fair comparison
    interp_gt = interp1d(x_grid_solver, ground_truth_trajectory, axis=1, kind='cubic')
    true_solution_on_sensors = interp_gt(sensor_locs_np)
    
    rollout_mse = np.mean((true_solution_on_sensors - predicted_trajectory_sensors)**2)
    print(f"\nFinal Autoregressive Rollout MSE: {rollout_mse:.4e}")
    results = {'autoregressive_rollout_mse': float(rollout_mse)}
    results_path = os.path.join(run_dir, "evaluation_rollout_results.yaml")
    with open(results_path, 'w') as f: yaml.dump(results, f, sort_keys=False)

    # --- 7. Plotting Results ---
    print("--- Plotting results... ---")
    fig, axes = plt.subplots(1, 3, figsize=(22, 6))
    
    font_size = 16 # Adjusted for better readability
    v_min = ground_truth_trajectory.min()
    v_max = ground_truth_trajectory.max()
    
    im1 = axes[0].imshow(ground_truth_trajectory.T, extent=[0, config['T_FINAL'], 0, config['L']], origin='lower', aspect='auto', vmin=v_min, vmax=v_max, cmap='viridis')
    axes[0].set_title("Exact Solution (Ground Truth)", fontsize=font_size)
    axes[0].set_xlabel("Time (t)", fontsize=font_size); axes[0].set_ylabel("Space (x)", fontsize=font_size) 
    
    im2 = axes[1].imshow(predicted_trajectory_sensors.T, extent=[0, config['T_FINAL'], 0, config['L']], origin='lower', aspect='auto', vmin=v_min, vmax=v_max, cmap='viridis')
    axes[1].set_title("Predicted Solution (Rollout)", fontsize=font_size)
    axes[1].set_xlabel("Time (t)", fontsize=font_size); axes[1].set_ylabel("Space (x)", fontsize=font_size) 
    
    error = np.abs(true_solution_on_sensors - predicted_trajectory_sensors)
    im3 = axes[2].imshow(error.T, extent=[0, config['T_FINAL'], 0, config['L']], origin='lower', aspect='auto', cmap='magma') # Changed colormap for error
    axes[2].set_title("Absolute Error", fontsize=font_size)
    axes[2].set_xlabel("Time (t)", fontsize=font_size); axes[2].set_ylabel("Space (x)", fontsize=font_size) 
    
    # Add colorbars with consistent font sizes
    fig.colorbar(im1, ax=axes[0]).ax.tick_params(labelsize=font_size-2)
    fig.colorbar(im2, ax=axes[1]).ax.tick_params(labelsize=font_size-2)
    fig.colorbar(im3, ax=axes[2]).ax.tick_params(labelsize=font_size-2)
    
    for ax in axes:
        ax.tick_params(axis='both', which='major', labelsize=font_size-2)
    
    plt.tight_layout()
    plot_path = os.path.join(run_dir, "evaluation_plot.pdf")
    plt.savefig(plot_path, bbox_inches='tight')
    plt.close()
    
    # NEW: Also plot the specific viscosity profile used for this evaluation
    plt.figure(figsize=(10, 5))
    plt.plot(x_grid_solver, v_profile_eval, lw=3)
    plt.title(f"Unseen Viscosity Profile ν(x) for Evaluation (Run: {args.run_id})", fontsize=font_size)
    plt.xlabel("Space (x)", fontsize=font_size)
    plt.ylabel("Viscosity ν", fontsize=font_size)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tick_params(axis='both', which='major', labelsize=font_size-2)
    plt.tight_layout()
    visc_plot_path = os.path.join(run_dir, "evaluation_viscosity_profile.pdf")
    plt.savefig(visc_plot_path, bbox_inches='tight')
    plt.close()

    print(f"Evaluation PDF saved to {plot_path}")
    print(f"Viscosity profile plot saved to {visc_plot_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Autoregressive evaluation of a trained PropagatorDeepONet.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True, help="The specific run_id of the model to evaluate.")
    args = parser.parse_args()
    main(args)