

import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
import time

from data_and_models import PropagatorDeepONet, RecurrentController
from scipy.optimize import fsolve, minimize
from scipy import sparse
from scipy.sparse.linalg import spsolve

def generate_target_profile(config, target_type, x_grid):
    print(f"Generating '{target_type}' target profile for a grid of size {len(x_grid)}...")
    if target_type == 'sine': return 0.8 * np.sin(np.pi * x_grid / config['L'])
    elif target_type == 'parabola': return 4 * 0.5 * x_grid * (config['L'] - x_grid)
    elif target_type == 'zero': return np.zeros_like(x_grid)
    else: raise ValueError(f"Invalid target type: {target_type}")

def get_initial_condition(config, x_grid):
    ic = 0.5 * np.sin(2 * np.pi * x_grid / config['L'])
    ic[0], ic[-1] = 0.0, 0.0
    return ic


def run_recurrent_controller_evaluation(args, config, target_x_np, initial_x_np):
    print("\n--- [1/3] Evaluating Recurrent NN Controller (Direct) ---")
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    controller_run_dir = os.path.join(args.output_base_dir, args.run_id)
    controller_model_path = os.path.join(controller_run_dir, "burgers_controller_model.pth")
    with open(os.path.join(controller_run_dir, "hyperparams.yaml"), 'r') as f:
        controller_hyperparams = yaml.safe_load(f)
    controller_arg_keys = ['hidden_dim', 'num_layers', 'activation_fn']
    controller_kwargs = {k: v for k, v in controller_hyperparams.items() if k in controller_arg_keys}
    controller = RecurrentController(M_sensors=config['M_SENSORS'], num_basis_functions=config['M_SENSORS'], control_scale=config['CONTROL_SCALE'], **controller_kwargs).to(DEVICE)
    controller.load_state_dict(torch.load(controller_model_path, map_location=DEVICE)); controller.eval()
    deeponet_run_id = controller_hyperparams['deeponet_run_id']
    deeponet_run_dir = os.path.join(args.output_base_dir, deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "burgers_propagator_best.pth")
    with open(os.path.join(deeponet_run_dir, "hyperparams.yaml"), 'r') as f:
        deeponet_hyperparams = yaml.safe_load(f)
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: deeponet_hyperparams[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['M_SENSORS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE)); physics_simulator.eval()
    print("Successfully loaded recurrent controller and surrogate model.")
    target_x_torch = torch.from_numpy(target_x_np).float().unsqueeze(0).to(DEVICE)
    x_current = torch.from_numpy(initial_x_np).float().unsqueeze(0).to(DEVICE)
    x_grid_sensors_torch = torch.linspace(0, config['L'], config['M_SENSORS'], device=DEVICE).unsqueeze(0).unsqueeze(-1)
    state_history, control_history = [x_current.cpu().numpy().flatten()], []
    start_time = time.time()
    with torch.no_grad():
        hidden_state = None
        for _ in range(config['NT_SOLVER'] - 1):
            u_k_at_sensors, hidden_state = controller(x_current, target_x_torch, hidden_state)
            control_history.append(u_k_at_sensors.cpu().numpy().flatten())
            x_next_pred = physics_simulator(x_current, u_k_at_sensors, x_grid_sensors_torch).squeeze(-1)
            x_current = x_next_pred
            state_history.append(x_current.cpu().numpy().flatten())
    total_time = time.time() - start_time
    final_mse = np.mean((state_history[-1] - target_x_np)**2)
    print(f"Recurrent controller finished in {total_time:.4f}s. Final MSE (surrogate): {final_mse:.4e}")
    return np.array(state_history), np.array(control_history), total_time

def run_pod_mpc_direct_evaluation(config, target_x_np_sensors, initial_x_np_sensors):
    print("\n--- [2/3] Evaluating POD-Based MPC (Direct) ---")
    L, nu = config['L'], config['VISCOSITY']
    Nx, Nt_sim = config['NX_SOLVER'], config['NT_SOLVER']
    T_sim = config['T_FINAL']
    dt = T_sim / (Nt_sim - 1) if Nt_sim > 1 else 0.0
    N_horizon, dt_mpc, Q_weight, R_weight = 1, 0.02, 1, 1e-5
    r_pod_modes, f_min, f_max = 5, -1.0, 1.0
    x, dx = np.linspace(0, L, Nx), L / (Nx - 1)
    sensor_grid = np.linspace(0, L, config['M_SENSORS'])
    u_ref, u0 = np.interp(x, sensor_grid, target_x_np_sensors), np.interp(x, sensor_grid, initial_x_np_sensors)
    D, D2 = np.zeros((Nx, Nx)), np.zeros((Nx, Nx))
    for i in range(1, Nx - 1): D[i, i-1], D[i, i+1] = -1/(2*dx), 1/(2*dx)
    for i in range(1, Nx - 1): D2[i, i-1], D2[i, i], D2[i, i+1] = 1/dx**2, -2/dx**2, 1/dx**2
    def cn_implicit_eq(u_next, u_prev, f_prev, f_next, dt_local):
        adv_prev, diff_prev = u_prev * (D @ u_prev), nu * (D2 @ u_prev)
        adv_next, diff_next = u_next * (D @ u_next), nu * (D2 @ u_next)
        return (u_next - u_prev) - 0.5 * dt_local * ((-adv_prev + diff_prev + f_prev) + (-adv_next + diff_next + f_next))
    def full_cn_step(u_prev, f_prev, f_next, dt_local, guess=None):
        u_next = fsolve(cn_implicit_eq, guess if guess is not None else u_prev, args=(u_prev, f_prev, f_next, dt_local), xtol=1e-9, maxfev=100)
        u_next[0], u_next[-1] = 0.0, 0.0
        return u_next
    def generate_snapshots_direct(num_snaps=100, amp=1.0):
        print("Generating snapshots for POD with direct forcing...")
        U_snap, snap_idx = np.zeros((Nx, num_snaps * 10)), 0
        u, f_prev = u0.copy(), np.zeros(Nx)
        for _ in range(num_snaps):
            for i in range(50):
                f_rand = sum(amp * (2*np.random.rand()-1) * np.sin(np.random.randint(1,5) * np.pi * x/L) for _ in range(np.random.randint(1,4)))
                f_rand[0], f_rand[-1] = 0.0, 0.0
                u = full_cn_step(u, f_prev, f_rand, dt)
                f_prev = f_rand
                if i % 5 == 0: U_snap[:, snap_idx], snap_idx = u, snap_idx + 1
        return U_snap[:, :snap_idx]
    U_snap = generate_snapshots_direct()
    u_mean = np.mean(U_snap, axis=1)
    V_pod = np.linalg.svd(U_snap - u_mean[:, None], full_matrices=False)[0][:, :r_pod_modes]
    z_ref = V_pod.T @ (u_ref - u_mean)
    def reduced_cn_step(z_prev, f_prev, f_next, dt_local, guess=None):
        def reduced_implicit(z_next):
            u_prev_full, u_next_full = V_pod @ z_prev + u_mean, V_pod @ z_next + u_mean
            return V_pod.T @ cn_implicit_eq(u_next_full, u_prev_full, f_prev, f_next, dt_local)
        return fsolve(reduced_implicit, guess if guess is not None else z_prev, xtol=1e-9, maxfev=60)
    def mpc_cost_direct(f_seq_flat, z_current):
        f_seq = f_seq_flat.reshape((N_horizon, Nx))
        cost, z_pred, f_prev = 0.0, z_current.copy(), np.zeros(Nx)
        for i in range(N_horizon):
            f_next = f_seq[i]
            z_next = reduced_cn_step(z_pred, f_prev, f_next, dt_mpc, guess=z_pred)
            cost += Q_weight * np.sum((z_next - z_ref)**2) + R_weight * np.sum(f_next**2)
            z_pred, f_prev = z_next, f_next
        cost += Q_weight * np.sum((z_pred - z_ref)**2)
        return cost
    bounds = [(0,0) if j==0 or j==Nx-1 else (f_min, f_max) for k in range(N_horizon) for j in range(Nx)]
    u_history, f_history = np.zeros((Nx, Nt_sim)), np.zeros((Nx, Nt_sim))
    u, z_current, f_plan_prev = u0.copy(), V_pod.T @ (u0 - u_mean), np.zeros((N_horizon, Nx))
    start_time = time.time()
    for k in range(Nt_sim - 1):
        f_init = np.vstack([f_plan_prev[1:], f_plan_prev[-1:]]).ravel()
        res = minimize(mpc_cost_direct, f_init, args=(z_current), method='SLSQP', bounds=bounds, options={'maxiter': 200, 'ftol': 1e-6})
        f_plan_prev = res.x.reshape((N_horizon, Nx))
        f_apply = f_plan_prev[0]
        u = full_cn_step(u, f_apply, f_apply, dt)
        z_current = V_pod.T @ (u - u_mean)
        u_history[:, k+1], f_history[:, k] = u, f_apply
        print(f"POD-MPC (Direct) Step {k+1}/{Nt_sim-1} completed. Cost: {res.fun:.4e}", end='\r')
    total_time = time.time() - start_time
    u_history[:, 0] = u0
    state_history, control_history = u_history.T, f_history.T
    final_state_sensors = np.interp(sensor_grid, x, state_history[-1])
    final_mse = np.mean((final_state_sensors - target_x_np_sensors)**2)
    print(f"\nPOD-MPC (Direct) finished in {total_time:.4f}s. Final MSE: {final_mse:.4e}")
    return state_history, control_history, total_time

def run_adjoint_nmpc_direct_evaluation(config, target_x_np_sensors, initial_x_np_sensors):
    print("\n--- [3/3] Evaluating Adjoint-NMPC (Direct) ---")
    L, nu = config['L'], config['VISCOSITY']
    Nx, Nt = config['NX_SOLVER'], config['NT_SOLVER']
    dt = config['T_FINAL'] / (Nt - 1) if Nt > 1 else 0.0
    H, max_opt_iters, alpha, v_min, v_max = 10, 40, 5e-5, -1.0, 1.0
    x, dx = np.linspace(0, L, Nx), L / (Nx - 1)
    sensor_grid = np.linspace(0, L, config['M_SENSORS'])
    u_ref, u0 = np.interp(x, sensor_grid, target_x_np_sensors), np.interp(x, sensor_grid, initial_x_np_sensors)
    R = alpha * sparse.eye(Nx, format='csr')
    D = sparse.diags([-1,1], [-1,1], shape=(Nx,Nx))/(2*dx); D2 = sparse.diags([1,-2,1],[-1,0,1], shape=(Nx,Nx))/dx**2
    I_sp = sparse.eye(Nx, format='csr')
    def enforce_bcs_mat(A): A=A.tolil(); A[0,:],A[-1,:]=0.0,0.0; A[0,0],A[-1,-1]=1.0,1.0; return A.tocsr()
    def enforce_bcs_rhs(v): v[0],v[-1]=0.0,0.0; return v
    def advective(u): return D.dot(0.5*(u**2))
    def dRdu(u): return -D.dot(sparse.diags(u)) + nu * D2
    def solve_cn_step(u_n, f_n):
        u_np1, Rn = u_n.copy(), -advective(u_n)+nu*D2.dot(u_n)+f_n
        for _ in range(25):
            Rnp1 = -advective(u_np1)+nu*D2.dot(u_np1)+f_n
            F = enforce_bcs_rhs(u_np1-u_n-0.5*dt*(Rn+Rnp1))
            J = enforce_bcs_mat(I_sp-0.5*dt*dRdu(u_np1))
            delta = spsolve(J,-F)
            u_np1 += delta
            if np.linalg.norm(delta)<1e-10: break
        return u_np1, dRdu(u_n), dRdu(u_np1)
    def compute_cost_and_grad(u_ctrl_flat, u_init, Hk):
        U_ctrl_seq = u_ctrl_flat.reshape((Hk, Nx))
        U, A_list, B_list = np.zeros((Hk+1,Nx)), [], []
        U[0] = u_init
        for n in range(Hk):
            u_np1, dRdu_n, dRdu_np1 = solve_cn_step(U[n], U_ctrl_seq[n])
            U[n+1] = u_np1
            A_list.append(enforce_bcs_mat(I_sp-0.5*dt*dRdu_np1))
            B_list.append(enforce_bcs_mat(-I_sp-0.5*dt*dRdu_n))
        J = 0.5*dt*np.sum((U[1:]-u_ref)**2) + 0.5*dt*sum(c@R@c for c in U_ctrl_seq)
        p_next, grad = dt*(U[-1]-u_ref), np.zeros_like(U_ctrl_seq)
        for n in reversed(range(Hk)):
            q = spsolve(A_list[n].T, p_next)
            grad[n,:] = dt * (R @ U_ctrl_seq[n] + q)
            p_next = dt*(U[n]-u_ref) - (B_list[n].T @ q)
        return J, grad.ravel()
    history_u, history_control = np.zeros((Nt,Nx)), np.zeros((Nt-1,Nx))
    history_u[0] = u0
    u_current = u0.copy()
    start_time = time.time()
    for k in range(Nt - 1):
        Hk = min(H, Nt - 1 - k)
        bounds = [(v_min,v_max) for _ in range(Hk*Nx)]
        res = minimize(lambda u: compute_cost_and_grad(u, u_current, Hk), np.zeros(Hk*Nx), method='L-BFGS-B', jac=True, bounds=bounds, options={'maxiter':max_opt_iters, 'ftol':1e-6})
        u_apply = res.x.reshape((Hk, Nx))[0]
        u_current, _, _ = solve_cn_step(u_current, u_apply)
        history_u[k+1], history_control[k] = u_current, u_apply
        print(f"Adjoint-NMPC (Direct) Step {k+1}/{Nt-1}: J={res.fun:.4e}", end='\r')
    total_time = time.time() - start_time
    final_state_sensors = np.interp(sensor_grid, x, history_u[-1])
    final_mse = np.mean((final_state_sensors - target_x_np_sensors)**2)
    print(f"\nAdjoint-NMPC (Direct) finished in {total_time:.4f}s. Final MSE: {final_mse:.4e}")
    return history_u, history_control, total_time

def generate_comparison_plots(all_results, sensor_grid, config, output_dir):

    print(f"\n--- Generating Combined Comparison Plot for All Targets (Direct Control) ---")
    os.makedirs(output_dir, exist_ok=True)
    
    font_size = 40
    legend_font_size = 40

    fig, axes = plt.subplots(1, 3, figsize=(38, 9), sharey=False)
    
    target_order = ['sine', 'parabola', 'zero']

    for i, target_type in enumerate(target_order):
        ax = axes[i]
        results = all_results[target_type]
        target_x_np = results['target_x']
        initial_x_np = results['initial_x']
        
        ax.plot(sensor_grid, target_x_np, 'k:', lw=4, label=r'$y_{\text{target}}(x)$')
        ax.plot(sensor_grid, initial_x_np, 'b--', lw=2, label='$y(0, x)$')
        
        x_grid_classical = np.linspace(0, config['L'], config['NX_SOLVER'])
        pod_final = np.interp(sensor_grid, x_grid_classical, results['pod_mpc']['states'][-1])
        nmpc_final = np.interp(sensor_grid, x_grid_classical, results['nmpc']['states'][-1])
        
        ax.plot(sensor_grid, results['recurrent']['states'][-1], 'r-', lw=2.5, label='PDE-OP')
        ax.plot(sensor_grid, nmpc_final, 'm:', lw=3.5, label='Adjoint-Method')
        ax.plot(sensor_grid, pod_final, 'c-.', lw=2.5, label='NMPC')
        
        ax.set_xlabel('x', fontsize=font_size)
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.tick_params(axis='both', which='major', labelsize=font_size)
        ax.set_ylabel('$y(t, x)$', fontsize=font_size)
    
    handles, labels = axes[0].get_legend_handles_labels()
    
    fig.legend(handles, labels, loc='lower center', 
               bbox_to_anchor=(0.5, -0.05), ncol=5, fontsize=legend_font_size)
               
    fig.subplots_adjust(bottom=0.3, wspace=0.3)
    
    save_path = os.path.join(output_dir, "comparison_final_state_ALL_TARGETS_direct.pdf")
    plt.savefig(save_path, bbox_inches='tight')
    plt.close(fig)
    print(f"\nSaved combined comparison plot to: {save_path}")

def main(args):
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)

    all_results = {}
    target_types = ['sine', 'parabola', 'zero']

    for target_type in target_types:
        config['target_type'] = target_type
        print(f"\n\n{'='*80}")
        print(f"--- Starting Burgers' Eval (Direct Control) for Target: '{target_type.upper()}' ---")
        print(f"--- Using Recurrent Model Run ID: '{args.run_id}' ---")
        print(f"{'='*80}")

        sensor_grid = np.linspace(0, config['L'], config['M_SENSORS'])
        initial_x_np = get_initial_condition(config, sensor_grid)
        target_x_np = generate_target_profile(config, target_type, sensor_grid)

        rec_states, rec_controls, rec_time = run_recurrent_controller_evaluation(args, config, target_x_np, initial_x_np)
        nmpc_states, nmpc_controls, nmpc_time = run_adjoint_nmpc_direct_evaluation(config, target_x_np, initial_x_np)
        pod_states, pod_controls, pod_time = run_pod_mpc_direct_evaluation(config, target_x_np, initial_x_np) 
        
        results = {
            'recurrent': {'states': rec_states, 'controls': rec_controls, 'time': rec_time},
            'pod_mpc': {'states': pod_states, 'controls': pod_controls, 'time': pod_time},
            'nmpc': {'states': nmpc_states, 'controls': nmpc_controls, 'time': nmpc_time},
            'target_x': target_x_np,
            'initial_x': initial_x_np
        }
        all_results[target_type] = results

        x_grid_classical = np.linspace(0, config['L'], config['NX_SOLVER'])
        pod_final_mse = np.mean((np.interp(sensor_grid, x_grid_classical, results['pod_mpc']['states'][-1]) - target_x_np)**2)
        nmpc_final_mse = np.mean((np.interp(sensor_grid, x_grid_classical, results['nmpc']['states'][-1]) - target_x_np)**2)
        print(f"\n--- SUMMARY FOR '{target_type.upper()}' (Direct Control) ---")
        print(f"{'Method':<25} | {'Time Taken (s)':<20} | {'Final MSE':<20}")
        print("-" * 70)
        print(f"{'Recurrent NN (Direct)':<25} | {results['recurrent']['time']:<20.4f} | {np.mean((results['recurrent']['states'][-1] - target_x_np)**2):<20.4e}")
        print(f"{'POD-MPC (Direct)':<25} | {results['pod_mpc']['time']:<20.4f} | {pod_final_mse:<20.4e}")
        print(f"{'Adjoint-NMPC (Direct)':<25} | {results['nmpc']['time']:<20.4f} | {nmpc_final_mse:<20.4e}")
        print("-" * 70)

    output_dir = os.path.join(args.output_base_dir, args.run_id, "comparison_plots_direct_combined")
    generate_comparison_plots(all_results, sensor_grid, config, output_dir)

    results_save_path = os.path.join(output_dir, "all_simulation_results_direct.npz")
    np.savez_compressed(results_save_path, **all_results)
    print(f"\nSaved all numerical results to: {results_save_path}")

    print("\n--- Unified evaluation for all targets complete. ---")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Unified evaluation for direct control of Burgers' Eq.")
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_base_dir", type=str, required=True)
    parser.add_argument("--run_id", type=str, required=True)
    args = parser.parse_args()
    main(args)