# src/unified_burgers_comparison.py
# ---------------------------------------------------------------------------
# FINAL UNIFIED SCRIPT (for nu(x) Paper Revision)
# Compares all methods on a single, reproducible Burgers' equation problem.
#
# Methods Compared:
# 1. PDE-OP (Your recurrent controller)
# 2. Adjoint-Based Method (Classical Baseline 1)
# 3. NMPC (POD-Based Classical Baseline 2)
# 4. Open-Loop Surrogate Control (SBTO)
# 5. Neural MPC
# ---------------------------------------------------------------------------

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
import time

from data_and_models import PropagatorDeepONet, RecurrentController, generate_grf_spatial_series
from scipy.optimize import fsolve, minimize
from scipy import sparse
from scipy.sparse.linalg import spsolve

# ===========================================================================
# 1. UTILITY FUNCTIONS
# ===========================================================================

def generate_target_profile(config, target_type):
    """
    Generates a specific, potentially complex target profile for evaluation.
    All profiles are corrected to satisfy y(0)=0 and y(L)=0.
    """
    print(f"Generating advanced '{target_type}' target profile...")
    L = config['L']
    x_grid = np.linspace(0, L, config['M_SENSORS'])
    target_profile = np.zeros_like(x_grid)

    if target_type == 'sine':
        target_profile = 0.4 * np.sin(2 * np.pi * x_grid / L) + 0.4 * np.sin(4 * np.pi * x_grid / L)
    elif target_type == 'parabola':
        target_profile = 4 * 0.5 * x_grid * (L - x_grid)
    elif target_type == 'triangle':
        target_profile = 0.7 * (1.0 - 2 * np.abs(x_grid - 0.5))
    elif target_type == 'step':
        target_profile[(x_grid > 0.3 * L) & (x_grid < 0.7 * L)] = 0.6
    elif target_type == 'sawtooth':
        cycles = 4
        target_profile = 1.0 * ((x_grid % (L / cycles)) / (L / cycles))
    elif target_type == 'complex_gaussian':
        np.random.seed(42) # Fixed seed for this shape for reproducibility
        for _ in range(5):
            amp = np.random.uniform(0.5, 1.5) * (1 if np.random.rand() > 0.5 else -1)
            mean = np.random.uniform(0.1 * L, 0.9 * L)
            sigma = np.random.uniform(0.05 * L, 0.15 * L)
            target_profile += amp * np.exp(-((x_grid - mean)**2) / (2 * sigma**2))
    elif target_type == 'high_freq':
        base = 0.5 * np.sin(2 * np.pi * x_grid / L)
        hf_comp = 0.2 * np.sin(12 * np.pi * x_grid / L)
        target_profile = base + hf_comp
    elif target_type == 'zero':
        return target_profile
    else:
        raise ValueError(f"Unknown target type: {target_type}")

    correction_line = target_profile[0] + (target_profile[-1] - target_profile[0]) * x_grid / L
    corrected_target = target_profile - correction_line
    return corrected_target

def generate_problem_instance(config, target_type, seed):
    """Generates a single, reproducible problem instance."""
    print(f"--- Generating problem instance for target '{target_type}' with seed {seed} ---")
    np.random.seed(seed)
    
    x_grid_sensors = np.linspace(0, config['L'], config['M_SENSORS'])
    target_profile = generate_target_profile(config, target_type)
    
    initial_cond = 0.5 * np.sin(2 * np.pi * x_grid_sensors / config['L'])
    initial_cond[0], initial_cond[-1] = 0.0, 0.0

    nu_raw = generate_grf_spatial_series(config, 1, config['VISCOSITY_LENGTH_SCALE']).squeeze()
    nu_min, nu_max = config['VISCOSITY_RANGE']
    viscosity_profile = nu_min + (nu_max - nu_min) * (0.5 * (np.tanh(nu_raw) + 1))
    
    return initial_cond, target_profile, viscosity_profile

# ===========================================================================
# 2. EVALUATION METHODS
# ===========================================================================

