# src/unified_comparison_heat.py
# ---------------------------------------------------------------------------
# FINAL UNIFIED SCRIPT for comparing five different control strategies for the 1D Heat Equation.
#
# Methods Compared:
# 1. PDE-OP (Recurrent Controller)
# 2. Adjoint-Based Method (Classical Closed-Loop)
# 3. Linear MPC (LMPC) (Classical Closed-Loop)
# 4. SBTO (Open-Loop Learning Baseline)
# 5. DeepONet MPC (Closed-Loop Learning Baseline)
# ---------------------------------------------------------------------------

import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
import time

# --- Dependencies for Learning-Based Controllers ---
from data_and_models import PropagatorDeepONet, RecurrentController

# --- Dependencies for Classical Controllers ---
import scipy.sparse as sp
from scipy.sparse.linalg import spsolve
from scipy.optimize import minimize
import cvxpy as cp

# ===========================================================================
# 1. UNIFIED TARGET GENERATION
# ===========================================================================
def generate_target_profile(target_type, x_grid, L):
    """
    Generates a consistent target temperature profile based on the type.
    """
    print(f"Generating '{target_type}' target profile...")
    
    if target_type == 'sine':
        return 0.6 + 0.3 * np.sin(2 * np.pi * x_grid / L)
    elif target_type == 'ramp':
        return np.linspace(0.5, 1.5, len(x_grid))
    elif target_type == 'constant':
        return np.full(len(x_grid), 1.0)
        
    # --- Advanced Targets ---
    elif target_type == 'step':
        # Step function: 0.6 between 30% and 70% of domain, else 0
        target_profile = np.zeros_like(x_grid)
        mask = (x_grid > 0.3 * L) & (x_grid < 0.7 * L)
        target_profile[mask] = 0.6
        return target_profile

    elif target_type == 'high_freq':
        # Base low frequency + High frequency component
        base = 0.5 * np.sin(2 * np.pi * x_grid / L)
        hf_comp = 0.2 * np.sin(12 * np.pi * x_grid / L)
        return base + hf_comp

    elif target_type == 'complex_gaussian':
        target_profile = np.zeros_like(x_grid)
        np.random.seed(42) # Fixed seed for reproducibility
        # Sum of 5 random Gaussian curves
        for _ in range(5):
            amp = np.random.uniform(0.5, 1.5) * (1 if np.random.rand() > 0.5 else -1)
            mean = np.random.uniform(0.1 * L, 0.9 * L)
            sigma = np.random.uniform(0.05 * L, 0.15 * L)
            target_profile += amp * np.exp(-((x_grid - mean)**2) / (2 * sigma**2))
        return target_profile
        
    else:
        raise ValueError(f"Invalid target type: {target_type}")

# ===========================================================================
# 2. PDE-OP (RECURRENT CONTROLLER) EVALUATION
# ===========================================================================
def run_recurrent_evaluation(args, config, target_T_np):
    print("\n--- [1/5] Evaluating PDE-OP (Recurrent Controller) ---")
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    controller_run_dir = os.path.join(args.output_base_dir, args.run_id)
    controller_model_path = os.path.join(controller_run_dir, "recurrent_controller_model.pth")
    with open(os.path.join(controller_run_dir, "hyperparams.yaml"), 'r') as f: ch_params = yaml.safe_load(f)
    controller_kwargs = {k: v for k, v in ch_params.items() if k in ['hidden_dim', 'num_layers', 'activation_fn']}
    controller = RecurrentController(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], **controller_kwargs).to(DEVICE)
    controller.load_state_dict(torch.load(controller_model_path, map_location=DEVICE)); controller.eval()

    deeponet_run_id = ch_params['deeponet_run_id']
    deeponet_run_dir = os.path.join(args.output_base_dir, deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "propagator_deeponet_best.pth")
    with open(os.path.join(deeponet_run_dir, "hyperparams.yaml"), 'r') as f: dh_params = yaml.safe_load(f)
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: dh_params[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE)); physics_simulator.eval()
    
    print("Successfully loaded recurrent controller and surrogate model.")
    target_T_torch = torch.from_numpy(target_T_np).float().unsqueeze(0).to(DEVICE)
    T_current = torch.full((1, config['M_SENSORS']), 0.0, device=DEVICE)
    sensor_locs_torch = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).unsqueeze(0).unsqueeze(-1)
    
    state_history, control_history = [T_current.cpu().numpy().flatten()], []
    start_time = time.time()
    with torch.no_grad():
        hidden_state = None
        for _ in range(config['NT_SOLVER']):
            w_k, hidden_state = controller(T_current, target_T_torch, hidden_state)
            control_history.append(w_k.cpu().numpy().flatten())
            T_current = physics_simulator(T_current, w_k, sensor_locs_torch).squeeze(-1)
            state_history.append(T_current.cpu().numpy().flatten())
    total_time = time.time() - start_time
    final_mse = np.mean((state_history[-1] - target_T_np)**2)
    print(f"Recurrent controller finished in {total_time:.4f}s. Final MSE: {final_mse:.4e}")
    return np.array(state_history), np.array(control_history), total_time

