# src/unified_comparison_2d_final.py
# ---------------------------------------------------------------------------
# FINAL UNIFIED SCRIPT for comparing five different control strategies for the 2D Heat Equation.
# (Version with Larger Fonts for Plotting)
# ---------------------------------------------------------------------------

import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
import time

# --- Dependencies ---
from data_and_models_2d import PropagatorDeepONet, RecurrentController
import scipy.sparse as sp
from scipy.sparse import eye, kronsum
from scipy.sparse.linalg import spsolve
from scipy.optimize import minimize
import cvxpy as cp

# ===========================================================================
# 1. HELPER: 2D TARGET & STATE-SPACE MODEL
# ===========================================================================

def generate_target_2d(config, target_type, device, seed=42):
    """Generates a 2D target profile, returning both a torch tensor and numpy array."""
    nx, ny = config['NX_SENSORS'], config['NY_SENSORS']
    x = torch.linspace(0, config['L_X'], nx, device=device)
    y = torch.linspace(0, config['L_Y'], ny, device=device)
    grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
    target_2d = torch.zeros(nx, ny, device=device)

    if target_type == 'gaussian_peak':
        cx, cy, wx, wy, amp = config['L_X']/2, config['L_Y']/2, config['L_X']/4, config['L_Y']/4, 1.5
        target_2d = amp * torch.exp(-((grid_x-cx)**2/(2*wx**2) + (grid_y-cy)**2/(2*wy**2)))
    elif target_type == 'sine_wave':
        amp, mean = 0.7, 0.8
        target_2d = mean + amp * torch.sin(2*torch.pi*grid_x/config['L_X']) * torch.sin(2*torch.pi*grid_y/config['L_Y'])
    elif target_type == 'complex_gaussian':
        torch.manual_seed(seed)
        for _ in range(5):
            amp = (torch.rand(1, device=device).item() - 0.5) * 2.0
            cx, cy = torch.rand(1, device=device).item()*config['L_X'], torch.rand(1, device=device).item()*config['L_Y']
            wx, wy = (torch.rand(1, device=device).item()*0.15+0.05)*config['L_X'], (torch.rand(1, device=device).item()*0.15+0.05)*config['L_Y']
            target_2d += amp * torch.exp(-((grid_x-cx)**2/(2*wx**2) + (grid_y-cy)**2/(2*wy**2)))
        target_2d = 0.75 + (target_2d - torch.mean(target_2d))
    
    return target_2d.reshape(1, -1), target_2d.cpu().numpy().flatten()

def get_state_space_model_2d(config):
    """Constructs the discrete-time state-space matrices A, B, g for the 2D problem."""
    NX, NY = config['NX_SENSORS'], config['NY_SENSORS']
    dx, dy = config['L_X'] / (NX - 1), config['L_Y'] / (NY - 1)
    dt = config['T_FINAL'] / (config['NT_SOLVER'] - 1)
    N = NX * NY
    
    def laplacian_1d(n_pts, d_space):
        D2 = sp.diags([1, -2, 1], [-1, 0, 1], shape=(n_pts, n_pts), format='csc') / d_space**2
        D2[0, 1], D2[-1, -2] = 2/d_space**2, 2/d_space**2
        return D2
    
    L_2D = kronsum(laplacian_1d(NY, dy), laplacian_1d(NX, dx))
    I = eye(N, format='csc')
    
    A_cn = I - 0.5 * config['D'] * dt * L_2D + 0.5 * config['BETA'] * dt * I
    B_cn = I + 0.5 * config['D'] * dt * L_2D - 0.5 * config['BETA'] * dt * I
    A_mat = spsolve(A_cn, B_cn)
    
    num_basis_total = config['NUM_BASIS_X'] * config['NUM_BASIS_Y']
    x_grid = np.linspace(0, config['L_X'], NX)
    y_grid = np.linspace(0, config['L_Y'], NY)
    basis_x = np.cos(np.arange(config['NUM_BASIS_X']) * np.pi * x_grid[:, None] / config['L_X'])
    basis_y = np.cos(np.arange(config['NUM_BASIS_Y']) * np.pi * y_grid[:, None] / config['L_Y'])
    basis_2d = (basis_x[:, None, :, None] * basis_y[None, :, None, :]).reshape(N, num_basis_total)
    B_mat_continuous = config['ALPHA'] * basis_2d
    B_mat = spsolve(A_cn, dt * B_mat_continuous)
    
    g_vec_continuous = dt * config['BETA'] * config['V_REF_VAL'] * np.ones(N)
    g_vec = spsolve(A_cn, g_vec_continuous)
    
    return A_mat, B_mat, g_vec