# --- [Method 1] PDE-OP Evaluation ---
def run_pde_op_evaluation(args, config, initial_x_np, target_x_np, viscosity_profile_np):
    print("\n--- [1/5] Evaluating PDE-OP (Recurrent Controller) ---")
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    controller_run_dir = os.path.join(args.output_base_dir, args.pdeop_run_id)
    controller_model_path = os.path.join(controller_run_dir, "burgers_controller_model.pth")
    with open(os.path.join(controller_run_dir, "hyperparams.yaml"), 'r') as f: ch_params = yaml.safe_load(f)
    controller_kwargs = {k: v for k, v in ch_params.items() if k in ['hidden_dim', 'num_layers', 'activation_fn']}
    controller = RecurrentController(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], control_scale=config['CONTROL_SCALE'], **controller_kwargs).to(DEVICE)
    controller.load_state_dict(torch.load(controller_model_path, map_location=DEVICE)); controller.eval()

    deeponet_run_dir = os.path.join(args.output_base_dir, args.deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "burgers_propagator_best.pth")
    with open(os.path.join(deeponet_run_dir, "hyperparams.yaml"), 'r') as f: dh_params = yaml.safe_load(f)
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: dh_params[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE)); physics_simulator.eval()
    
    x_current = torch.from_numpy(initial_x_np).float().unsqueeze(0).to(DEVICE)
    target_x_torch = torch.from_numpy(target_x_np).float().unsqueeze(0).to(DEVICE)
    viscosity_profile_torch = torch.from_numpy(viscosity_profile_np).float().unsqueeze(0).to(DEVICE)
    x_grid_sensors_torch = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).unsqueeze(0).unsqueeze(-1)

    state_history, control_history = [x_current.cpu().numpy().flatten()], []
    start_time = time.time()
    with torch.no_grad():
        hidden_state = None
        for _ in range(config['NT_SOLVER'] - 1):
            u_k, hidden_state = controller(x_current, target_x_torch, viscosity_profile_torch, hidden_state)
            control_history.append(u_k.cpu().numpy().flatten())
            x_current = physics_simulator(x_current, u_k, viscosity_profile_torch, x_grid_sensors_torch).squeeze(-1)
            state_history.append(x_current.cpu().numpy().flatten())
    total_time = time.time() - start_time
    final_mse = np.mean((state_history[-1] - target_x_np)**2)
    print(f"Recurrent controller finished in {total_time:.2f}s. Final MSE: {final_mse:.4e}")
    return np.array(state_history), np.array(control_history), total_time

