# src/unified_burgers_comparison_final.py
# ---------------------------------------------------------------------------
# FINAL UNIFIED SCRIPT (Version 8 - No Titles, Correct Axis Labels)
# Compares control strategies for the 1D viscous Burgers' Equation.
# ---------------------------------------------------------------------------

import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
import time

from data_and_models import PropagatorDeepONet, RecurrentController
from scipy.optimize import fsolve, minimize
from scipy import sparse
from scipy.sparse.linalg import spsolve

# ===========================================================================
# 1. PROBLEM SETUP UTILITIES
# ===========================================================================

def generate_target_profile(config, target_type, x_grid):
    """Generates the target profile, ensuring it respects boundary conditions."""
    print(f"Generating '{target_type}' target profile...")
    L = config['L']
    profile = np.zeros_like(x_grid)

    # --- Standard Targets ---
    if target_type == 'sine':
        profile = 0.8 * np.sin(np.pi * x_grid / L)
    elif target_type == 'parabola':
        profile = 4 * 0.5 * x_grid * (L - x_grid)
    elif target_type == 'zero':
        profile = np.zeros_like(x_grid)
    
    # --- Advanced Targets ---
    elif target_type == 'step':
        # 0.6 between 30% and 70% of domain
        mask = (x_grid > 0.3 * L) & (x_grid < 0.7 * L)
        profile[mask] = 0.6
    elif target_type == 'high_freq':
        # Low freq base + High freq detail
        profile = 0.5 * np.sin(np.pi * x_grid / L) + 0.2 * np.sin(6 * np.pi * x_grid / L)
    elif target_type == 'complex_gaussian':
        np.random.seed(42)
        for _ in range(5):
            amp = np.random.uniform(0.3, 0.8) * (1 if np.random.rand() > 0.5 else -1)
            mean = np.random.uniform(0.2 * L, 0.8 * L)
            sigma = np.random.uniform(0.05 * L, 0.15 * L)
            profile += amp * np.exp(-((x_grid - mean)**2) / (2 * sigma**2))
    else:
        raise ValueError(f"Invalid target type: {target_type}")
        
    # Enforce Dirichlet BCs explicitly
    profile[0], profile[-1] = 0.0, 0.0 
    return profile

def get_initial_condition(config, x_grid):
    """Generates the initial condition."""
    ic = 0.5 * np.sin(2 * np.pi * x_grid / config['L'])
    ic[0], ic[-1] = 0.0, 0.0
    return ic

# ===========================================================================
# 2. CONTROLLER EVALUATION FUNCTIONS
# ===========================================================================

def run_recurrent_controller_evaluation(args, config, target_x_np, initial_x_np):
    print(f"\n--- [1/4] Evaluating PDE-OP (Recurrent NN) ---")
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load Models
    controller_run_dir = os.path.join(args.output_base_dir, args.run_id)
    controller_model_path = os.path.join(controller_run_dir, "burgers_controller_model.pth")
    with open(os.path.join(controller_run_dir, "hyperparams.yaml"), 'r') as f: ch_params = yaml.safe_load(f)
    controller_kwargs = {k: v for k, v in ch_params.items() if k in ['hidden_dim', 'num_layers', 'activation_fn']}
    controller = RecurrentController(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], control_scale=config['CONTROL_SCALE'], **controller_kwargs).to(DEVICE)
    controller.load_state_dict(torch.load(controller_model_path, map_location=DEVICE)); controller.eval()

    deeponet_run_id = ch_params['deeponet_run_id']
    deeponet_run_dir = os.path.join(args.output_base_dir, deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "burgers_propagator_best.pth")
    with open(os.path.join(deeponet_run_dir, "hyperparams.yaml"), 'r') as f: dh_params = yaml.safe_load(f)
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: dh_params[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE)); physics_simulator.eval()
    
    target_x_torch = torch.from_numpy(target_x_np).float().unsqueeze(0).to(DEVICE)
    x_current = torch.from_numpy(initial_x_np).float().unsqueeze(0).to(DEVICE)
    x_grid_sensors_torch = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).unsqueeze(0).unsqueeze(-1)
    
    state_history, control_history = [x_current.cpu().numpy().flatten()], []
    start_time = time.time()
    with torch.no_grad():
        hidden_state = None
        for _ in range(config['NT_SOLVER'] - 1):
            u_k, hidden_state = controller(x_current, target_x_torch, hidden_state)
            control_history.append(u_k.cpu().numpy().flatten())
            x_current = physics_simulator(x_current, u_k, x_grid_sensors_torch).squeeze(-1)
            state_history.append(x_current.cpu().numpy().flatten())
    total_time = time.time() - start_time
    
    final_mse = np.mean((state_history[-1] - target_x_np)**2)
    print(f"PDE-OP finished in {total_time:.4f}s. Final MSE: {final_mse:.4e}")
    return np.array(state_history), np.array(control_history), total_time