# ===========================================================================
# 3. ADJOINT-BASED CONTROLLER EVALUATION
# ===========================================================================
def run_adjoint_evaluation(config, target_T_np):
    print("\n--- [2/5] Evaluating Adjoint-Based Controller ---")
    L, D, alpha = config['L'], config['D'], config['ALPHA']
    beta, T_amb = 0.5, 0.0
    m = config['NUM_BASIS_FUNCTIONS']
    N = config['M_SENSORS'] - 1
    dx = L / N
    dt = config['T_FINAL'] / config['NT_SOLVER']
    n_steps = config['NT_SOLVER']
    Np = 10
    Q = 2*np.eye(N + 1)
    R = 1e-5 * np.eye(m)
    u_min, u_max = -1.0, 1.0
    inner_maxiter = 30

    def get_state_space_model(N, dx, dt, D, alpha, m, beta=0.0, T_amb=0.0):
        r = D * dt / (2 * dx**2); s = beta * dt / 2.0
        ml_diag = np.full(N + 1, 1 + 2*r + s); mr_diag = np.full(N + 1, 1 - 2*r - s)
        off_diag = np.full(N, -r)
        M_L = sp.diags([off_diag, ml_diag, off_diag], [-1, 0, 1], format='csc')
        M_R = sp.diags([-off_diag, mr_diag, -off_diag], [-1, 0, 1], format='csc')
        M_L[0, 1] = -2*r; M_R[0, 1] = 2*r
        M_L[N, N-1] = -2*r; M_R[N, N-1] = 2*r
        x_grid_local = np.linspace(0.0, L, N+1)
        B_c = np.zeros((N + 1, m))
        for j in range(m): B_c[:, j] = alpha * np.cos(j * np.pi * x_grid_local / L)
        s_vec = beta * T_amb * np.ones(N + 1)
        A = spsolve(M_L, M_R); B = spsolve(M_L, dt * B_c); g = spsolve(M_L, dt * s_vec)
        return A, B, g

    def horizon_cost_and_grad(u_horizon_flat, x0, horizon_len, x_ref_target, A, B, g):
        U_h = u_horizon_flat.reshape((horizon_len, m))
        X_h = np.zeros((horizon_len + 1, N + 1)); X_h[0] = x0
        for t in range(horizon_len): X_h[t+1] = A @ X_h[t] + B @ U_h[t] + g
        J = sum((X_h[t] - x_ref_target) @ (Q @ (X_h[t] - x_ref_target)) for t in range(horizon_len + 1))
        J += sum(U_h[t] @ (R @ U_h[t]) for t in range(horizon_len))
        Lambda = np.zeros((horizon_len + 1, N + 1))
        Lambda[-1] = 2.0 * (Q @ (X_h[-1] - x_ref_target))
        for t in reversed(range(horizon_len)): 
            Lambda[t] = A.T @ Lambda[t+1] + 2.0 * (Q @ (X_h[t] - x_ref_target))
        grad = np.array([2.0 * (R @ U_h[t]) + B.T @ Lambda[t+1] for t in range(horizon_len)])
        return float(J), grad.ravel()
        
    A, B, g = get_state_space_model(N, dx, dt, D, alpha, m, beta, T_amb)
    x_current = np.zeros(N + 1)
    history_x, history_u = [x_current], []
    prev_horizon_flat = None
    start_time = time.time()
    for k in range(n_steps):
        print(f"Adjoint Step {k+1}/{n_steps}", end='\r')
        horizon_len = min(Np, n_steps - k)
        u0_flat = np.zeros(horizon_len * m)
        if prev_horizon_flat is not None and prev_horizon_flat.reshape((-1, m)).shape[0] > 1:
            prev = prev_horizon_flat.reshape((-1, m))
            shifted = np.vstack((prev[1:], np.zeros((1, m))))
            if shifted.shape[0] < horizon_len:
                shifted = np.vstack((shifted, np.zeros((horizon_len - shifted.shape[0], m))))
            u0_flat = shifted[:horizon_len].ravel()
        bounds = [(u_min, u_max)] * (horizon_len * m)
        res = minimize(lambda u: horizon_cost_and_grad(u, x_current, horizon_len, target_T_np, A, B, g), u0_flat, method='L-BFGS-B', jac=True, bounds=bounds, options={'maxiter': inner_maxiter, 'ftol': 1e-6})
        u_opt_flat = res.x if res.success else u0_flat
        prev_horizon_flat = u_opt_flat
        u_apply = u_opt_flat.reshape((horizon_len, m))[0]
        x_next = A @ x_current + B @ u_apply + g
        history_x.append(x_next); history_u.append(u_apply)
        x_current = x_next
    total_time = time.time() - start_time
    final_mse = np.mean((history_x[-1] - target_T_np)**2)
    print(f"Adjoint controller finished in {total_time:.4f}s. Final MSE: {final_mse:.4e}")
    return np.array(history_x), np.array(history_u), total_time