# --- [Method 2] Adjoint-Method Evaluation ---
def run_adjoint_nmpc_evaluation(config, initial_x_np_sensors, target_x_np_sensors, viscosity_profile_np_sensors):
    print("\n--- [2/5] Evaluating Adjoint-Based Method ---")
    L, Nx, Nt = config['L'], config['NX_SOLVER'], config['NT_SOLVER']
    dt = config['T_FINAL'] / (Nt - 1)
    m, H, max_opt_iters, alpha = config['NUM_BASIS_FUNCTIONS'], 10, 40, 5e-5
    
    x_grid = np.linspace(0, L, Nx)
    sensor_grid = np.linspace(0, L, config['M_SENSORS'])
    
    u0 = np.interp(x_grid, sensor_grid, initial_x_np_sensors)
    u_ref = np.interp(x_grid, sensor_grid, target_x_np_sensors)
    nu = np.interp(x_grid, sensor_grid, viscosity_profile_np_sensors)

    B = np.array([np.sin((j + 1) * np.pi * x_grid / L) for j in range(m)]).T
    dx = L / (Nx - 1)
    
    nu_half = (nu[:-1] + nu[1:]) / 2
    d_lower = nu_half; d_upper = nu_half
    d_main = -(np.concatenate(([0], nu_half)) + np.concatenate((nu_half, [0])))
    L_nu = sparse.diags([d_lower, d_main, d_upper], [-1, 0, 1], shape=(Nx, Nx), format='csr') / dx**2
    L_nu = L_nu.tolil(); L_nu[0, :], L_nu[-1, :] = 0, 0; L_nu = L_nu.tocsr()
    
    D_adv = sparse.diags([-1, 1], [-1, 1], shape=(Nx, Nx), format='csr') / (2 * dx)
    I_sp = sparse.eye(Nx, format='csr')
    
    def advective(u): return D_adv.dot(0.5 * (u**2))
    def dRdu(u): return -D_adv.dot(sparse.diags(u)) + L_nu
    def enforce_bcs_vec(v): v[0], v[-1] = 0.0, 0.0; return v
    def enforce_bcs_mat(A): A = A.tolil(); A[0, :], A[-1, :] = 0, 0; A[0, 0], A[-1, -1] = 1.0, 1.0; return A.tocsr()

    def solve_cn_step(u_n, v_n):
        f_n = B @ v_n
        u_np1, Rn = u_n.copy(), -advective(u_n) + L_nu.dot(u_n) + f_n
        for _ in range(25):
            Rnp1 = -advective(u_np1) + L_nu.dot(u_np1) + f_n
            F = enforce_bcs_vec(u_np1 - u_n - 0.5 * dt * (Rn + Rnp1))
            J = enforce_bcs_mat(I_sp - 0.5 * dt * dRdu(u_np1))
            try:
                delta = spsolve(J, -F)
            except: return u_n, None, None
            u_np1 += delta
            if np.linalg.norm(delta) < 1e-10: break
        return u_np1, dRdu(u_n), dRdu(u_np1)
    
    def compute_cost_and_grad(v_flat, u_init, Hk):
        V_seq = v_flat.reshape((Hk, m))
        U = np.zeros((Hk + 1, Nx)); U[0] = u_init
        A_list, B_list = [], []
        for n in range(Hk):
            u_np1, dRdu_n, dRdu_np1 = solve_cn_step(U[n], V_seq[n])
            if dRdu_n is None: return 1e10, np.zeros_like(v_flat)
            U[n+1] = u_np1
            A_list.append(enforce_bcs_mat(I_sp - 0.5 * dt * dRdu_np1))
            B_list.append(enforce_bcs_mat(-I_sp - 0.5 * dt * dRdu_n))
        J = 0.5 * dt * np.sum((U[1:] - u_ref)**2) + 0.5 * alpha * dt * np.sum(V_seq**2)
        p_next = dt * (U[-1] - u_ref)
        grad = np.zeros_like(V_seq)
        for n in reversed(range(Hk)):
            q = spsolve(A_list[n].T, p_next)
            grad[n, :] = dt * (alpha * V_seq[n] + (B.T @ q))
            p_next = dt * (U[n] - u_ref) - (B_list[n].T @ q)
        return J, grad.ravel()

    state_history, control_history = [u0], []
    u_current = u0.copy()
    start_time = time.time()
    for k in range(Nt - 1):
        Hk = min(H, Nt - 1 - k)
        res = minimize(lambda v: compute_cost_and_grad(v, u_current, Hk), np.zeros(Hk*m), method='L-BFGS-B', jac=True, bounds=[(-1.0, 1.0)] * (Hk * m), options={'maxiter': max_opt_iters, 'ftol': 1e-6})
        v_apply = res.x.reshape((Hk, m))[0]
        control_history.append(v_apply)
        u_current, _, _ = solve_cn_step(u_current, v_apply)
        state_history.append(u_current)
        print(f"Adjoint Step {k+1}/{Nt-1}: J={res.fun:.4e}", end='\r')
    total_time = time.time() - start_time
    final_mse = np.mean((state_history[-1] - target_x_np_sensors)**2)
    print(f"Adjoint finished in {total_time:.2f}s. Final MSE: {final_mse:.4e}")
    return np.array(state_history), np.array(control_history), total_time