def get_surrogate_model(args, config, device):
    """Helper function to load the pre-trained DeepONet surrogate model."""
    controller_run_dir = os.path.join(args.output_base_dir, args.run_id)
    with open(os.path.join(controller_run_dir, "hyperparams_controller_2d.yaml"), 'r') as f:
        controller_hyperparams = yaml.safe_load(f)
    deeponet_run_id = controller_hyperparams['deeponet_run_id']
    with open(os.path.join(args.output_base_dir, deeponet_run_id, "hyperparams_propagator_2d.yaml"), 'r') as f:
        deeponet_hyperparams = yaml.safe_load(f)
        
    M_sensors_total = config['NX_SENSORS'] * config['NY_SENSORS']
    num_basis_total = config['NUM_BASIS_X'] * config['NUM_BASIS_Y']
    model_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    
    physics_simulator = PropagatorDeepONet(M_sensors=M_sensors_total, num_basis_functions=num_basis_total, trunk_input_dim=2, **{k:v for k,v in deeponet_hyperparams.items() if k in model_keys}).to(device)
    physics_simulator.load_state_dict(torch.load(os.path.join(args.output_base_dir, deeponet_run_id, "propagator_deeponet_2d_best.pth"), map_location=device))
    physics_simulator.eval()
    return physics_simulator

# ===========================================================================
# 3. CONTROLLER EVALUATION FUNCTIONS
# ===========================================================================

def run_recurrent_evaluation(args, config, target_T_torch, target_T_np, physics_simulator):
    print("\n--- [1/5] Evaluating PDE-OP (Recurrent) Controller ---")
    DEVICE = target_T_torch.device
    
    controller_run_dir = os.path.join(args.output_base_dir, args.run_id)
    with open(os.path.join(controller_run_dir, "hyperparams_controller_2d.yaml"), 'r') as f:
        ch = yaml.safe_load(f)
    
    M_sensors_total = config['NX_SENSORS'] * config['NY_SENSORS']
    num_basis_total = config['NUM_BASIS_X'] * config['NUM_BASIS_Y']
    controller_keys = ['hidden_dim', 'num_layers', 'activation_fn']
    
    controller = RecurrentController(M_sensors=M_sensors_total, num_basis_functions=num_basis_total, **{k:v for k,v in ch.items() if k in controller_keys}).to(DEVICE)
    controller.load_state_dict(torch.load(os.path.join(controller_run_dir, "recurrent_controller_2d_model.pth"), map_location=DEVICE))
    controller.eval()

    T_current = torch.full((1, M_sensors_total), config['INITIAL_STATE_VAL'], device=DEVICE)
    sensor_x = torch.linspace(0, config['L_X'], config['NX_SENSORS'], device=DEVICE)
    sensor_y = torch.linspace(0, config['L_Y'], config['NY_SENSORS'], device=DEVICE)
    grid_x, grid_y = torch.meshgrid(sensor_x, sensor_y, indexing='ij')
    sensor_locs_flat = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1).unsqueeze(0)
    
    state_history, control_history = [T_current.cpu().numpy().flatten()], []
    start_time = time.time()
    with torch.no_grad():
        hidden_state = None
        for _ in range(config['NT_SOLVER']):
            w_k, hidden_state = controller(T_current, target_T_torch, hidden_state)
            control_history.append(w_k.cpu().numpy().flatten())
            T_current = physics_simulator(T_current, w_k, sensor_locs_flat).squeeze(-1)
            state_history.append(T_current.cpu().numpy().flatten())
    total_time = time.time() - start_time
    
    final_mse = np.mean((state_history[-1] - target_T_np)**2)
    print(f"PDE-OP controller finished in {total_time:.3f}s. Final MSE: {final_mse:.6e}")
    return np.array(state_history), np.array(control_history), total_time