def run_adjoint_nmpc_evaluation(config, target_x_np_sensors, initial_x_np_sensors):
    print("\n--- [2/4] Evaluating Adjoint Method ---")
    L, nu, Nx, Nt = config['L'], config['VISCOSITY'], config['NX_SOLVER'], config['NT_SOLVER']
    dt = config['T_FINAL'] / (Nt - 1)
    m, H, max_opt_iters, alpha = config['NUM_BASIS_FUNCTIONS'], 10, 40, 5e-5
    x_grid = np.linspace(0, L, Nx); sensor_grid = np.linspace(0, L, config['M_SENSORS'])
    u_ref, u0 = np.interp(x_grid, sensor_grid, target_x_np_sensors), np.interp(x_grid, sensor_grid, initial_x_np_sensors)
    B = np.array([np.sin((j + 1) * np.pi * x_grid / L) for j in range(m)]).T
    dx = L / (Nx - 1)
    D = sparse.diags([-1, 1], [-1, 1], shape=(Nx, Nx), format='csr') / (2 * dx)
    D2 = sparse.diags([1, -2, 1], [-1, 0, 1], shape=(Nx, Nx), format='csr') / dx**2
    I_sp = sparse.eye(Nx, format='csr')
    
    def advective(u): return D.dot(0.5 * (u**2))
    def dRdu(u): return -D.dot(sparse.diags(u)) + nu * D2
    def enforce_bcs_vec(v): v[0], v[-1] = 0.0, 0.0; return v
    def enforce_bcs_mat(A): A = A.tolil(); A[0, :], A[-1, :] = 0, 0; A[0, 0], A[-1, -1] = 1.0, 1.0; return A.tocsr()
    
    def solve_cn_step(u_n, v_n):
        f_n = B @ v_n
        u_np1, Rn = u_n.copy(), -advective(u_n) + nu * D2.dot(u_n) + f_n
        for _ in range(25):
            Rnp1 = -advective(u_np1) + nu * D2.dot(u_np1) + f_n
            F = enforce_bcs_vec(u_np1 - u_n - 0.5 * dt * (Rn + Rnp1))
            J = enforce_bcs_mat(I_sp - 0.5 * dt * dRdu(u_np1))
            try: delta = spsolve(J, -F)
            except: return u_n, None, None
            u_np1 += delta
            if np.linalg.norm(delta) < 1e-10: break
        return u_np1, dRdu(u_n), dRdu(u_np1)

    def compute_cost_and_grad(v_flat, u_init, Hk):
        V_seq = v_flat.reshape((Hk, m))
        U = np.zeros((Hk + 1, Nx)); U[0] = u_init
        A_list, B_list = [], []
        for n in range(Hk):
            u_np1, dRdu_n, dRdu_np1 = solve_cn_step(U[n], V_seq[n])
            if dRdu_n is None: return 1e10, np.zeros_like(v_flat)
            U[n+1] = u_np1
            A_list.append(enforce_bcs_mat(I_sp - 0.5 * dt * dRdu_np1))
            B_list.append(enforce_bcs_mat(-I_sp - 0.5 * dt * dRdu_n))
        J = 0.5 * dt * np.sum((U[1:] - u_ref)**2) + 0.5 * alpha * dt * np.sum(V_seq**2)
        p_next = dt * (U[-1] - u_ref)
        grad = np.zeros_like(V_seq)
        for n in reversed(range(Hk)):
            q = spsolve(A_list[n].T, p_next)
            grad[n, :] = dt * (alpha * V_seq[n] + (B.T @ q))
            p_next = dt * (U[n] - u_ref) - (B_list[n].T @ q)
        return J, grad.ravel()

    state_history, control_history = [u0], []
    u_current = u0.copy()
    start_time = time.time()
    for k in range(Nt - 1):
        Hk = min(H, Nt - 1 - k)
        res = minimize(lambda v: compute_cost_and_grad(v, u_current, Hk), np.zeros(Hk*m), method='L-BFGS-B', jac=True, bounds=[(-1.0, 1.0)] * (Hk * m), options={'maxiter': max_opt_iters, 'ftol': 1e-6})
        v_apply = res.x.reshape((Hk, m))[0]
        control_history.append(v_apply)
        u_current, _, _ = solve_cn_step(u_current, v_apply)
        state_history.append(u_current)
        print(f"Adjoint Step {k+1}/{Nt-1}: J={res.fun:.4e}", end='\r')
    
    total_time = time.time() - start_time
    final_state_sensors = np.interp(sensor_grid, x_grid, state_history[-1])
    final_mse = np.mean((final_state_sensors - target_x_np_sensors)**2)
    print(f"\nAdjoint Method finished in {total_time:.4f}s. Final MSE: {final_mse:.4e}")
    return np.array(state_history), np.array(control_history), total_time