# --- [Method 3] POD-MPC (Renamed to NMPC in Plot/Summary) ---
def run_pod_mpc_evaluation(config, initial_x_np_sensors, target_x_np_sensors, viscosity_profile_np_sensors):
    print("\n--- [3/5] Evaluating POD-Based NMPC ---")
    L, Nx, Nt_sim = config['L'], config['NX_SOLVER'], config['NT_SOLVER']
    dt = config['T_FINAL'] / (Nt_sim - 1)
    m, N_horizon, Q_weight, R_weight, r_pod_modes = config['NUM_BASIS_FUNCTIONS'], 10, 1.0, 5e-5, 5
    
    x_grid = np.linspace(0, L, Nx)
    sensor_grid = np.linspace(0, L, config['M_SENSORS'])
    
    u0 = np.interp(x_grid, sensor_grid, initial_x_np_sensors)
    u_ref = np.interp(x_grid, sensor_grid, target_x_np_sensors)
    nu = np.interp(x_grid, sensor_grid, viscosity_profile_np_sensors)
    
    B = np.array([np.sin((j + 1) * np.pi * x_grid / L) for j in range(m)]).T
    dx = L / (Nx - 1)
    
    D_adv_op = sparse.diags([-1, 1], [-1, 1], shape=(Nx, Nx), format='csr') / (2*dx)

    nu_half = (nu[:-1] + nu[1:]) / 2
    main_diag = -(np.concatenate(([0], nu_half)) + np.concatenate((nu_half, [0])))
    L_nu_dense = (np.diag(nu_half, 1) + np.diag(main_diag) + np.diag(nu_half, -1)) / dx**2
    L_nu_dense[0,:], L_nu_dense[-1,:] = 0, 0

    def cn_implicit_eq(u_next, u_prev, f_prev, f_next, dt_local):
        adv_prev, diff_prev = u_prev * (D_adv_op @ u_prev), L_nu_dense @ u_prev
        adv_next, diff_next = u_next * (D_adv_op @ u_next), L_nu_dense @ u_next
        res = (u_next - u_prev) - (dt_local / 2) * ((-adv_prev + diff_prev + f_prev) + (-adv_next + diff_next + f_next))
        res[0], res[-1] = u_next[0], u_next[-1]
        return res

    def full_cn_step(u_prev, f_prev, f_next, dt_local, guess=None):
        return fsolve(cn_implicit_eq, guess if guess is not None else u_prev, args=(u_prev, f_prev, f_next, dt_local))

    print("Generating POD basis... (using problem-specific physics)")
    U_snap, snap_idx = np.zeros((Nx, 100 * 10)), 0
    u_snap = u0.copy()
    for _ in range(100):
        for i in range(50):
            f = B @ (1.0 * (2 * np.random.rand(m) - 1))
            u_snap = full_cn_step(u_snap, f, f, dt)
            if i % 5 == 0: U_snap[:, snap_idx], snap_idx = u_snap, snap_idx + 1
    U_snap = U_snap[:, :snap_idx]
    
    u_mean = np.mean(U_snap, axis=1)
    U_pod, _, _ = np.linalg.svd(U_snap - u_mean[:, np.newaxis], full_matrices=False)
    V_pod = U_pod[:, :r_pod_modes]
    z_ref = V_pod.T @ (u_ref - u_mean)

    def reduced_cn_step(z_prev, f_prev, f_next, dt_local, guess=None):
        u_prev_full = V_pod @ z_prev + u_mean
        u_next_full_guess = V_pod @ (guess if guess is not None else z_prev) + u_mean
        u_next_full = full_cn_step(u_prev_full, f_prev, f_next, dt_local, guess=u_next_full_guess)
        return V_pod.T @ (u_next_full - u_mean)

    def mpc_cost(v_seq_flat, z_current_arg):
        v_seq = v_seq_flat.reshape((N_horizon, m))
        cost, z_pred, f_prev = 0.0, z_current_arg.copy(), np.zeros(Nx)
        for i in range(N_horizon):
            f_next = B @ v_seq[i]
            z_next = reduced_cn_step(z_pred, f_prev, f_next, dt, guess=z_pred)
            cost += Q_weight * np.sum((z_next - z_ref)**2) + R_weight * np.sum(v_seq[i]**2)
            z_pred, f_prev = z_next, f_next
        cost += Q_weight * np.sum((z_pred - z_ref)**2)
        return cost

    state_history, control_history = [u0], []
    u_current = u0.copy()
    start_time = time.time()
    for k in range(Nt_sim - 1):
        z_current = V_pod.T @ (u_current - u_mean)
        res = minimize(mpc_cost, np.zeros(N_horizon * m), args=(z_current), method='SLSQP', bounds=[(-1.0, 1.0)] * (N_horizon * m))
        v_apply = res.x.reshape((N_horizon, m))[0]
        control_history.append(v_apply)
        u_current = full_cn_step(u_current, B @ v_apply, B @ v_apply, dt)
        state_history.append(u_current)
        print(f"POD-NMPC Step {k+1}/{Nt_sim-1} completed.", end='\r')
    total_time = time.time() - start_time
    final_mse = np.mean((state_history[-1] - target_x_np_sensors)**2)
    print(f"POD-NMPC finished in {total_time:.2f}s. Final MSE: {final_mse:.4e}")
    return np.array(state_history), np.array(control_history), total_time

