import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
from scipy.interpolate import RegularGridInterpolator

# --- CORRECTED IMPORT ---
# Import the 2D versions of the model and data generation functions
from data_and_models_2d import PropagatorDeepONet, solve_pde_2d, generate_grf_time_series

def main(args):
    # --- 1. Load Configs and Set Up ---
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"--- 2D Autoregressive Evaluation for Propagator: {args.run_id} ---")
    print(f"Using device: {DEVICE}")

    run_dir = os.path.join(args.output_base_dir, args.run_id)
    model_path = os.path.join(run_dir, "propagator_deeponet_2d_best.pth")
    hyperparams_path = os.path.join(run_dir, "hyperparams_propagator_2d.yaml")
    if not os.path.exists(model_path):
        print(f"FATAL: Model not found at {model_path}. Please train the 2D model first."); return

    # --- 2. Load the Trained 2D Propagator Model ---
    with open(hyperparams_path, 'r') as f:
        hyperparams = yaml.safe_load(f)
    
    M_sensors_total = config['NX_SENSORS'] * config['NY_SENSORS']
    num_basis_total = config['NUM_BASIS_X'] * config['NUM_BASIS_Y']

    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    model_kwargs = {key: hyperparams[key] for key in model_arg_keys}
    
    model = PropagatorDeepONet(
        M_sensors=M_sensors_total, 
        num_basis_functions=num_basis_total, 
        trunk_input_dim=config['TRUNK_INPUT_DIM'],
        **model_kwargs
    ).to(DEVICE)
    
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()
    print("2D PropagatorDeepONet model loaded successfully.")

    # --- 3. Generate a Single, Unseen 2D Test Case ---
    print("\n--- Generating a new 2D test case... ---")
    np.random.seed(42)
    w_case_sequence = generate_grf_time_series(config, config['NT_SOLVER'], num_basis_total, length_scale=1.5)
    w_case_sequence = np.clip(w_case_sequence, -1.0, 1.0)

    # --- 4. Get the Ground Truth solution from the numerical solver ---
    x_solver = np.linspace(0, config['L_X'], config['NX_SOLVER'])
    y_solver = np.linspace(0, config['L_Y'], config['NY_SOLVER'])
    basis_x = np.cos(np.arange(config['NUM_BASIS_X']) * np.pi * x_solver[:, None] / config['L_X'])
    basis_y = np.cos(np.arange(config['NUM_BASIS_X']) * np.pi * y_solver[:, None] / config['L_Y'])
    basis_functions_2d = (basis_x[:, None, :, None] * basis_y[None, :, None, :]).reshape(config['NX_SOLVER'], config['NY_SOLVER'], -1)
    
    u_case_xyt = np.einsum('tb,xyb->txy', w_case_sequence, basis_functions_2d)
    
    # --- CORRECTED FUNCTION CALL ---
    true_solution_grid = solve_pde_2d(config, u_case_xyt) # Shape: (NT, NX, NY)

    # --- 5. Perform AUTOREGRESSIVE Prediction (Rollout) ---
    print("--- Performing autoregressive rollout... ---")
    predicted_states_flat = []
    current_state_sensors = torch.full((1, M_sensors_total), config['INITIAL_STATE_VAL'], device=DEVICE)
    
    sensor_x = torch.linspace(0, config['L_X'], config['NX_SENSORS'], device=DEVICE)
    sensor_y = torch.linspace(0, config['L_Y'], config['NY_SENSORS'], device=DEVICE)
    grid_x, grid_y = torch.meshgrid(sensor_x, sensor_y, indexing='ij')
    sensor_locs_flat = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1)
    sensor_locs_model_input = sensor_locs_flat.unsqueeze(0)
    
    w_case_tensor = torch.from_numpy(w_case_sequence).float().to(DEVICE)

    with torch.no_grad():
        for k in range(config['NT_SOLVER'] - 1):
            w_k = w_case_tensor[k, :].unsqueeze(0)
            next_state_pred = model(current_state_sensors, w_k, sensor_locs_model_input).squeeze()
            predicted_states_flat.append(next_state_pred.cpu().numpy())
            current_state_sensors = next_state_pred.unsqueeze(0)

    predicted_solution_flat = np.array(predicted_states_flat)

    # --- 6. Calculate Final MSE and Save Results ---
    sensor_x_np, sensor_y_np = sensor_x.cpu().numpy(), sensor_y.cpu().numpy()
    sensor_grid_x_np, sensor_grid_y_np = np.meshgrid(sensor_x_np, sensor_y_np, indexing='ij')
    sensor_points_flat = np.array([sensor_grid_x_np.flatten(), sensor_grid_y_np.flatten()]).T
    
    true_solution_on_sensors_flat = []
    interpolator = RegularGridInterpolator((x_solver, y_solver), true_solution_grid[0], method='linear', bounds_error=False, fill_value=None)
    for t_step in range(config['NT_SOLVER']):
        interpolator.values = true_solution_grid[t_step]
        true_solution_on_sensors_flat.append(interpolator(sensor_points_flat))
    
    true_solution_on_sensors_flat = np.array(true_solution_on_sensors_flat)

    rollout_mse = np.mean((true_solution_on_sensors_flat[1:] - predicted_solution_flat)**2)
    print(f"\nFinal Autoregressive Rollout MSE: {rollout_mse:.6e}")

    results = {'autoregressive_rollout_mse': float(rollout_mse)}
    results_path = os.path.join(run_dir, "evaluation_rollout_results_2d.yaml")
    with open(results_path, 'w') as f: yaml.dump(results, f)

    # --- 7. Plotting Results ---
    print("--- Plotting 2D comparison snapshots... ---")
    predicted_solution_grid = predicted_solution_flat.reshape(-1, config['NX_SENSORS'], config['NY_SENSORS'])
    initial_state_grid = np.full((1, config['NX_SENSORS'], config['NY_SENSORS']), config['INITIAL_STATE_VAL'])
    predicted_solution_grid = np.concatenate((initial_state_grid, predicted_solution_grid), axis=0)
    
    true_solution_on_sensors_grid = true_solution_on_sensors_flat.reshape(-1, config['NY_SENSORS'], config['NY_SENSORS'])

    time_indices = [int(p * (config['NT_SOLVER'] - 1)) for p in [0.25, 0.50, 1.0]]
    
    fig, axes = plt.subplots(len(time_indices), 3, figsize=(15, 5 * len(time_indices)))
    fig.suptitle(f'2D Propagator Rollout Evaluation: {args.run_id}', fontsize=18)
    
    vmin, vmax = true_solution_on_sensors_grid.min(), true_solution_on_sensors_grid.max()

    for i, t_idx in enumerate(time_indices):
        time = t_idx * config['T_FINAL'] / (config['NT_SOLVER'] - 1)
        
        ax = axes[i, 0]
        im = ax.imshow(true_solution_on_sensors_grid[t_idx].T, extent=[0, config['L_X'], 0, config['L_Y']], origin='lower', vmin=vmin, vmax=vmax, cmap='viridis')
        ax.set_title(f'Ground Truth at t={time:.2f}'); ax.set_xlabel('x'); ax.set_ylabel('y')
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        
        ax = axes[i, 1]
        im = ax.imshow(predicted_solution_grid[t_idx].T, extent=[0, config['L_X'], 0, config['L_Y']], origin='lower', vmin=vmin, vmax=vmax, cmap='viridis')
        ax.set_title(f'Prediction at t={time:.2f}'); ax.set_xlabel('x'); ax.set_ylabel('y')
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        
        ax = axes[i, 2]
        error = np.abs(true_solution_on_sensors_grid[t_idx] - predicted_solution_grid[t_idx])
        im = ax.imshow(error.T, extent=[0, config['L_X'], 0, config['L_Y']], origin='lower', cmap='magma')
        ax.set_title(f'Absolute Error at t={time:.2f}'); ax.set_xlabel('x'); ax.set_ylabel('y')
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plot_path = os.path.join(run_dir, "evaluation_plot_2d.pdf")
    plt.savefig(plot_path, bbox_inches='tight')
    plt.close()
    
    print(f"Evaluation PDF saved to {plot_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Autoregressive evaluation of a trained 2D PropagatorDeepONet.")
    parser.add_argument("--config_path", type=str, default="config/base_config.yaml", help="Path to the 2D configuration YAML file.")
    parser.add_argument("--output_base_dir", type=str, default="outputs_2d", help="The base directory where run folders are stored.")
    parser.add_argument("--run_id", type=str, required=True, help="The specific run_id of the 2D model to evaluate.")
    args = parser.parse_args()
    main(args)