def run_sbto_evaluation(args, config, target_x_np, initial_x_np):
    print(f"\n--- [3/4] Evaluating SBTO ---")
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load Surrogate
    controller_run_dir = os.path.join(args.output_base_dir, args.run_id)
    with open(os.path.join(controller_run_dir, "hyperparams.yaml"), 'r') as f: ch_params = yaml.safe_load(f)
    deeponet_run_id = ch_params['deeponet_run_id']
    deeponet_run_dir = os.path.join(args.output_base_dir, deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "burgers_propagator_best.pth")
    with open(os.path.join(deeponet_run_dir, "hyperparams.yaml"), 'r') as f: dh_params = yaml.safe_load(f)
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: dh_params[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE)); physics_simulator.eval()

    initial_x_torch = torch.from_numpy(initial_x_np).float().unsqueeze(0).to(DEVICE)
    target_x_torch = torch.from_numpy(target_x_np).float().unsqueeze(0).to(DEVICE)
    x_grid_sensors_torch = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).unsqueeze(0).unsqueeze(-1)
    
    num_control_steps = config['NT_SOLVER'] - 1
    control_params = torch.zeros(1, num_control_steps, config['NUM_BASIS_FUNCTIONS'], device=DEVICE, requires_grad=True)
    optimizer = optim.Adam([control_params], lr=args.sbto_lr)
    mse_loss_fn = nn.MSELoss()
    
    print(f"Starting SBTO optimization for {args.sbto_optim_steps} steps...")
    start_time = time.time()
    for i in range(args.sbto_optim_steps):
        optimizer.zero_grad()
        x_current = initial_x_torch
        running_loss, total_effort = 0.0, 0.0
        control_sequence = torch.tanh(control_params) * config['CONTROL_SCALE']

        for k in range(num_control_steps):
            u_k = control_sequence[:, k, :]
            x_current = physics_simulator(x_current, u_k, x_grid_sensors_torch).squeeze(-1)
            running_loss += mse_loss_fn(x_current, target_x_torch)
            total_effort += torch.mean(u_k**2)
        
        terminal_loss = mse_loss_fn(x_current, target_x_torch)
        total_loss = (args.sbto_terminal_weight * terminal_loss +
                      args.sbto_running_weight * (running_loss / num_control_steps) +
                      args.sbto_effort_weight * (total_effort / num_control_steps))
        
        total_loss.backward()
        optimizer.step()
        if (i + 1) % 50 == 0: print(f"Step {i+1}/{args.sbto_optim_steps} | Loss: {total_loss.item():.4e}")

    total_time = time.time() - start_time
    final_control_sequence = (torch.tanh(control_params.detach()) * config['CONTROL_SCALE']).cpu().numpy().squeeze(0)
    state_history = [initial_x_np]
    x_current_eval = initial_x_torch
    with torch.no_grad():
        for k in range(num_control_steps):
            u_k_final = torch.from_numpy(final_control_sequence[k, :]).float().unsqueeze(0).to(DEVICE)
            x_next_eval = physics_simulator(x_current_eval, u_k_final, x_grid_sensors_torch).squeeze(-1)
            state_history.append(x_next_eval.cpu().numpy().flatten())
            x_current_eval = x_next_eval

    final_mse = np.mean((state_history[-1] - target_x_np)**2)
    print(f"SBTO finished in {total_time:.4f}s. Final MSE: {final_mse:.4e}")   
    return np.array(state_history), final_control_sequence, total_time