# --- [Method 4] SBTO Evaluation ---
def run_sbto_evaluation(args, config, initial_x_np, target_x_np, viscosity_profile_np):
    print("\n--- [4/5] Evaluating SBTO (Open-Loop Surrogate Control) ---")
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    deeponet_run_dir = os.path.join(args.output_base_dir, args.deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "burgers_propagator_best.pth")
    with open(os.path.join(deeponet_run_dir, "hyperparams.yaml"), 'r') as f: dh_params = yaml.safe_load(f)
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: dh_params[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE)); physics_simulator.eval()
    
    initial_x_torch = torch.from_numpy(initial_x_np).float().unsqueeze(0).to(DEVICE)
    target_x_torch = torch.from_numpy(target_x_np).float().unsqueeze(0).to(DEVICE)
    viscosity_profile_torch = torch.from_numpy(viscosity_profile_np).float().unsqueeze(0).to(DEVICE)
    x_grid_sensors_torch = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).unsqueeze(0).unsqueeze(-1)
    
    num_control_steps = config['NT_SOLVER'] - 1
    control_params = torch.zeros(1, num_control_steps, config['NUM_BASIS_FUNCTIONS'], device=DEVICE, requires_grad=True)
    
    optimizer = optim.Adam([control_params], lr=args.sbto_lr)
    mse_loss_fn = nn.MSELoss()
    
    print(f"Starting test-time optimization for {args.sbto_optim_steps} steps...")
    start_time = time.time()
    for i in range(args.sbto_optim_steps):
        optimizer.zero_grad()
        x_current = initial_x_torch
        total_effort = 0
        running_tracking_loss = 0.0
        control_sequence = torch.tanh(control_params) * config['CONTROL_SCALE']
        for k in range(num_control_steps):
            u_k = control_sequence[:, k, :]
            total_effort += torch.mean(u_k**2)
            x_current = physics_simulator(x_current, u_k, viscosity_profile_torch, x_grid_sensors_torch).squeeze(-1)
            running_tracking_loss += mse_loss_fn(x_current, target_x_torch)
        
        terminal_loss = mse_loss_fn(x_current, target_x_torch)
        avg_effort_loss = total_effort / num_control_steps
        avg_tracking_loss = running_tracking_loss / num_control_steps
        total_loss = (args.sbto_terminal_weight * terminal_loss +
                      args.sbto_running_weight * avg_tracking_loss +
                      args.sbto_effort_weight * avg_effort_loss)
        total_loss.backward()
        optimizer.step()
        if (i + 1) % 50 == 0: print(f"Step {i+1}/{args.sbto_optim_steps} | Loss: {total_loss.item():.4e}")
    total_time = time.time() - start_time

    final_control_sequence = (torch.tanh(control_params.detach()) * config['CONTROL_SCALE']).cpu().numpy().squeeze(0)
    state_history = [initial_x_np]
    x_current_eval = initial_x_torch
    with torch.no_grad():
        for k in range(num_control_steps):
            u_k_final = torch.from_numpy(final_control_sequence[k, :]).float().unsqueeze(0).to(DEVICE)
            x_next_eval = physics_simulator(x_current_eval, u_k_final, viscosity_profile_torch, x_grid_sensors_torch).squeeze(-1)
            state_history.append(x_next_eval.cpu().numpy().flatten())
            x_current_eval = x_next_eval

    final_mse = np.mean((state_history[-1] - target_x_np)**2)
    print(f"SBTO in {total_time:.2f}s. Final MSE: {final_mse:.4e}")   
    return np.array(state_history), final_control_sequence, total_time

