
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
import time

from data_and_models import PropagatorDeepONet, RecurrentController
from scipy.optimize import fsolve, minimize
from scipy import sparse
from scipy.sparse.linalg import spsolve

def generate_target_profile(config, target_type, x_grid):
    print(f"Generating '{target_type}' target profile for a grid of size {len(x_grid)}...")
    if target_type == 'sine': return 0.8 * np.sin(np.pi * x_grid / config['L'])
    elif target_type == 'parabola': return 4 * 0.5 * x_grid * (config['L'] - x_grid)
    elif target_type == 'zero': return np.zeros_like(x_grid)
    else: raise ValueError(f"Invalid target type: {target_type}")

def get_initial_condition(config, x_grid):
    ic = 0.5 * np.sin(2 * np.pi * x_grid / config['L'])
    ic[0], ic[-1] = 0.0, 0.0
    return ic

def run_recurrent_controller_evaluation(args, config, target_x_np, initial_x_np):
    print(f"\n--- [1/3] Evaluating Recurrent NN Controller (Model: {args.run_id}) ---")
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    controller_run_dir = os.path.join(args.output_base_dir, args.run_id)
    controller_model_path = os.path.join(controller_run_dir, "burgers_controller_model.pth")
    with open(os.path.join(controller_run_dir, "hyperparams.yaml"), 'r') as f:
        controller_hyperparams = yaml.safe_load(f)
    controller_arg_keys = ['hidden_dim', 'num_layers', 'activation_fn']
    controller_kwargs = {k: v for k, v in controller_hyperparams.items() if k in controller_arg_keys}
    controller = RecurrentController(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], control_scale=config['CONTROL_SCALE'], **controller_kwargs).to(DEVICE)
    controller.load_state_dict(torch.load(controller_model_path, map_location=DEVICE)); controller.eval()
    deeponet_run_id = controller_hyperparams['deeponet_run_id']
    deeponet_run_dir = os.path.join(args.output_base_dir, deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "burgers_propagator_best.pth")
    with open(os.path.join(deeponet_run_dir, "hyperparams.yaml"), 'r') as f:
        deeponet_hyperparams = yaml.safe_load(f)
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: deeponet_hyperparams[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE)); physics_simulator.eval()
    print("Successfully loaded recurrent controller and surrogate model.")
    target_x_torch = torch.from_numpy(target_x_np).float().unsqueeze(0).to(DEVICE)
    x_current = torch.from_numpy(initial_x_np).float().unsqueeze(0).to(DEVICE)
    x_grid_sensors_torch = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).unsqueeze(0).unsqueeze(-1)
    state_history, control_history = [x_current.cpu().numpy().flatten()], []
    start_time = time.time()
    with torch.no_grad():
        hidden_state = None
        for _ in range(config['NT_SOLVER'] - 1):
            u_k, hidden_state = controller(x_current, target_x_torch, hidden_state)
            control_history.append(u_k.cpu().numpy().flatten())
            x_current = physics_simulator(x_current, u_k, x_grid_sensors_torch).squeeze(-1)
            state_history.append(x_current.cpu().numpy().flatten())
    total_time = time.time() - start_time
    final_mse = np.mean((state_history[-1] - target_x_np)**2)
    print(f"Recurrent controller finished in {total_time:.4f}s. Final MSE (surrogate): {final_mse:.4e}")
    return np.array(state_history), np.array(control_history), total_time