def run_adjoint_evaluation(config, target_T_np, A, B, g):
    """Evaluates the Adjoint-based optimal control method."""
    print("\n--- [2/5] Evaluating Adjoint-Based Controller ---")
    m, N = B.shape[1], A.shape[0]
    n_steps, Np = config['NT_SOLVER'], 10
    Q = 2 * eye(N); R = 1e-5 * eye(m)
    u_min, u_max = -1.0, 1.0

    def cost_and_grad(u_flat, x0, T, x_ref, A, B, g):
        U = u_flat.reshape(T, m); X = np.zeros((T + 1, N)); X[0] = x0
        for t in range(T): X[t+1] = A @ X[t] + B @ U[t] + g
        cost = sum((X[t] - x_ref) @ (Q @ (X[t] - x_ref)) for t in range(T + 1))
        cost += sum(U[t] @ (R @ U[t]) for t in range(T))
        Lambda = np.zeros((T + 1, N)); Lambda[T] = 2.0 * Q @ (X[T] - x_ref)
        for t in reversed(range(T)): Lambda[t] = A.T @ Lambda[t+1] + 2.0 * Q @ (X[t] - x_ref)
        grad = np.array([2.0 * R @ U[t] + B.T @ Lambda[t+1] for t in range(T)])
        return cost, grad.ravel()

    x_current = np.full(N, config['INITIAL_STATE_VAL'])
    history_x, history_u = [x_current], []
    start_time = time.time()
    for k in range(n_steps):
        print(f"Adjoint Step {k+1}/{n_steps}", end='\r')
        T = min(Np, n_steps - k)
        res = minimize(cost_and_grad, np.zeros(T*m), args=(x_current, T, target_T_np, A, B, g), method='L-BFGS-B', jac=True, bounds=[(u_min, u_max)]*(T*m), options={'maxiter': 30})
        u_apply = res.x.reshape(T, m)[0]
        x_current = A @ x_current + B @ u_apply + g
        history_x.append(x_current); history_u.append(u_apply)
    total_time = time.time() - start_time
    final_mse = np.mean((history_x[-1] - target_T_np)**2)
    print(f"Adjoint controller finished in {total_time:.3f}s. Final MSE: {final_mse:.6e}")
    return np.array(history_x), np.array(history_u), total_time

def run_mpc_evaluation(config, target_T_np, A, B, g):
    """Evaluates the Linear Model Predictive Control (LMPC) method."""
    print("\n--- [3/5] Evaluating LMPC Controller ---")
    m, N = B.shape[1], A.shape[0]
    n_steps, Np = config['NT_SOLVER'], 10
    Q = 2*eye(N); T_Q = 1e-1*eye(N); R = 1e-5*eye(m)
    u_min, u_max = -1.0, 1.0
    
    u, x = cp.Variable((m, Np)), cp.Variable((N, Np + 1))
    x_k, x_ref = cp.Parameter(N), cp.Parameter(N)
    cost = sum(cp.quad_form(x[:,i]-x_ref, T_Q) + cp.quad_form(u[:,i], R) for i in range(Np))
    cost += cp.quad_form(x[:,Np]-x_ref, Q)
    constraints = [x[:,0] == x_k] + [x[:,i+1] == A@x[:,i] + B@u[:,i] + g for i in range(Np)]
    constraints += [u >= u_min, u <= u_max]
    problem = cp.Problem(cp.Minimize(cost), constraints)

    x_current = np.full(N, config['INITIAL_STATE_VAL'])
    history_x, history_u = [x_current], []
    start_time = time.time()
    for k in range(n_steps):
        print(f"LMPC Step {k+1}/{n_steps}", end='\r')
        x_k.value, x_ref.value = x_current, target_T_np
        problem.solve(solver=cp.OSQP, warm_start=True, verbose=False)
        u_apply = u[:,0].value if problem.status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE] else np.zeros(m)
        x_current = A @ x_current + B @ u_apply + g
        history_x.append(x_current); history_u.append(u_apply)
    total_time = time.time() - start_time
    final_mse = np.mean((history_x[-1] - target_T_np)**2)
    print(f"LMPC controller finished in {total_time:.3f}s. Final MSE: {final_mse:.6e}")
    return np.array(history_x), np.array(history_u), total_time