# ===========================================================================
# 4. LMPC CONTROLLER EVALUATION
# ===========================================================================
def run_mpc_evaluation(config, target_T_np):
    print("\n--- [3/5] Evaluating LMPC Controller ---")
    L, D, alpha = config['L'], config['D'], config['ALPHA']
    beta, T_amb = 0.5, 0.0
    m = config['NUM_BASIS_FUNCTIONS']
    N = config['M_SENSORS'] - 1
    dx = L / N
    dt = config['T_FINAL'] / config['NT_SOLVER']
    n_steps = config['NT_SOLVER']
    Np = 10
    Q = 2*np.eye(N + 1)
    T_Q = 1e-1 * np.eye(N + 1)
    R = 1e-5 * np.eye(m)
    u_min, u_max = -1.0, 1.0

    def get_state_space_model(N, dx, dt, D, alpha, m, beta=0.0, T_amb=0.0):
        r = D * dt / (2 * dx**2); s = beta * dt / 2.0
        ml_diag = np.full(N + 1, 1 + 2*r + s); mr_diag = np.full(N + 1, 1 - 2*r - s)
        off_diag = np.full(N, -r)
        M_L = sp.diags([off_diag, ml_diag, off_diag], [-1, 0, 1], format='csc')
        M_R = sp.diags([-off_diag, mr_diag, -off_diag], [-1, 0, 1], format='csc')
        M_L[0, 1] = -2*r; M_R[0, 1] = 2*r
        M_L[N, N-1] = -2*r; M_R[N, N-1] = 2*r
        x_grid_local = np.linspace(0.0, L, N+1)
        B_c = np.zeros((N + 1, m))
        for j in range(m): B_c[:, j] = alpha * np.cos(j * np.pi * x_grid_local / L)
        s_vec = beta * T_amb * np.ones(N + 1)
        A = spsolve(M_L, M_R); B = spsolve(M_L, dt * B_c); g = spsolve(M_L, dt * s_vec)
        return A, B, g

    A, B, g = get_state_space_model(N, dx, dt, D, alpha, m, beta, T_amb)
    nx, nu = B.shape
    x_k = cp.Parameter(nx); x_ref = cp.Parameter(nx)
    u_seq = cp.Variable((nu, Np)); x_seq = cp.Variable((nx, Np + 1))
    cost = 0
    constraints = [x_seq[:, 0] == x_k]
    for i in range(Np):
        cost += cp.quad_form(x_seq[:, i] - x_ref, T_Q) + cp.quad_form(u_seq[:, i], R)
        constraints += [x_seq[:, i+1] == A @ x_seq[:, i] + B @ u_seq[:, i] + g]
        constraints += [u_min <= u_seq[:, i], u_seq[:, i] <= u_max]
    cost += cp.quad_form(x_seq[:, Np] - x_ref, Q)
    problem = cp.Problem(cp.Minimize(cost), constraints)
    x_current = np.zeros(nx)
    history_x, history_u = [x_current], []
    start_time = time.time()
    for k in range(n_steps):
        print(f"LMPC Step {k+1}/{n_steps}", end='\r')
        x_k.value = x_current; x_ref.value = target_T_np
        problem.solve(solver=cp.OSQP, warm_start=True, verbose=False)
        u_optimal = u_seq[:, 0].value if problem.status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE] else np.zeros(nu)
        x_next = A @ x_current + B @ u_optimal + g
        history_x.append(x_next); history_u.append(u_optimal)
        x_current = x_next
    total_time = time.time() - start_time
    final_mse = np.mean((history_x[-1] - target_T_np)**2)
    print(f"MPC controller finished in {total_time:.4f}s. Final MSE: {final_mse:.4e}")
    return np.array(history_x), np.array(history_u), total_time

