

import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
from scipy.interpolate import interp1d

from data_and_models import PropagatorDeepONet, BurgersSimulator, generate_2d_grf_fast

def main(args):
    # --- 1. Load Configs and Set Up ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- Autoregressive Evaluation (Direct Control) for Burgers' Propagator: {args.run_id} ---")

    run_dir = os.path.join(args.output_base_dir, args.run_id)
    model_path = os.path.join(run_dir, "burgers_propagator_best.pth")
    hyperparams_path = os.path.join(run_dir, "hyperparams.yaml")
    if not os.path.exists(model_path):
        print(f"FATAL: Model not found at {model_path}"); return

    # --- 2. Load the Trained Propagator Model ---
    with open(hyperparams_path, 'r') as f:
        hyperparams = yaml.safe_load(f)
    
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    model_kwargs = {key: hyperparams[key] for key in model_arg_keys}
    model = PropagatorDeepONet(
        M_sensors=config['M_SENSORS'],
        num_basis_functions=config['NUM_BASIS_FUNCTIONS'], # Kept for API consistency
        trunk_input_dim=config['TRUNK_INPUT_DIM'],
        **model_kwargs
    ).to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()
    print("Burgers' PropagatorDeepONet model loaded successfully.")

    # --- 3. Generate a Single, New, Unseen Test Case ---
    print("\n--- Generating a new, unseen test case... ---")
    np.random.seed(42) # Use a fixed seed for a consistent test case
    
    x_grid_solver = np.linspace(0, config['L'], config['NX_SOLVER'])
    t_grid_solver = np.linspace(0, config['T_FINAL'], config['NT_SOLVER'])
    
    # Generate a full u(x,t) field for the test case
    u_xt_field_gt_unscaled = generate_2d_grf_fast(x_grid_solver, t_grid_solver, num_samples=1, x_scale=0.5, t_scale=1.5)[0]
    u_xt_field_gt = np.clip(u_xt_field_gt_unscaled * 1.5, -config['CONTROL_SCALE'], config['CONTROL_SCALE'])


    # --- 4. Get the Ground Truth solution from the HIGH-RES numerical solver ---
    simulator = BurgersSimulator(config)
    initial_state_np = np.zeros(config['NX_SOLVER'])
    for _ in range(np.random.randint(1, 4)):
        initial_state_np += np.random.uniform(-0.5, 0.5) * np.sin(np.random.randint(1, 4) * np.pi * x_grid_solver)
    
    current_state_gt = initial_state_np.copy()
    ground_truth_history = [current_state_gt]
    for k in range(config['NT_SOLVER'] - 1):
        # The simulator now takes the control values on the grid directly
        next_state_gt = simulator.step(current_state_gt, u_xt_field_gt[k, :])
        ground_truth_history.append(next_state_gt)
        current_state_gt = next_state_gt
    ground_truth_trajectory = np.array(ground_truth_history)

    # --- 5. Perform AUTOREGRESSIVE Prediction (Rollout) ---
    print("--- Performing autoregressive rollout... ---")
    
    sensor_locs_np = np.linspace(0, config['L'], config['M_SENSORS'])
    
    # Interpolate the fine-grid initial state down to the sensor grid for the model's first input
    interp_initial = interp1d(x_grid_solver, initial_state_np)
    initial_state_sensors = interp_initial(sensor_locs_np)
    
    # Get the control values at the sensor locations for the model's input
    control_interpolator = interp1d(x_grid_solver, u_xt_field_gt, axis=1, kind='linear')
    u_case_at_sensors = control_interpolator(sensor_locs_np)
    
    predicted_states_sensors = [initial_state_sensors]
    current_state_sensors_torch = torch.from_numpy(initial_state_sensors).float().unsqueeze(0).to(DEVICE)
    
    x_grid_sensors = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).float().unsqueeze(0).unsqueeze(-1)
    u_case_at_sensors_tensor = torch.from_numpy(u_case_at_sensors).float().to(DEVICE)

    with torch.no_grad():
        for k in range(config['NT_SOLVER'] - 1):
            # Get the control values at sensor locations for the current time step
            u_k_at_sensors = u_case_at_sensors_tensor[k, :].unsqueeze(0)
            
            next_state_pred_sensors = model(current_state_sensors_torch, u_k_at_sensors, x_grid_sensors).squeeze()
            
            predicted_states_sensors.append(next_state_pred_sensors.cpu().numpy())
            current_state_sensors_torch = next_state_pred_sensors.unsqueeze(0)

    predicted_trajectory_sensors = np.array(predicted_states_sensors)

    # --- 6. Calculate Final MSE and Save Results ---
    # For a fair comparison, interpolate the high-res ground truth to the sensor grid
    interp_gt = interp1d(x_grid_solver, ground_truth_trajectory, axis=1, kind='cubic')
    true_solution_on_sensors = interp_gt(sensor_locs_np)
    
    rollout_mse = np.mean((true_solution_on_sensors - predicted_trajectory_sensors)**2)
    print(f"\nFinal Autoregressive Rollout MSE: {rollout_mse:.4e}")
    results = {'autoregressive_rollout_mse': float(rollout_mse)}
    results_path = os.path.join(run_dir, "evaluation_rollout_results.yaml")
    with open(results_path, 'w') as f: yaml.dump(results, f, sort_keys=False)

    print("--- Plotting results... ---")
    fig, axes = plt.subplots(1, 3, figsize=(22, 6))
    font_size = 20
    v_min = ground_truth_trajectory.min()
    v_max = ground_truth_trajectory.max()
    
    im1 = axes[0].imshow(ground_truth_trajectory.T, extent=[0, config['T_FINAL'], 0, config['L']], origin='lower', aspect='auto', vmin=v_min, vmax=v_max, cmap='viridis')
    axes[0].set_title("Exact Solution", fontsize=font_size)
    axes[0].set_xlabel("t", fontsize=font_size)
    axes[0].set_ylabel("x", fontsize=font_size)
    cbar1 = fig.colorbar(im1, ax=axes[0])
    cbar1.ax.tick_params(labelsize=font_size)
    
    im2 = axes[1].imshow(predicted_trajectory_sensors.T, extent=[0, config['T_FINAL'], 0, config['L']], origin='lower', aspect='auto', vmin=v_min, vmax=v_max, cmap='viridis')
    axes[1].set_title("Predicted Solution", fontsize=font_size)
    axes[1].set_xlabel("t", fontsize=font_size)
    axes[1].set_ylabel("x", fontsize=font_size) 
    cbar2 = fig.colorbar(im2, ax=axes[1])
    cbar2.ax.tick_params(labelsize=font_size)
    
    error = np.abs(true_solution_on_sensors - predicted_trajectory_sensors)
    im3 = axes[2].imshow(error.T, extent=[0, config['T_FINAL'], 0, config['L']], origin='lower', aspect='auto', cmap='viridis')
    axes[2].set_title("Absolute Error", fontsize=font_size)
    axes[2].set_xlabel("t", fontsize=font_size)
    axes[2].set_ylabel("x", fontsize=font_size)
    cbar3 = fig.colorbar(im3, ax=axes[2])
    cbar3.ax.tick_params(labelsize=font_size)
    
    for ax in axes:
        ax.tick_params(axis='both', which='major', labelsize=font_size)
    
    plt.tight_layout()
    plot_path = os.path.join(run_dir, "evaluation_plot.pdf")
    plt.savefig(plot_path, bbox_inches='tight')
    plt.close()
    print(f"Evaluation PDF saved to {plot_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Autoregressive evaluation of a trained PropagatorDeepONet.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True, help="The specific run_id of the model to evaluate.")
    args = parser.parse_args()
    main(args)