def run_pod_mpc_evaluation(config, target_x_np_sensors, initial_x_np_sensors):
    print("\n--- [2/3] Evaluating POD-Based MPC ---")
    L, nu = config['L'], config['VISCOSITY']
    Nx = config['NX_SOLVER']
    Nt_sim = config['NT_SOLVER'] - 1
    T_sim = config['T_FINAL']
    dt = T_sim / (Nt_sim - 1) if Nt_sim > 1 else 0.0
    m = config['NUM_BASIS_FUNCTIONS']
    N_horizon, dt_mpc, Q_weight, R_weight, r_pod_modes = 10, 0.05, 1.0, 5e-5, 5
    x = np.linspace(0, L, Nx)
    dx = L / (Nx - 1)
    sensor_grid = np.linspace(0, L, config['M_SENSORS'])
    u_ref, u0 = np.interp(x, sensor_grid, target_x_np_sensors), np.interp(x, sensor_grid, initial_x_np_sensors)
    B = np.array([np.sin((j + 1) * np.pi * x / L) for j in range(m)]).T
    D = np.zeros((Nx, Nx))
    for i in range(1, Nx - 1): D[i, i-1], D[i, i+1] = -1/(2*dx), 1/(2*dx)
    D2 = np.zeros((Nx, Nx))
    for i in range(1, Nx - 1): D2[i, i-1], D2[i, i], D2[i, i+1] = 1/dx**2, -2/dx**2, 1/dx**2
    def cn_implicit_eq(u_next, u_prev, f_prev, f_next, dt_local):
        adv_prev, diff_prev = u_prev * (D @ u_prev), nu * (D2 @ u_prev)
        adv_next, diff_next = u_next * (D @ u_next), nu * (D2 @ u_next)
        return (u_next - u_prev) - (dt_local / 2) * ((-adv_prev + diff_prev + f_prev) + (-adv_next + diff_next + f_next))
    def full_cn_step(u_prev, f_prev, f_next, dt_local, guess=None):
        u_next = fsolve(cn_implicit_eq, guess if guess is not None else u_prev, args=(u_prev, f_prev, f_next, dt_local))
        u_next[0], u_next[-1] = 0, 0
        return u_next
    def generate_snapshots(num_snaps=100, amp=1.0):
        U_snap, snap_idx = np.zeros((Nx, num_snaps * 10)), 0
        u = u0.copy()
        for snap in range(num_snaps):
            for i in range(50):
                f = B @ (amp * (2 * np.random.rand(m) - 1))
                u = full_cn_step(u, f, f, dt)
                if i % 5 == 0: U_snap[:, snap_idx], snap_idx = u, snap_idx + 1
        return U_snap[:, :snap_idx]
    print("Generating POD basis... (this may take a moment)")
    U_snap = generate_snapshots()
    u_mean = np.mean(U_snap, axis=1)
    U_pod, _, _ = np.linalg.svd(U_snap - u_mean[:, np.newaxis], full_matrices=False)
    V_pod = U_pod[:, :r_pod_modes]
    z_ref = V_pod.T @ (u_ref - u_mean)
    def reduced_cn_implicit_eq(z_next, z_prev, f_prev, f_next, dt_local, V_pod_arg, u_mean_arg):
        u_prev_full, u_next_full = V_pod_arg @ z_prev + u_mean_arg, V_pod_arg @ z_next + u_mean_arg
        return V_pod_arg.T @ cn_implicit_eq(u_next_full, u_prev_full, f_prev, f_next, dt_local)
    def reduced_cn_step(z_prev, f_prev, f_next, dt_local, V_pod_arg, u_mean_arg, guess=None):
        return fsolve(reduced_cn_implicit_eq, guess if guess is not None else z_prev, args=(z_prev, f_prev, f_next, dt_local, V_pod_arg, u_mean_arg))
    def mpc_cost(v_seq_flat, z_current_arg, dt_mpc_arg, V_pod_arg, u_mean_arg, N_arg, m_arg, Q_arg, R_arg, z_ref_arg, B_arg):
        v_seq = v_seq_flat.reshape((N_arg, m_arg))
        cost, z_pred, f_prev = 0.0, z_current_arg.copy(), np.zeros(Nx)
        for i in range(N_arg):
            f_next = B_arg @ v_seq[i]
            z_next = reduced_cn_step(z_pred, f_prev, f_next, dt_mpc_arg, V_pod_arg, u_mean_arg, guess=z_pred)
            cost += Q_arg * np.sum((z_next - z_ref_arg)**2) + R_arg * np.sum(v_seq[i]**2)
            z_pred, f_prev = z_next, f_next
        cost += Q_arg * np.sum((z_pred - z_ref_arg)**2)
        return cost
    u_history, v_history = np.zeros((Nx, Nt_sim)), np.zeros((m, Nt_sim))
    u_current = u0.copy()
    z_current = V_pod.T @ (u_current - u_mean)
    start_time = time.time()
    for k in range(Nt_sim):
        res = minimize(mpc_cost, np.zeros(N_horizon * m), args=(z_current, dt_mpc, V_pod, u_mean, N_horizon, m, Q_weight, R_weight, z_ref, B), method='SLSQP', bounds=[(-1.0, 1.0)] * (N_horizon * m))
        v_apply = res.x.reshape((N_horizon, m))[0]
        if k < Nt_sim - 1:
            u_current = full_cn_step(u_current, B @ v_apply, B @ v_apply, dt)
        u_history[:, k], v_history[:, k] = u_current, v_apply
        z_current = V_pod.T @ (u_current - u_mean)
        print(f"POD-MPC Step {k+1}/{Nt_sim} completed. Cost: {res.fun:.4e}", end='\r')
    total_time = time.time() - start_time
    state_history, control_history = u_history.T, v_history.T
    final_state_sensors = np.interp(sensor_grid, x, state_history[-1])
    final_mse = np.mean((final_state_sensors - target_x_np_sensors)**2)
    print(f"\nPOD-MPC finished in {total_time:.4f}s. Final MSE: {final_mse:.4e}")
    return state_history, control_history, total_time