# ===========================================================================
# 5. SBTO (OPEN-LOOP SURROGATE) EVALUATION
# ===========================================================================
def run_sbto_evaluation(args, config, target_T_np):
    print("\n--- [4/5] Evaluating SBTO (Open-Loop Surrogate Control) ---")
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    controller_run_dir = os.path.join(args.output_base_dir, args.run_id)
    with open(os.path.join(controller_run_dir, "hyperparams.yaml"), 'r') as f: ch_params = yaml.safe_load(f)
    deeponet_run_id = ch_params['deeponet_run_id']
    deeponet_run_dir = os.path.join(args.output_base_dir, deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "propagator_deeponet_best.pth")
    with open(os.path.join(deeponet_run_dir, "hyperparams.yaml"), 'r') as f: dh_params = yaml.safe_load(f)
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: dh_params[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE)); physics_simulator.eval()

    target_T_torch = torch.from_numpy(target_T_np).float().unsqueeze(0).to(DEVICE)
    sensor_locs_torch = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).unsqueeze(0).unsqueeze(-1)
    num_control_steps = config['NT_SOLVER']
    
    control_params = torch.zeros(1, num_control_steps, config['NUM_BASIS_FUNCTIONS'], device=DEVICE, requires_grad=True)
    optimizer = torch.optim.Adam([control_params], lr=args.sbto_lr)
    mse_loss_fn = torch.nn.MSELoss()

    print(f"Starting SBTO test-time optimization for {args.sbto_optim_steps} steps...")
    start_time = time.time()
    for i in range(args.sbto_optim_steps):
        optimizer.zero_grad()
        T_current = torch.full((1, config['M_SENSORS']), 0.0, device=DEVICE)
        running_tracking_loss = 0.0
        total_effort = 0.0
        control_sequence = torch.tanh(control_params)

        for k in range(num_control_steps):
            w_k = control_sequence[:, k, :]
            T_current = physics_simulator(T_current, w_k, sensor_locs_torch).squeeze(-1)
            running_tracking_loss += mse_loss_fn(T_current, target_T_torch)
            total_effort += torch.mean(w_k**2)
        
        terminal_loss = mse_loss_fn(T_current, target_T_torch)
        avg_running_loss = running_tracking_loss / num_control_steps
        avg_effort_loss = total_effort / num_control_steps
        
        total_loss = (args.sbto_terminal_weight * terminal_loss +
                      args.sbto_running_weight * avg_running_loss +
                      args.sbto_effort_weight * avg_effort_loss)
        
        total_loss.backward()
        optimizer.step()
        if (i + 1) % 50 == 0:
            print(f"Step {i+1}/{args.sbto_optim_steps} | Loss: {total_loss.item():.4e}")

    total_time = time.time() - start_time
    
    final_control_sequence = torch.tanh(control_params.detach())
    state_history, control_history = [], []
    T_current_eval = torch.full((1, config['M_SENSORS']), 0.0, device=DEVICE)
    state_history.append(T_current_eval.cpu().numpy().flatten())
    
    with torch.no_grad():
        for k in range(num_control_steps):
            w_k_final = final_control_sequence[:, k, :]
            control_history.append(w_k_final.cpu().numpy().flatten())
            T_current_eval = physics_simulator(T_current_eval, w_k_final, sensor_locs_torch).squeeze(-1)
            state_history.append(T_current_eval.cpu().numpy().flatten())

    final_mse = np.mean((state_history[-1] - target_T_np)**2)
    print(f"SBTO finished in {total_time:.4f}s. Final MSE: {final_mse:.4e}")
    return np.array(state_history), np.array(control_history), total_time