# --- [Method 5] Neural MPC Evaluation ---
def run_neural_mpc_evaluation(args, config, initial_x_np, target_x_np, viscosity_profile_np):
    print("\n--- [5/5] Evaluating Neural MPC ---")
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    deeponet_run_dir = os.path.join(args.output_base_dir, args.deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "burgers_propagator_best.pth")
    with open(os.path.join(deeponet_run_dir, "hyperparams.yaml"), 'r') as f: dh_params = yaml.safe_load(f)
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: dh_params[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE)); physics_simulator.eval()

    initial_x_torch = torch.from_numpy(initial_x_np).float().unsqueeze(0).to(DEVICE)
    target_x_torch = torch.from_numpy(target_x_np).float().unsqueeze(0).to(DEVICE)
    viscosity_profile_torch = torch.from_numpy(viscosity_profile_np).float().unsqueeze(0).to(DEVICE)
    x_grid_sensors_torch = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).unsqueeze(0).unsqueeze(-1)

    state_history, control_history = [initial_x_torch.cpu().numpy().flatten()], []
    x_current = initial_x_torch
    num_control_steps = config['NT_SOLVER'] - 1
    
    start_time = time.time()
    for k in range(num_control_steps):
        print(f"Neural MPC Step {k+1}/{num_control_steps}", end='\r')
        horizon = min(args.nmpc_horizon, num_control_steps - k)
        control_params = torch.zeros(1, horizon, config['NUM_BASIS_FUNCTIONS'], device=DEVICE, requires_grad=True)
        optimizer = optim.Adam([control_params], lr=args.nmpc_lr)
        
        for i in range(args.nmpc_optim_steps):
            optimizer.zero_grad()
            x_rollout = x_current
            total_effort = 0
            running_tracking_loss = 0.0
            control_sequence = torch.tanh(control_params) * config['CONTROL_SCALE']
            for h in range(horizon):
                u_h = control_sequence[:, h, :]
                total_effort += torch.mean(u_h**2)
                x_rollout = physics_simulator(x_rollout, u_h, viscosity_profile_torch, x_grid_sensors_torch).squeeze(-1)
                running_tracking_loss += nn.functional.mse_loss(x_rollout, target_x_torch)
            
            terminal_loss = nn.functional.mse_loss(x_rollout, target_x_torch)
            avg_effort_loss = total_effort / horizon
            avg_tracking_loss = running_tracking_loss / horizon
            total_loss = (args.nmpc_terminal_weight * terminal_loss +
                          args.nmpc_running_weight * avg_tracking_loss +
                          args.nmpc_effort_weight * avg_effort_loss)
            total_loss.backward()
            optimizer.step()

        with torch.no_grad():
            best_first_action = (torch.tanh(control_params[:, 0, :]) * config['CONTROL_SCALE']).detach()
            control_history.append(best_first_action.cpu().numpy().flatten())
            x_current = physics_simulator(x_current, best_first_action, viscosity_profile_torch, x_grid_sensors_torch).squeeze(-1)
            state_history.append(x_current.cpu().numpy().flatten())
    total_time = time.time() - start_time
    final_mse = np.mean((state_history[-1] - target_x_np)**2)
    print(f"Neural MPC in {total_time:.2f}s. Final MSE: {final_mse:.4e}")
    return np.array(state_history), np.array(control_history), total_time