def run_deeponet_mpc_evaluation(args, config, target_x_np, initial_x_np):
    print(f"\n--- [4/4] Evaluating DeepONet MPC (Model: {args.run_id}) ---")
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load Surrogate
    controller_run_dir = os.path.join(args.output_base_dir, args.run_id)
    with open(os.path.join(controller_run_dir, "hyperparams.yaml"), 'r') as f: ch_params = yaml.safe_load(f)
    deeponet_run_id = ch_params['deeponet_run_id']
    deeponet_run_dir = os.path.join(args.output_base_dir, deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "burgers_propagator_best.pth")
    with open(os.path.join(deeponet_run_dir, "hyperparams.yaml"), 'r') as f: dh_params = yaml.safe_load(f)
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: dh_params[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE)); physics_simulator.eval()

    initial_x_torch = torch.from_numpy(initial_x_np).float().unsqueeze(0).to(DEVICE)
    target_x_torch = torch.from_numpy(target_x_np).float().unsqueeze(0).to(DEVICE)
    x_grid_sensors_torch = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).unsqueeze(0).unsqueeze(-1)
    
    state_history, control_history = [initial_x_torch.cpu().numpy().flatten()], []
    x_current = initial_x_torch
    num_control_steps = config['NT_SOLVER'] - 1
    
    start_time = time.time()
    for k in range(num_control_steps):
        print(f"DeepONet MPC Step {k+1}/{num_control_steps}", end='\r')
        horizon = min(args.nmpc_horizon, num_control_steps - k)
        control_params = torch.zeros(1, horizon, config['NUM_BASIS_FUNCTIONS'], device=DEVICE, requires_grad=True)
        optimizer = optim.Adam([control_params], lr=args.nmpc_lr)
        
        for i in range(args.nmpc_optim_steps):
            optimizer.zero_grad()
            x_rollout = x_current
            running_loss, effort_loss = 0.0, 0.0
            control_sequence = torch.tanh(control_params) * config['CONTROL_SCALE']
            
            for h in range(horizon):
                u_h = control_sequence[:, h, :]
                x_rollout = physics_simulator(x_rollout, u_h, x_grid_sensors_torch).squeeze(-1)
                running_loss += nn.functional.mse_loss(x_rollout, target_x_torch)
                effort_loss += torch.mean(u_h**2)
            
            terminal_loss = nn.functional.mse_loss(x_rollout, target_x_torch)
            total_loss = (args.nmpc_terminal_weight * terminal_loss + 
                          args.nmpc_running_weight * (running_loss / horizon) + 
                          args.nmpc_effort_weight * (effort_loss / horizon))
            total_loss.backward()
            optimizer.step()
        
        with torch.no_grad():
            best_first_action = (torch.tanh(control_params[:, 0, :]) * config['CONTROL_SCALE']).detach()
            control_history.append(best_first_action.cpu().numpy().flatten())
            x_current = physics_simulator(x_current, best_first_action, x_grid_sensors_torch).squeeze(-1)
            state_history.append(x_current.cpu().numpy().flatten())
            
    total_time = time.time() - start_time
    final_mse = np.mean((state_history[-1] - target_x_np)**2)
    print(f"\nDeepONet MPC finished in {total_time:.4f}s. Final MSE: {final_mse:.4e}")
    return np.array(state_history), np.array(control_history), total_time