def run_adjoint_nmpc_evaluation(config, target_x_np_sensors, initial_x_np_sensors):
    print("\n--- [3/3] Evaluating Adjoint-Based NMPC ---")
    L, nu = config['L'], config['VISCOSITY']
    Nx, Nt = config['NX_SOLVER'], config['NT_SOLVER']
    dt = config['T_FINAL'] / (Nt - 1) if Nt > 1 else 0.0
    m = config['NUM_BASIS_FUNCTIONS']
    H, max_opt_iters, alpha = 10, 40, 5e-5
    x_grid = np.linspace(0, L, Nx)
    sensor_grid = np.linspace(0, L, config['M_SENSORS'])
    u_ref, u0 = np.interp(x_grid, sensor_grid, target_x_np_sensors), np.interp(x_grid, sensor_grid, initial_x_np_sensors)
    B = np.array([np.sin((j + 1) * np.pi * x_grid / L) for j in range(m)]).T
    dx = L / (Nx - 1)
    D = sparse.diags([-1, 1], [-1, 1], shape=(Nx, Nx), format='csr') / (2 * dx)
    D2 = sparse.diags([1, -2, 1], [-1, 0, 1], shape=(Nx, Nx), format='csr') / dx**2
    I_sp = sparse.eye(Nx, format='csr')
    def advective(u): return D.dot(0.5 * (u**2))
    def dRdu(u): return -D.dot(sparse.diags(u)) + nu * D2
    def enforce_bcs(A_or_v):
        if isinstance(A_or_v, np.ndarray): A_or_v[0], A_or_v[-1] = 0.0, 0.0
        else: A_or_v = A_or_v.tolil(); A_or_v[0, :], A_or_v[-1, :] = 0, 0; A_or_v[0, 0], A_or_v[-1, -1] = 1.0, 1.0; A_or_v = A_or_v.tocsr()
        return A_or_v
    def solve_cn_step(u_n, v_n):
        f_n = B @ v_n
        u_np1, Rn = u_n.copy(), -advective(u_n) + nu * D2.dot(u_n) + f_n
        for _ in range(25):
            Rnp1 = -advective(u_np1) + nu * D2.dot(u_np1) + f_n
            F = enforce_bcs(u_np1 - u_n - 0.5 * dt * (Rn + Rnp1))
            J = enforce_bcs(I_sp - 0.5 * dt * dRdu(u_np1))
            delta = spsolve(J, -F)
            u_np1 += delta
            if np.linalg.norm(delta) < 1e-10: break
        return u_np1, dRdu(u_n), dRdu(u_np1)
    def compute_cost_and_grad(v_flat, u_init, Hk):
        V_seq = v_flat.reshape((Hk, m))
        U = np.zeros((Hk + 1, Nx)); U[0] = u_init
        A_list, B_list = [], []
        for n in range(Hk):
            u_np1, dRdu_n, dRdu_np1 = solve_cn_step(U[n], V_seq[n])
            U[n+1] = u_np1
            A_list.append(enforce_bcs(I_sp - 0.5 * dt * dRdu_np1))
            B_list.append(enforce_bcs(-I_sp - 0.5 * dt * dRdu_n))
        J = 0.5 * dt * np.sum((U[1:] - u_ref)**2) + 0.5 * alpha * dt * np.sum(V_seq**2)
        p_next = dt * (U[-1] - u_ref)
        grad = np.zeros_like(V_seq)
        for n in reversed(range(Hk)):
            q = spsolve(A_list[n].T, p_next)
            grad[n, :] = dt * (alpha * V_seq[n] + (B.T @ q))
            p_next = dt * (U[n] - u_ref) - (B_list[n].T @ q)
        return J, grad.ravel()
    state_history, control_history = [u0], []
    u_current = u0.copy()
    start_time = time.time()
    for k in range(Nt - 1):
        Hk = min(H, Nt - 1 - k)
        res = minimize(lambda v: compute_cost_and_grad(v, u_current, Hk), np.zeros(Hk*m), method='L-BFGS-B', jac=True, bounds=[(-1.0, 1.0)] * (Hk * m), options={'maxiter': max_opt_iters, 'ftol': 1e-6})
        v_apply = res.x.reshape((Hk, m))[0]
        control_history.append(v_apply)
        u_current, _, _ = solve_cn_step(u_current, v_apply)
        state_history.append(u_current)
        print(f"Adjoint-NMPC Step {k+1}/{Nt-1}: J={res.fun:.4e}", end='\r')
    total_time = time.time() - start_time
    state_history, control_history = np.array(state_history), np.array(control_history)
    final_state_sensors = np.interp(sensor_grid, x_grid, state_history[-1])
    final_mse = np.mean((final_state_sensors - target_x_np_sensors)**2)
    print(f"\nAdjoint-NMPC finished in {total_time:.4f}s. Final MSE: {final_mse:.4e}")
    return state_history, control_history, total_time