def run_sbto_evaluation(args, config, target_T_torch, target_T_np, physics_simulator):
    """Evaluates the SBTO (open-loop surrogate optimization) baseline."""
    print("\n--- [4/5] Evaluating SBTO (Open-Loop Surrogate Control) ---")
    DEVICE = target_T_torch.device
    num_control_steps = config['NT_SOLVER']
    num_basis_total = config['NUM_BASIS_X'] * config['NUM_BASIS_Y']
    M_sensors_total = config['NX_SENSORS'] * config['NY_SENSORS']

    sensor_x = torch.linspace(0, config['L_X'], config['NX_SENSORS'], device=DEVICE)
    sensor_y = torch.linspace(0, config['L_Y'], config['NY_SENSORS'], device=DEVICE)
    grid_x, grid_y = torch.meshgrid(sensor_x, sensor_y, indexing='ij')
    sensor_locs_flat = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1).unsqueeze(0)
    
    control_params = torch.zeros(1, num_control_steps, num_basis_total, device=DEVICE, requires_grad=True)
    optimizer = optim.Adam([control_params], lr=args.sbto_lr)
    mse_loss_fn = nn.MSELoss()

    print(f"Starting SBTO optimization for {args.sbto_optim_steps} steps...")
    start_time = time.time()
    for i in range(args.sbto_optim_steps):
        optimizer.zero_grad()
        T_current = torch.full((1, M_sensors_total), config['INITIAL_STATE_VAL'], device=DEVICE)
        running_loss, effort_loss = 0.0, 0.0
        control_sequence = torch.tanh(control_params) # Bounded between -1 and 1
        
        for k in range(num_control_steps):
            w_k = control_sequence[:, k, :]
            T_current = physics_simulator(T_current, w_k, sensor_locs_flat).squeeze(-1)
            running_loss += mse_loss_fn(T_current, target_T_torch)
            effort_loss += torch.mean(w_k**2)
        
        terminal_loss = mse_loss_fn(T_current, target_T_torch)
        total_loss = (args.sbto_terminal_weight * terminal_loss +
                      args.sbto_running_weight * (running_loss / num_control_steps) +
                      args.sbto_effort_weight * (effort_loss / num_control_steps))
        total_loss.backward()
        optimizer.step()
        if (i + 1) % 50 == 0: print(f"Step {i+1}/{args.sbto_optim_steps} | Loss: {total_loss.item():.4e}")
    total_time = time.time() - start_time

    final_controls = torch.tanh(control_params.detach())
    state_history, control_history = [], []
    T_eval = torch.full((1, M_sensors_total), config['INITIAL_STATE_VAL'], device=DEVICE)
    state_history.append(T_eval.cpu().numpy().flatten())
    with torch.no_grad():
        for k in range(num_control_steps):
            w_k = final_controls[:, k, :]
            control_history.append(w_k.cpu().numpy().flatten())
            T_eval = physics_simulator(T_eval, w_k, sensor_locs_flat).squeeze(-1)
            state_history.append(T_eval.cpu().numpy().flatten())

    final_mse = np.mean((state_history[-1] - target_T_np)**2)
    print(f"SBTO finished in {total_time:.3f}s. Final MSE: {final_mse:.6e}")
    return np.array(state_history), np.array(control_history), total_time

def run_neural_mpc_evaluation(args, config, target_T_torch, target_T_np, physics_simulator):
    """Evaluates the Neural MPC (closed-loop surrogate optimization) baseline."""
    print("\n--- [5/5] Evaluating Neural MPC ---")
    DEVICE = target_T_torch.device
    num_control_steps = config['NT_SOLVER']
    num_basis_total = config['NUM_BASIS_X'] * config['NUM_BASIS_Y']
    M_sensors_total = config['NX_SENSORS'] * config['NY_SENSORS']
    
    sensor_x = torch.linspace(0, config['L_X'], config['NX_SENSORS'], device=DEVICE)
    sensor_y = torch.linspace(0, config['L_Y'], config['NY_SENSORS'], device=DEVICE)
    grid_x, grid_y = torch.meshgrid(sensor_x, sensor_y, indexing='ij')
    sensor_locs_flat = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1).unsqueeze(0)

    T_current = torch.full((1, M_sensors_total), config['INITIAL_STATE_VAL'], device=DEVICE)
    state_history, control_history = [T_current.cpu().numpy().flatten()], []
    
    start_time = time.time()
    for k in range(num_control_steps):
        print(f"Neural MPC Step {k+1}/{num_control_steps}", end='\r')
        horizon = min(args.nmpc_horizon, num_control_steps - k)
        control_params = torch.zeros(1, horizon, num_basis_total, device=DEVICE, requires_grad=True)
        optimizer = optim.Adam([control_params], lr=args.nmpc_lr)
        
        for _ in range(args.nmpc_optim_steps):
            optimizer.zero_grad()
            T_rollout = T_current
            running_loss, effort_loss = 0.0, 0.0
            control_sequence = torch.tanh(control_params)
            
            for h in range(horizon):
                w_h = control_sequence[:, h, :]
                T_rollout = physics_simulator(T_rollout, w_h, sensor_locs_flat).squeeze(-1)
                running_loss += nn.functional.mse_loss(T_rollout, target_T_torch)
                effort_loss += torch.mean(w_h**2)
            
            terminal_loss = nn.functional.mse_loss(T_rollout, target_T_torch)
            total_loss = (args.nmpc_terminal_weight * terminal_loss + 
                          args.nmpc_running_weight * (running_loss / horizon) + 
                          args.nmpc_effort_weight * (effort_loss / horizon))
            total_loss.backward()
            optimizer.step()
            
        with torch.no_grad():
            best_action = torch.tanh(control_params[:, 0, :].detach())
            control_history.append(best_action.cpu().numpy().flatten())
            T_current = physics_simulator(T_current, best_action, sensor_locs_flat).squeeze(-1)
            state_history.append(T_current.cpu().numpy().flatten())
            
    total_time = time.time() - start_time
    final_mse = np.mean((state_history[-1] - target_T_np)**2)
    print(f"\nNeural MPC finished in {total_time:.3f}s. Final MSE: {final_mse:.6e}")
    return np.array(state_history), np.array(control_history), total_time