# ===========================================================================
# 3. PLOTTING AND MAIN
# ===========================================================================

def generate_comparison_plots_split(all_results, sensor_grid, config, output_dir, target_subset, filename_suffix):
    """Generates a 1x3 plot for a subset of targets."""
    print(f"\n--- Generating {filename_suffix} Plot ---")
    font_size = 40
    legend_font_size = 40
    
    fig, axes = plt.subplots(1, 3, figsize=(38, 9))
    if len(target_subset) == 1: axes = [axes]

    # target_titles = {...} # Titles removed

    colors = {'PDE-OP': 'r', 'Adjoint Method': 'g', 'SBTO': 'b', 'DeepONet MPC': 'm'}
    styles = {'PDE-OP': '-', 'Adjoint Method': ':', 'SBTO': '--', 'DeepONet MPC': ':'}

    for i, target_type in enumerate(target_subset):
        ax = axes[i]
        if target_type not in all_results: continue
        results = all_results[target_type]
        
        ax.plot(sensor_grid, results['target_x'], 'k:', lw=4, label=r'$y_{\text{target}}(x)$')
        ax.plot(sensor_grid, results['initial_x'], 'k--', lw=2.5, label='$y(0, x)$')
        
        # Helper to match grids
        def interp_to_sensor_grid(state_history):
            if state_history.shape[1] != len(sensor_grid):
                x_grid_classical = np.linspace(0, config['L'], state_history.shape[1])
                return np.interp(sensor_grid, x_grid_classical, state_history[-1])
            return state_history[-1]

        ax.plot(sensor_grid, results['recurrent']['states'][-1], 
                color=colors['PDE-OP'], linestyle=styles['PDE-OP'], lw=2.5, label='PDE-OP')
        ax.plot(sensor_grid, interp_to_sensor_grid(results['adjoint_nmpc']['states']), 
                color=colors['Adjoint Method'], linestyle=styles['Adjoint Method'], lw=3, label='Adjoint Method')
        ax.plot(sensor_grid, results['sbto']['states'][-1], 
                color=colors['SBTO'], linestyle=styles['SBTO'], lw=2.5, label='SBTO')
        ax.plot(sensor_grid, results['neural_mpc']['states'][-1], 
                color=colors['DeepONet MPC'], linestyle=styles['DeepONet MPC'], lw=3, label='DeepONet MPC')

        # ax.set_title(...) # REMOVED
        
        ax.set_xlabel('x', fontsize=font_size)
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.tick_params(axis='both', which='major', labelsize=font_size)
        
        # Change Y-Label to y(T, x)
        if i == 0: ax.set_ylabel('$y(T, x)$', fontsize=font_size)
    
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.15), ncol=6, fontsize=legend_font_size)
    fig.tight_layout(rect=[0, 0.1, 1, 1])
    
    save_path = os.path.join(output_dir, f"comparison_burgers_{filename_suffix}.pdf")
    plt.savefig(save_path, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved plot to: {save_path}")

def main(args):
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
        
    # Define Groups
    targets_group_1 = ['sine', 'parabola', 'zero']
    targets_group_2 = ['step', 'high_freq', 'complex_gaussian']
    all_targets = targets_group_1 + targets_group_2
    
    all_results = {}

    for target_type in all_targets:

        print(f"\n\n{'='*80}")
        print(f"--- Starting Burgers' Eval for Target: '{target_type.upper()}' ---")
        print(f"--- Using Model Run ID: '{args.run_id}' ---")
        print(f"{'='*80}")
        
        sensor_grid = np.linspace(0, config['L'], config['M_SENSORS'])
        initial_x_np = get_initial_condition(config, sensor_grid)
        target_x_np = generate_target_profile(config, target_type, sensor_grid)

        # --- RUN EVALUATIONS ---
        recurrent_states, recurrent_controls, recurrent_time = run_recurrent_controller_evaluation(args, config, target_x_np, initial_x_np)
        adjoint_states, adjoint_controls, adjoint_time = run_adjoint_nmpc_evaluation(config, target_x_np, initial_x_np)
        sbto_states, sbto_controls, sbto_time = run_sbto_evaluation(args, config, target_x_np, initial_x_np)
        nmpc_states, nmpc_controls, nmpc_time = run_deeponet_mpc_evaluation(args, config, target_x_np, initial_x_np)

        all_results[target_type] = {
            'recurrent': {'states': recurrent_states, 'controls': recurrent_controls, 'time': recurrent_time},
            'adjoint_nmpc': {'states': adjoint_states, 'controls': adjoint_controls, 'time': adjoint_time},
            'sbto': {'states': sbto_states, 'controls': sbto_controls, 'time': sbto_time},
            'neural_mpc': {'states': nmpc_states, 'controls': nmpc_controls, 'time': nmpc_time},
            'target_x': target_x_np,
            'initial_x': initial_x_np
        }

        # --- PRINT SUMMARY ---
        print(f"\n--- SUMMARY FOR '{target_type.upper()}' ---")
        print(f"{'Method':<20} | {'Time Taken (s)':<20} | {'Final MSE':<20}")
        print("-" * 65)
        
        def get_mse(states, target):
            if states.shape[1] != len(target):
                x_grid_classical = np.linspace(0, config['L'], states.shape[1])
                final_state_interp = np.interp(sensor_grid, x_grid_classical, states[-1])
                return np.mean((final_state_interp - target)**2)
            else:
                return np.mean((states[-1] - target)**2)

        print(f"{'PDE-OP':<20} | {recurrent_time:<20.4f} | {get_mse(recurrent_states, target_x_np):<20.4e}")
        print(f"{'Adjoint Method':<20} | {adjoint_time:<20.4f} | {get_mse(adjoint_states, target_x_np):<20.4e}")
        print(f"{'SBTO':<20} | {sbto_time:<20.4f} | {get_mse(sbto_states, target_x_np):<20.4e}")
        print(f"{'DeepONet MPC':<20} | {nmpc_time:<20.4f} | {get_mse(nmpc_states, target_x_np):<20.4e}")
        print("-" * 65)

    output_dir = os.path.join(args.output_base_dir, "comparison_plots_burgers_split")
    os.makedirs(output_dir, exist_ok=True)
    
    generate_comparison_plots_split(all_results, sensor_grid, config, output_dir, targets_group_1, "group1_standard")
    generate_comparison_plots_split(all_results, sensor_grid, config, output_dir, targets_group_2, "group2_advanced")
    
    results_save_path = os.path.join(output_dir, "all_simulation_results.npz")
    np.savez_compressed(results_save_path, **all_results)
    print(f"\nSaved all numerical results to: {results_save_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Unified evaluation of multiple controllers for Burgers' Eq.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    
    # SBTO Hyperparameters
    parser.add_argument("--sbto_optim_steps", type=int, default=150)
    parser.add_argument("--sbto_lr", type=float, default=1e-2)
    parser.add_argument("--sbto_terminal_weight", type=float, default=1.0)
    parser.add_argument("--sbto_effort_weight", type=float, default=1e-5)
    parser.add_argument("--sbto_running_weight", type=float, default=0.1)
    
    # DeepONet MPC Hyperparameters
    parser.add_argument("--nmpc_horizon", type=int, default=10)
    parser.add_argument("--nmpc_optim_steps", type=int, default=25)
    parser.add_argument("--nmpc_lr", type=float, default=2e-2)
    parser.add_argument("--nmpc_terminal_weight", type=float, default=1.0)
    parser.add_argument("--nmpc_effort_weight", type=float, default=1e-5)
    parser.add_argument("--nmpc_running_weight", type=float, default=0.1)
    parser.add_argument("--run_id", type=str, required=False, default=None, help="Run ID of the trained controller model to evaluate.")
    
    args = parser.parse_args()
    args.run_id = None 
    main(args)