def generate_comparison_plots(all_results, sensor_grid, config, output_dir):
    """
    Generates a single, combined plot showing the final state comparison for
    all three target types in a single row, with a shared legend below.
    Each subplot has its own independently scaled y-axis.
    """
    print(f"\n--- Generating Combined Comparison Plot for All Targets ---")
    os.makedirs(output_dir, exist_ok=True)
    
    font_size = 40
    legend_font_size = 40

    fig, axes = plt.subplots(1, 3, figsize=(38, 9), sharey=False)
    
    target_order = ['sine', 'parabola', 'zero']

    for i, target_type in enumerate(target_order):
        ax = axes[i]
        results = all_results[target_type]
        target_x_np = results['target_x']
        initial_x_np = results['initial_x']
        
        ax.plot(sensor_grid, target_x_np, 'k:', lw=4, label=r'$y_{\text{target}}(x)$')
        ax.plot(sensor_grid, initial_x_np, 'b--', lw=2, label='$y(0, x)$')
        
        x_grid_classical = np.linspace(0, config['L'], config['NX_SOLVER'])
        pod_final_state_sensors = np.interp(sensor_grid, x_grid_classical, results['pod_mpc']['states'][-1])
        nmpc_final_state_sensors = np.interp(sensor_grid, x_grid_classical, results['nmpc']['states'][-1])
        
        ax.plot(sensor_grid, results['recurrent']['states'][-1], 'r-', lw=2.5, label='PDE-OP')
        ax.plot(sensor_grid, nmpc_final_state_sensors, 'm:', lw=2.5, label='Adjoint-Method')
        ax.plot(sensor_grid, pod_final_state_sensors, 'c-.', lw=2.5, label='NMPC')

        ax.set_xlabel('x', fontsize=font_size)
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.tick_params(axis='both', which='major', labelsize=font_size)
        ax.set_ylabel('$y(t, x)$', fontsize=font_size)
    
    handles, labels = axes[0].get_legend_handles_labels()
    
    fig.legend(handles, labels, loc='lower center', 
               bbox_to_anchor=(0.5, -0.05), ncol=5, fontsize=legend_font_size)
               
    fig.subplots_adjust(bottom=0.3, wspace=0.3)
    
    save_path = os.path.join(output_dir, "comparison_final_state_ALL_TARGETS.pdf")
    plt.savefig(save_path, bbox_inches='tight')
    plt.close(fig)
    print(f"\nSaved combined comparison plot to: {save_path}")