# ===========================================================================
# 3. PLOTTING (SPLIT GROUPS) - TITLES REMOVED
# ===========================================================================

def generate_comparison_plots_split(all_results, config, output_dir, target_subset, filename_suffix):
    """
    Generates a 1x3 plot for the specified subset of targets.
    """
    print(f"\n--- Generating {filename_suffix} Comparison Plot ---")
    font_size = 40
    legend_font_size = 40

    # Setup 1x3 grid
    fig, axes = plt.subplots(1, 3, figsize=(38, 9))
    if len(target_subset) == 1: axes = [axes] # Handle single target case

    # Colors and Styles
    # Note: 'NMPC' corresponds to the POD-MPC method results
    colors = {'PDE-OP': 'r', 'Adjoint Method': 'g', 'SBTO': 'b', 'DeepONet MPC': 'm'}
    styles = {'PDE-OP': '-', 'Adjoint Method': '--', 'SBTO': ':', 'DeepONet MPC': ':'}
    
    for i, target_type in enumerate(target_subset):
        ax = axes[i]
        if target_type not in all_results: continue

        results = all_results[target_type]
        sensor_grid = np.linspace(0, config['L'], len(results['target_x']))
        
        # Plot target and initial state
        ax.plot(sensor_grid, results['target_x'], 'k:', lw=4, label=r'$y_{\text{target}}(x)$')
        ax.plot(sensor_grid, results['initial_x'], 'b--', lw=2, label='$y(0, x)$')
        
        # Plot final states for each method
        for name, res in results['methods'].items():
            # Map interpolation if necessary (e.g. Adjoint/POD might be on different grid)
            grid_from = np.linspace(0, config['L'], res['states'].shape[1])
            final_state_interp = np.interp(sensor_grid, grid_from, res['states'][-1])
            
            ax.plot(sensor_grid, final_state_interp, color=colors.get(name,'k'), 
                    linestyle=styles.get(name,'-'), lw=2.5, label=name)

        # --- TITLES REMOVED AS REQUESTED ---
        # ax.set_title(f"Target: {target_type}", fontsize=font_size)
        
        ax.set_xlabel("x", fontsize=font_size)
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.tick_params(axis='both', which='major', labelsize=font_size)
        
        # Only set ylabel on the first plot
        if i == 0:
            ax.set_ylabel('$y(T, x)$', fontsize=font_size)

    # Shared Legend
    handles, labels = axes[0].get_legend_handles_labels()
    
    fig.legend(handles, labels, loc='lower center', 
               bbox_to_anchor=(0.5, -0.15), ncol=6, fontsize=legend_font_size)
               
    fig.tight_layout(rect=[0, 0.1, 1, 1])
    
    save_path = os.path.join(output_dir, f"comparison_{filename_suffix}.pdf")
    plt.savefig(save_path, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved plot to: {save_path}")

# ===========================================================================
# 4. MAIN ORCHESTRATOR
# ===========================================================================

def main(args):
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)

    # 1. Define Target Groups
    targets_group_1 = ['sine', 'parabola', 'zero']
    targets_group_2 = ['step', 'high_freq', 'complex_gaussian']
    all_targets = targets_group_1 + targets_group_2
    
    all_results = {}

    # 2. Evaluation Loop
    for target_type in all_targets:
        print("\n" + "#"*80)
        print(f"###   STARTING EVALUATION FOR TARGET: {target_type.upper()}   ###")
        print("#"*80)

        seed = args.seed
        initial_x_np, target_x_np, viscosity_profile_np = generate_problem_instance(config, target_type, seed)
        
        # Run all methods
        pdeop_res = run_pde_op_evaluation(args, config, initial_x_np, target_x_np, viscosity_profile_np)
        adjoint_res = run_adjoint_nmpc_evaluation(config, initial_x_np, target_x_np, viscosity_profile_np)
        #pod_mpc_res = run_pod_mpc_evaluation(config, initial_x_np, target_x_np, viscosity_profile_np) # Enabled
        sbto_res = run_sbto_evaluation(args, config, initial_x_np, target_x_np, viscosity_profile_np)
        nmpc_res = run_neural_mpc_evaluation(args, config, initial_x_np, target_x_np, viscosity_profile_np)
        
        # Store Results
        # Note: Mapping 'POD-MPC' result to key 'NMPC' for display purposes
        all_results[target_type] = {
            'initial_x': initial_x_np,
            'target_x': target_x_np,
            'viscosity_profile': viscosity_profile_np,
            'methods': {
                'PDE-OP': {'states': pdeop_res[0], 'controls': pdeop_res[1], 'time': pdeop_res[2]},
                'Adjoint Method': {'states': adjoint_res[0], 'controls': adjoint_res[1], 'time': adjoint_res[2]},
                #'NMPC': {'states': pod_mpc_res[0], 'controls': pod_mpc_res[1], 'time': pod_mpc_res[2]},
                'SBTO': {'states': sbto_res[0], 'controls': sbto_res[1], 'time': sbto_res[2]},
                'DeepONet MPC': {'states': nmpc_res[0], 'controls': nmpc_res[1], 'time': nmpc_res[2]}
            }
        }
    
    # 3. Print Summary Table
    print("\n\n" + "="*80)
    print("--- FINAL SUMMARY OF ALL EVALUATIONS ---")
    for target_type, results in all_results.items():
        print(f"\n--- Target: '{target_type.upper()}' ---")
        print(f"{'Method':<20} | {'Time Taken (s)':<20} | {'Final MSE':<20}")
        print("-" * 65)
        for name, res in results['methods'].items():
            # Interpolate to match target grid for MSE calculation
            grid_from = np.linspace(0, config['L'], res['states'].shape[1])
            grid_to = np.linspace(0, config['L'], len(results['target_x']))
            final_state_interp = np.interp(grid_to, grid_from, res['states'][-1])
            
            final_mse = np.mean((final_state_interp - results['target_x'])**2)
            print(f"{name:<20} | {res['time']:<20.4f} | {final_mse:<20.4e}")
    print("="*80)

    output_dir = os.path.join(args.output_base_dir, "unified_comparison_final_burgers")
    os.makedirs(output_dir, exist_ok=True)
    
    # 4. Generate Split Plots
    generate_comparison_plots_split(all_results, config, output_dir, targets_group_1, "group1_standard")
    generate_comparison_plots_split(all_results, config, output_dir, targets_group_2, "group2_advanced")
    
    # 5. Save Data
    np.savez_compressed(os.path.join(output_dir, "all_results_burgers.npz"), **all_results)
    print(f"Saved all numerical results to {output_dir}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Unified evaluation for Burgers' Eq with nu(x).")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--pdeop_run_id", type=str, required=True)
    parser.add_argument("--deeponet_run_id", type=str, required=True)
    parser.add_argument("--seed", type=int, default=42, help="Base seed for random problem instances.")

    # SBTO Hyperparameters
    parser.add_argument("--sbto_optim_steps", type=int, default=150)
    parser.add_argument("--sbto_lr", type=float, default=1e-2)
    parser.add_argument("--sbto_terminal_weight", type=float, default=1.0)
    parser.add_argument("--sbto_effort_weight", type=float, default=1e-5)
    parser.add_argument("--sbto_running_weight", type=float, default=1e-1)
    
    # Neural MPC Hyperparameters
    parser.add_argument("--nmpc_horizon", type=int, default=10)
    parser.add_argument("--nmpc_optim_steps", type=int, default=25)
    parser.add_argument("--nmpc_lr", type=float, default=2e-2)
    parser.add_argument("--nmpc_terminal_weight", type=float, default=1.0)
    parser.add_argument("--nmpc_effort_weight", type=float, default=1e-5)
    parser.add_argument("--nmpc_running_weight", type=float, default=1e-1)
    
    args = parser.parse_args()
    main(args)