# ===========================================================================
# 6. DEEPONET MPC EVALUATION (Renamed from Neural MPC)
# ===========================================================================
def run_deeponet_mpc_evaluation(args, config, target_T_np):
    print("\n--- [5/5] Evaluating DeepONet MPC ---")
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    controller_run_dir = os.path.join(args.output_base_dir, args.run_id)
    with open(os.path.join(controller_run_dir, "hyperparams.yaml"), 'r') as f: ch_params = yaml.safe_load(f)
    deeponet_run_id = ch_params['deeponet_run_id']
    deeponet_run_dir = os.path.join(args.output_base_dir, deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "propagator_deeponet_best.pth")
    with open(os.path.join(deeponet_run_dir, "hyperparams.yaml"), 'r') as f: dh_params = yaml.safe_load(f)
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: dh_params[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE)); physics_simulator.eval()
    
    target_T_torch = torch.from_numpy(target_T_np).float().unsqueeze(0).to(DEVICE)
    sensor_locs_torch = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).unsqueeze(0).unsqueeze(-1)
    num_control_steps = config['NT_SOLVER']
    
    T_current = torch.full((1, config['M_SENSORS']), 0.0, device=DEVICE)
    state_history, control_history = [T_current.cpu().numpy().flatten()], []
    
    start_time = time.time()
    for k in range(num_control_steps):
        print(f"DeepONet MPC Step {k+1}/{num_control_steps}", end='\r')
        horizon = min(args.nmpc_horizon, num_control_steps - k)
        
        control_params = torch.zeros(1, horizon, config['NUM_BASIS_FUNCTIONS'], device=DEVICE, requires_grad=True)
        optimizer = torch.optim.Adam([control_params], lr=args.nmpc_lr)
        
        for i in range(args.nmpc_optim_steps):
            optimizer.zero_grad()
            T_rollout = T_current
            running_loss, effort_loss = 0.0, 0.0
            control_sequence = torch.tanh(control_params)
            
            for h in range(horizon):
                w_h = control_sequence[:, h, :]
                T_rollout = physics_simulator(T_rollout, w_h, sensor_locs_torch).squeeze(-1)
                running_loss += torch.nn.functional.mse_loss(T_rollout, target_T_torch)
                effort_loss += torch.mean(w_h**2)
            
            terminal_loss = torch.nn.functional.mse_loss(T_rollout, target_T_torch)
            total_loss = (args.nmpc_terminal_weight * terminal_loss + 
                          args.nmpc_running_weight * (running_loss / horizon) + 
                          args.nmpc_effort_weight * (effort_loss / horizon))
            total_loss.backward()
            optimizer.step()
        
        with torch.no_grad():
            best_first_action = torch.tanh(control_params[:, 0, :].detach())
            control_history.append(best_first_action.cpu().numpy().flatten())
            T_current = physics_simulator(T_current, best_first_action, sensor_locs_torch).squeeze(-1)
            state_history.append(T_current.cpu().numpy().flatten())

    total_time = time.time() - start_time
    final_mse = np.mean((state_history[-1] - target_T_np)**2)
    print(f"DeepONet MPC finished in {total_time:.4f}s. Final MSE: {final_mse:.4e}")
    return np.array(state_history), np.array(control_history), total_time