# ===========================================================================
# 4. UNIFIED PLOTTING & MAIN ORCHESTRATOR
# ===========================================================================

def generate_comparison_plots_2d(all_results, config, output_dir):
    """Generates a combined plot comparing all methods across all 2D targets."""
    print(f"\n--- Generating Combined 2D Comparison Plot ---")
    os.makedirs(output_dir, exist_ok=True)
    
    # Increase font sizes globally
    font_size_title = 48
    font_size_label = 42
    font_size_ticks = 36
    
    target_types = list(all_results.keys())
    method_keys = ['pde_op', 'sbto', 'deeponet_mpc', 'adjoint', 'lmpc']
    method_titles = ['Target', 'PDE-OP (Ours)', 'SBTO', 'DeepONet MPC', 'Adjoint', 'LMPC']
    
    # Increased height to accommodate larger fonts
    fig, axes = plt.subplots(len(target_types), len(method_titles), figsize=(30, 6.5 * len(target_types)))
    
    for i, target_type in enumerate(target_types):
        results = all_results[target_type]
        nx, ny = config['NX_SENSORS'], config['NY_SENSORS']
        
        all_grids = [results['target_T'].reshape(nx, ny)] + [results[key]['states'][-1].reshape(nx, ny) for key in method_keys]
        vmin = min(g.min() for g in all_grids)
        vmax = max(g.max() for g in all_grids)
        
        row_axes = axes[i] if len(target_types) > 1 else axes

        for j, (grid, title) in enumerate(zip(all_grids, method_titles)):
            ax = row_axes[j]
            im = ax.imshow(grid.T, extent=[0, config['L_X'], 0, config['L_Y']], origin='lower', vmin=vmin, vmax=vmax, cmap='viridis')
            
            # Axis Labels (only on bottom row)
            if i == len(target_types) - 1: 
                ax.set_xlabel('x', fontsize=font_size_label)
            
            # Y-Labels (Target Name on first column)
            if j == 0: 
                formatted_target = target_type.replace('_', ' ').title()
                ax.set_ylabel(f"{formatted_target}\ny", fontsize=font_size_label)
            else: 
                ax.set_yticklabels([])
            
            # Column Titles (Method Names on top row)
            if i == 0: 
                ax.set_title(title, fontsize=font_size_title, pad=20)
                
            # Tick Parameters
            ax.tick_params(axis='both', which='major', labelsize=font_size_ticks)
            
            # Colorbar with larger ticks
            cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
            cbar.ax.tick_params(labelsize=font_size_ticks)

    plt.tight_layout()
    save_path = os.path.join(output_dir, "comparison_final_state_2D_ALL_METHODS.pdf")
    plt.savefig(save_path, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved combined 2D comparison plot to: {save_path}")

def main(args):
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {DEVICE}")

    print("Pre-loading surrogate model...")
    physics_simulator = get_surrogate_model(args, config, DEVICE)

    print("Pre-computing 2D state-space model for classical methods...")
    A, B, g = get_state_space_model_2d(config)
    
    all_results = {}
    target_types = ['gaussian_peak', 'sine_wave', 'complex_gaussian']
    
    for target_type in target_types:
        print(f"\n\n{'='*80}\n--- Starting 2D Evaluation for Target: '{target_type.upper()}' ---\n{'='*80}")
        
        target_T_torch, target_T_np = generate_target_2d(config, target_type, DEVICE)
        
        # --- Run all five methods ---
        pde_op_states, _, pde_op_time = run_recurrent_evaluation(args, config, target_T_torch, target_T_np, physics_simulator)
        adjoint_states, _, adjoint_time = run_adjoint_evaluation(config, target_T_np, A, B, g)
        lmpc_states, _, lmpc_time = run_mpc_evaluation(config, target_T_np, A, B, g)
        sbto_states, _, sbto_time = run_sbto_evaluation(args, config, target_T_torch, target_T_np, physics_simulator)
        nmpc_states, _, nmpc_time = run_neural_mpc_evaluation(args, config, target_T_torch, target_T_np, physics_simulator)

        all_results[target_type] = {
            'pde_op': {'states': pde_op_states, 'time': pde_op_time},
            'adjoint': {'states': adjoint_states, 'time': adjoint_time},
            'lmpc': {'states': lmpc_states, 'time': lmpc_time},
            'sbto': {'states': sbto_states, 'time': sbto_time},
            'deeponet_mpc': {'states': nmpc_states, 'time': nmpc_time},
            'target_T': target_T_np
        }

        print(f"\n--- SUMMARY FOR '{target_type.upper()}' ---")
        print(f"{'Method':<20} | {'Final MSE':<15} | {'Time (s)':<10}")
        print("-" * 50)
        print(f"{'PDE-OP (Ours)':<20} | {np.mean((pde_op_states[-1]-target_T_np)**2):<15.6e} | {pde_op_time:<10.3f}")
        print(f"{'Adjoint Method':<20} | {np.mean((adjoint_states[-1]-target_T_np)**2):<15.6e} | {adjoint_time:<10.3f}")
        print(f"{'LMPC':<20} | {np.mean((lmpc_states[-1]-target_T_np)**2):<15.6e} | {lmpc_time:<10.3f}")
        print(f"{'SBTO':<20} | {np.mean((sbto_states[-1]-target_T_np)**2):<15.6e} | {sbto_time:<10.3f}")
        print(f"{'DeepONet MPC':<20} | {np.mean((nmpc_states[-1]-target_T_np)**2):<15.6e} | {nmpc_time:<10.3f}")
        print("-" * 50)

    output_dir = os.path.join(args.output_base_dir, args.run_id, "comparison_plots_2d_all_methods")
    generate_comparison_plots_2d(all_results, config, output_dir)
    
    np.savez_compressed(os.path.join(output_dir, "all_simulation_results_2d.npz"), **all_results)
    print(f"\nSaved all numerical results to: {os.path.join(output_dir, 'all_simulation_results_2d.npz')}")
    print("\n--- Unified 2D evaluation complete. ---")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Unified evaluation of controllers for the 2D Heat Equation.")
    parser.add_argument("--config_path", type=str, default="config/config_2d.yaml", help="Path to the 2D YAML config file.")
    parser.add_argument("--output_base_dir", type=str, default="outputs_2d", help="Base directory where runs are stored.")
    parser.add_argument("--run_id", type=str, required=True, help="The run_id of the trained PDE-OP controller.")
    
    # --- NEW: SBTO Hyperparameters ---
    parser.add_argument("--sbto_optim_steps", type=int, default=200, help="Number of optimization steps for SBTO.")
    parser.add_argument("--sbto_lr", type=float, default=1e-2, help="Learning rate for SBTO optimizer.")
    parser.add_argument("--sbto_terminal_weight", type=float, default=1.0)
    parser.add_argument("--sbto_running_weight", type=float, default=0.1)
    parser.add_argument("--sbto_effort_weight", type=float, default=1e-5)
    
    # --- NEW: Neural MPC Hyperparameters ---
    parser.add_argument("--nmpc_horizon", type=int, default=10, help="Planning horizon for Neural MPC.")
    parser.add_argument("--nmpc_optim_steps", type=int, default=25, help="Optimization steps per MPC step.")
    parser.add_argument("--nmpc_lr", type=float, default=2e-2, help="Learning rate for Neural MPC optimizer.")
    parser.add_argument("--nmpc_terminal_weight", type=float, default=1.0)
    parser.add_argument("--nmpc_running_weight", type=float, default=0.1)
    parser.add_argument("--nmpc_effort_weight", type=float, default=1e-5)
    
    args = parser.parse_args()
    main(args)