def main(args):
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    
    all_results = {}
    target_types = ['sine', 'parabola', 'zero']

    for target_type in target_types:
    
        print(f"\n\n{'='*80}")
        print(f"--- Starting Burgers' Eval for Target: '{target_type.upper()}' ---")
        print(f"--- Using Recurrent Model Run ID: '{args.run_id}' ---")
        print(f"{'='*80}")
        
        config['target_type'] = target_type
        sensor_grid = np.linspace(0, config['L'], config['M_SENSORS'])
        initial_x_np = get_initial_condition(config, sensor_grid)
        target_x_np = generate_target_profile(config, target_type, sensor_grid)

        recurrent_states, recurrent_controls, recurrent_time = run_recurrent_controller_evaluation(args, config, target_x_np, initial_x_np)
        pod_mpc_states, pod_mpc_controls, pod_mpc_time = run_pod_mpc_evaluation(config, target_x_np, initial_x_np)
        nmpc_states, nmpc_controls, nmpc_time = run_adjoint_nmpc_evaluation(config, target_x_np, initial_x_np)

        results = {
            'recurrent': {'states': recurrent_states, 'controls': recurrent_controls, 'time': recurrent_time},
            'pod_mpc': {'states': pod_mpc_states, 'controls': pod_mpc_controls, 'time': pod_mpc_time},
            'nmpc': {'states': nmpc_states, 'controls': nmpc_controls, 'time': nmpc_time},
            'target_x': target_x_np,
            'initial_x': initial_x_np
        }
        all_results[target_type] = results

        x_grid_classical = np.linspace(0, config['L'], config['NX_SOLVER'])
        pod_final_mse = np.mean((np.interp(sensor_grid, x_grid_classical, results['pod_mpc']['states'][-1]) - target_x_np)**2)
        nmpc_final_mse = np.mean((np.interp(sensor_grid, x_grid_classical, results['nmpc']['states'][-1]) - target_x_np)**2)
        print(f"\n--- SUMMARY FOR '{target_type.upper()}' ---")
        print(f"{'Method':<20} | {'Time Taken (s)':<20} | {'Final MSE':<20}")
        print("-" * 65)
        print(f"{'Recurrent NN':<20} | {results['recurrent']['time']:<20.4f} | {np.mean((results['recurrent']['states'][-1] - target_x_np)**2):<20.4e}")
        print(f"{'POD-MPC':<20} | {results['pod_mpc']['time']:<20.4f} | {pod_final_mse:<20.4e}")
        print(f"{'Adjoint-NMPC':<20} | {results['nmpc']['time']:<20.4f} | {nmpc_final_mse:<20.4e}")
        print("-" * 65)

    output_dir = os.path.join(args.output_base_dir, "comparison_plots_burgers_combined")
    generate_comparison_plots(all_results, sensor_grid, config, output_dir)
    
    results_save_path = os.path.join(output_dir, "all_simulation_results.npz")
    np.savez_compressed(results_save_path, **all_results)
    print(f"\nSaved all numerical results to: {results_save_path}")

    print("\n--- Unified evaluation for all targets complete. ---")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Unified evaluation of Recurrent, POD-MPC, and Adjoint-NMPC controllers for Burgers' Eq.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True, help="The run_id of the trained recurrent controller.")
    args = parser.parse_args()
    main(args)