# ===========================================================================
# 7. PLOTTING (SPLIT GROUPS) - TITLES REMOVED
# ===========================================================================
def generate_comparison_plots_split(all_results, x_grid, config, output_dir, target_subset, filename_suffix):
    """
    Generates a 1x3 plot for the specified subset of targets.
    """
    print(f"\n--- Generating {filename_suffix} Comparison Plot ---")
    font_size = 40
    legend_font_size = 40
    
    # Create 1x3 subplot grid
    fig, axes = plt.subplots(1, 3, figsize=(38, 9))
    if len(target_subset) == 1: axes = [axes]

    target_titles = {
        'sine': 'Sine Wave', 'ramp': 'Ramp', 'constant': 'Constant',
        'step': 'Step Function', 'high_freq': 'High Frequency', 'complex_gaussian': 'Complex Gaussian'
    }

    # Consistent colors/styles
    colors = {
        'PDE-OP': 'r', 'Adjoint Method': 'g', 'LMPC': 'c', 
        'SBTO': 'b', 'DeepONet MPC': 'm' # 'm' used for DeepONet MPC
    }
    styles = {
        'PDE-OP': '-', 'Adjoint Method': '--', 'LMPC': '-.', 
        'SBTO': ':', 'DeepONet MPC': ':'
    }
    
    for i, target_type in enumerate(target_subset):
        ax = axes[i]
        if target_type not in all_results: continue
        
        results = all_results[target_type]
        target_T_np = results['target_T']
        
        # Plot Target and Initial
        ax.plot(x_grid, target_T_np, 'k:', lw=4, label=r'$y_{\text{target}}(x)$')
        # Initial state (assumed 0 for heat equation usually, or taken from recurrent history)
        ax.plot(x_grid, results['recurrent']['states'][0], 'b--', lw=2, label='$y(0, x)$')
        
        # Plot Methods
        ax.plot(x_grid, results['recurrent']['states'][-1], 
                color=colors['PDE-OP'], linestyle=styles['PDE-OP'], lw=2.5, label='PDE-OP')
        ax.plot(x_grid, results['adjoint']['states'][-1], 
                color=colors['Adjoint Method'], linestyle=styles['Adjoint Method'], lw=2.5, label='Adjoint Method')
        ax.plot(x_grid, results['mpc']['states'][-1], 
                color=colors['LMPC'], linestyle=styles['LMPC'], lw=2.5, label='LMPC')
        ax.plot(x_grid, results['sbto']['states'][-1], 
                color=colors['SBTO'], linestyle=styles['SBTO'], lw=2.5, label='SBTO')
        ax.plot(x_grid, results['deeponet_mpc']['states'][-1], 
                color=colors['DeepONet MPC'], linestyle=styles['DeepONet MPC'], lw=2.5, label='DeepONet MPC')

        # ax.set_title(...) # Title removed as requested previously
        
        ax.set_xlabel('x', fontsize=font_size)
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.tick_params(axis='both', which='major', labelsize=font_size)
        
        # Only set ylabel on the first plot
        if i == 0:
            ax.set_ylabel('$y(T, x)$', fontsize=font_size)

    # Shared Legend
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', 
               bbox_to_anchor=(0.5, -0.15), ncol=7, fontsize=legend_font_size)
               
    fig.tight_layout(rect=[0, 0.1, 1, 1])
    
    save_path = os.path.join(output_dir, f"comparison_heat_{filename_suffix}.pdf")
    plt.savefig(save_path, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved plot to: {save_path}")

# ===========================================================================
# 8. MAIN ORCHESTRATOR
# ===========================================================================
def main(args):
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # 1. Define Target Groups
    targets_group_1 = ['sine', 'ramp', 'constant']
    targets_group_2 = ['step', 'high_freq', 'complex_gaussian']
    all_targets = targets_group_1 + targets_group_2
    
    all_results = {}
    x_grid = np.linspace(0, config['L'], config['M_SENSORS'])

    # 2. Evaluation Loop
    for target_type in all_targets:
        print(f"\n\n{'='*80}")
        print(f"--- Starting Unified Evaluation for Target: '{target_type.upper()}' ---")
        print(f"{'='*80}")
        
        target_T_np = generate_target_profile(target_type, x_grid, config['L'])
        
        recurrent_states, recurrent_controls, recurrent_time = run_recurrent_evaluation(args, config, target_T_np)
        adjoint_states, adjoint_controls, adjoint_time = run_adjoint_evaluation(config, target_T_np)
        mpc_states, mpc_controls, mpc_time = run_mpc_evaluation(config, target_T_np)
        sbto_states, sbto_controls, sbto_time = run_sbto_evaluation(args, config, target_T_np)
        nmpc_states, nmpc_controls, nmpc_time = run_deeponet_mpc_evaluation(args, config, target_T_np)
        
        all_results[target_type] = {
            'recurrent': {'states': recurrent_states, 'controls': recurrent_controls, 'time': recurrent_time},
            'adjoint': {'states': adjoint_states, 'controls': adjoint_controls, 'time': adjoint_time},
            'mpc': {'states': mpc_states, 'controls': mpc_controls, 'time': mpc_time},
            'sbto': {'states': sbto_states, 'controls': sbto_controls, 'time': sbto_time},
            'deeponet_mpc': {'states': nmpc_states, 'controls': nmpc_controls, 'time': nmpc_time},
            'target_T': target_T_np
        }
        
        # Print Summary Table
        print(f"\n--- SUMMARY FOR '{target_type.upper()}' ---")
        print(f"{'Method':<20} | {'Time Taken (s)':<20} | {'Final MSE':<20}")
        print("-" * 65)
        print(f"{'PDE-OP':<20} | {recurrent_time:<20.4f} | {np.mean((recurrent_states[-1] - target_T_np)**2):<20.4e}")
        print(f"{'Adjoint Method':<20} | {adjoint_time:<20.4f} | {np.mean((adjoint_states[-1] - target_T_np)**2):<20.4e}")
        print(f"{'LMPC':<20} | {mpc_time:<20.4f} | {np.mean((mpc_states[-1] - target_T_np)**2):<20.4e}")
        print(f"{'SBTO':<20} | {sbto_time:<20.4f} | {np.mean((sbto_states[-1] - target_T_np)**2):<20.4e}")
        print(f"{'DeepONet MPC':<20} | {nmpc_time:<20.4f} | {np.mean((nmpc_states[-1] - target_T_np)**2):<20.4e}")
        print("-" * 65)

    # 3. Generate Split Plots
    output_dir = os.path.join(args.output_base_dir, args.run_id, "comparison_plots_split_heat")
    os.makedirs(output_dir, exist_ok=True)
    
    generate_comparison_plots_split(all_results, x_grid, config, output_dir, targets_group_1, "group1_standard")
    generate_comparison_plots_split(all_results, x_grid, config, output_dir, targets_group_2, "group2_advanced")

    # 4. Save Data
    results_save_path = os.path.join(output_dir, "all_simulation_results_heat.npz")
    np.savez_compressed(results_save_path, **all_results)
    print(f"\nSaved all numerical results to: {results_save_path}")
    
    print("\n--- Unified evaluation for all targets complete. ---")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Unified evaluation of multiple controllers for the 1D Heat Equation.")
    parser.add_argument("--config_path", type=str, required=True, help="Path to the main YAML config file.")
    parser.add_argument("--output_base_dir", type=str, required=True, help="Base directory where runs are stored.")
    parser.add_argument("--run_id", type=str, required=True, help="The run_id of the trained recurrent controller.")

    # SBTO Hyperparameters
    parser.add_argument("--sbto_optim_steps", type=int, default=200)
    parser.add_argument("--sbto_lr", type=float, default=1e-2)
    parser.add_argument("--sbto_terminal_weight", type=float, default=1.0)
    parser.add_argument("--sbto_running_weight", type=float, default=0.1)
    parser.add_argument("--sbto_effort_weight", type=float, default=1e-5)
    
    # DeepONet MPC Hyperparameters
    parser.add_argument("--nmpc_horizon", type=int, default=10)
    parser.add_argument("--nmpc_optim_steps", type=int, default=200)
    parser.add_argument("--nmpc_lr", type=float, default=2e-2)
    parser.add_argument("--nmpc_terminal_weight", type=float, default=1.0)
    parser.add_argument("--nmpc_running_weight", type=float, default=0.1)
    parser.add_argument("--nmpc_effort_weight", type=float, default=1e-5)

    args = parser.parse_args()
    main(args)