import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
import time

# --- Dependencies for Recurrent Controller ---
from data_and_models import PropagatorDeepONet, RecurrentController

# --- Dependencies for Adjoint-based Controller ---
import scipy.sparse as sp
from scipy.sparse.linalg import spsolve
from scipy.optimize import minimize

# --- Dependencies for MPC Controller ---
import cvxpy as cp

def generate_target_profile(target_type, x_grid, L):
    """
    Generates a consistent target temperature profile based on the type.
    """
    print(f"Generating '{target_type}' target profile...")
    if target_type == 'sine':
        return 0.6 + 0.3 * np.sin(2 * np.pi * x_grid / L)
    elif target_type == 'ramp':
        return np.linspace(0.5, 1.5, len(x_grid))
    elif target_type == 'constant':
        return np.full(len(x_grid), 1.0)
    else:
        raise ValueError(f"Invalid target type: {target_type}")

def run_recurrent_evaluation(args, config, target_T_np):
    """
    Loads and evaluates the pre-trained Recurrent Neural Network controller.
    """
    print("\n--- [1/3] Evaluating Recurrent Controller ---")
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    controller_run_dir = os.path.join(args.output_base_dir, args.run_id)
    controller_model_path = os.path.join(controller_run_dir, "recurrent_controller_model.pth")
    with open(os.path.join(controller_run_dir, "hyperparams.yaml"), 'r') as f:
        controller_hyperparams = yaml.safe_load(f)
    controller_kwargs = {k: v for k, v in controller_hyperparams.items() if k in ['hidden_dim', 'num_layers', 'activation_fn']}
    controller = RecurrentController(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], **controller_kwargs).to(DEVICE)
    controller.load_state_dict(torch.load(controller_model_path, map_location=DEVICE))
    controller.eval()
    deeponet_run_id = controller_hyperparams['deeponet_run_id']
    deeponet_run_dir = os.path.join(args.output_base_dir, deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "propagator_deeponet_best.pth")
    with open(os.path.join(deeponet_run_dir, "hyperparams.yaml"), 'r') as f:
        deeponet_hyperparams = yaml.safe_load(f)
    model_arg_keys = ['branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'latent_dim', 'activation_fn']
    deeponet_kwargs = {key: deeponet_hyperparams[key] for key in model_arg_keys}
    physics_simulator = PropagatorDeepONet(M_sensors=config['M_SENSORS'], num_basis_functions=config['NUM_BASIS_FUNCTIONS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE))
    physics_simulator.eval()
    print("Successfully loaded recurrent controller and surrogate model.")
    target_T_torch = torch.from_numpy(target_T_np).float().unsqueeze(0).to(DEVICE)
    T_current = torch.full((1, config['M_SENSORS']), 0.0, device=DEVICE)
    controller_hidden_state = None
    sensor_locs_np = np.linspace(0, config['L'], config['M_SENSORS'])
    sensor_locs_torch = torch.from_numpy(sensor_locs_np).float().to(DEVICE).unsqueeze(0).unsqueeze(-1)
    state_history, control_history = [T_current.cpu().numpy().flatten()], []
    start_time = time.time()
    with torch.no_grad():
        for _ in range(config['NT_SOLVER']):
            w_k, controller_hidden_state = controller(T_current, target_T_torch, controller_hidden_state)
            control_history.append(w_k.cpu().numpy().flatten())
            T_current = physics_simulator(T_current, w_k, sensor_locs_torch).squeeze(-1)
            state_history.append(T_current.cpu().numpy().flatten())
    total_time = time.time() - start_time
    final_mse = np.mean((state_history[-1] - target_T_np)**2)
    print(f"Recurrent controller finished in {total_time:.2f}s. Final MSE: {final_mse:.4e}")
    return np.array(state_history), np.array(control_history)

def run_adjoint_evaluation(config, target_T_np):
    """
    Runs the simulation using the Adjoint-based optimal control method.
    """
    print("\n--- [2/3] Evaluating Adjoint-Based Controller ---")
    L, D, alpha = config['L'], config['D'], config['ALPHA']
    beta, T_amb = 0.5, 0.0
    m = config['NUM_BASIS_FUNCTIONS']
    N = config['M_SENSORS'] - 1
    dx = L / N
    dt = config['T_FINAL'] / config['NT_SOLVER']
    n_steps = config['NT_SOLVER']
    Np = 10
    Q = 2*np.eye(N + 1)
    R = 1e-5 * np.eye(m)
    u_min, u_max = -1.0, 1.0
    inner_maxiter = 30

    def get_state_space_model(N, dx, dt, D, alpha, m, beta=0.0, T_amb=0.0):
        r = D * dt / (2 * dx**2); s = beta * dt / 2.0
        ml_diag = np.full(N + 1, 1 + 2*r + s); mr_diag = np.full(N + 1, 1 - 2*r - s)
        off_diag = np.full(N, -r)
        M_L = sp.diags([off_diag, ml_diag, off_diag], [-1, 0, 1], format='csc')
        M_R = sp.diags([-off_diag, mr_diag, -off_diag], [-1, 0, 1], format='csc')
        M_L[0, 1] = -2*r; M_R[0, 1] = 2*r
        M_L[N, N-1] = -2*r; M_R[N, N-1] = 2*r
        x_grid_local = np.linspace(0.0, L, N+1)
        B_c = np.zeros((N + 1, m))
        for j in range(m): B_c[:, j] = alpha * np.cos(j * np.pi * x_grid_local / L)
        s_vec = beta * T_amb * np.ones(N + 1)
        A = spsolve(M_L, M_R); B = spsolve(M_L, dt * B_c); g = spsolve(M_L, dt * s_vec)
        return A, B, g

    def horizon_cost_and_grad(u_horizon_flat, x0, horizon_len, x_ref_target, A, B, g):
        U_h = u_horizon_flat.reshape((horizon_len, m))
        X_h = np.zeros((horizon_len + 1, N + 1)); X_h[0] = x0
        for t in range(horizon_len): X_h[t+1] = A @ X_h[t] + B @ U_h[t] + g
        J = sum((X_h[t] - x_ref_target) @ (Q @ (X_h[t] - x_ref_target)) for t in range(horizon_len + 1))
        J += sum(U_h[t] @ (R @ U_h[t]) for t in range(horizon_len))
        Lambda = np.zeros((horizon_len + 1, N + 1))
        Lambda[-1] = 2.0 * (Q @ (X_h[-1] - x_ref_target))
        for t in reversed(range(horizon_len)): 
            Lambda[t] = A.T @ Lambda[t+1] + 2.0 * (Q @ (X_h[t] - x_ref_target))
        grad = np.array([2.0 * (R @ U_h[t]) + B.T @ Lambda[t+1] for t in range(horizon_len)])
        return float(J), grad.ravel()
        
    A, B, g = get_state_space_model(N, dx, dt, D, alpha, m, beta, T_amb)
    x_current = np.zeros(N + 1)
    history_x, history_u = [x_current], []
    prev_horizon_flat = None
    start_time = time.time()
    for k in range(n_steps):
        horizon_len = min(Np, n_steps - k)
        u0_flat = np.zeros(horizon_len * m)
        if prev_horizon_flat is not None and prev_horizon_flat.reshape((-1, m)).shape[0] > 1:
            prev = prev_horizon_flat.reshape((-1, m))
            shifted = np.vstack((prev[1:], np.zeros((1, m))))
            if shifted.shape[0] < horizon_len:
                shifted = np.vstack((shifted, np.zeros((horizon_len - shifted.shape[0], m))))
            u0_flat = shifted[:horizon_len].ravel()
        bounds = [(u_min, u_max)] * (horizon_len * m)
        res = minimize(lambda u: horizon_cost_and_grad(u, x_current, horizon_len, target_T_np, A, B, g), u0_flat, method='L-BFGS-B', jac=True, bounds=bounds, options={'maxiter': inner_maxiter, 'ftol': 1e-6})
        u_opt_flat = res.x if res.success else u0_flat
        prev_horizon_flat = u_opt_flat
        u_apply = u_opt_flat.reshape((horizon_len, m))[0]
        x_next = A @ x_current + B @ u_apply + g
        history_x.append(x_next); history_u.append(u_apply)
        x_current = x_next
    total_time = time.time() - start_time
    final_mse = np.mean((history_x[-1] - target_T_np)**2)
    print(f"Adjoint controller finished in {total_time:.2f}s. Final MSE: {final_mse:.4e}")
    return np.array(history_x), np.array(history_u)


def run_mpc_evaluation(config, target_T_np):
    """
    Runs the simulation using the Model Predictive Control (MPC) method.
    """
    print("\n--- [3/3] Evaluating MPC Controller ---")
    L, D, alpha = config['L'], config['D'], config['ALPHA']
    beta, T_amb = 0.5, 0.0
    m = config['NUM_BASIS_FUNCTIONS']
    N = config['M_SENSORS'] - 1
    dx = L / N
    dt = config['T_FINAL'] / config['NT_SOLVER']
    n_steps = config['NT_SOLVER']
    Np = 10
    Q = 2*np.eye(N + 1)
    T_Q = 1e-1 * np.eye(N + 1)
    R = 1e-5 * np.eye(m)
    u_min, u_max = -1.0, 1.0

    def get_state_space_model(N, dx, dt, D, alpha, m, beta=0.0, T_amb=0.0):
        r = D * dt / (2 * dx**2); s = beta * dt / 2.0
        ml_diag = np.full(N + 1, 1 + 2*r + s); mr_diag = np.full(N + 1, 1 - 2*r - s)
        off_diag = np.full(N, -r)
        M_L = sp.diags([off_diag, ml_diag, off_diag], [-1, 0, 1], format='csc')
        M_R = sp.diags([-off_diag, mr_diag, -off_diag], [-1, 0, 1], format='csc')
        M_L[0, 1] = -2*r; M_R[0, 1] = 2*r
        M_L[N, N-1] = -2*r; M_R[N, N-1] = 2*r
        x_grid_local = np.linspace(0.0, L, N+1)
        B_c = np.zeros((N + 1, m))
        for j in range(m): B_c[:, j] = alpha * np.cos(j * np.pi * x_grid_local / L)
        s_vec = beta * T_amb * np.ones(N + 1)
        A = spsolve(M_L, M_R); B = spsolve(M_L, dt * B_c); g = spsolve(M_L, dt * s_vec)
        return A, B, g

    A, B, g = get_state_space_model(N, dx, dt, D, alpha, m, beta, T_amb)
    nx, nu = B.shape
    x_k = cp.Parameter(nx); x_ref = cp.Parameter(nx)
    u_seq = cp.Variable((nu, Np)); x_seq = cp.Variable((nx, Np + 1))
    cost = 0
    constraints = [x_seq[:, 0] == x_k]
    for i in range(Np):
        cost += cp.quad_form(x_seq[:, i] - x_ref, T_Q) + cp.quad_form(u_seq[:, i], R)
        constraints += [x_seq[:, i+1] == A @ x_seq[:, i] + B @ u_seq[:, i] + g]
        constraints += [u_min <= u_seq[:, i], u_seq[:, i] <= u_max]
    cost += cp.quad_form(x_seq[:, Np] - x_ref, Q)
    problem = cp.Problem(cp.Minimize(cost), constraints)
    x_current = np.zeros(nx)
    history_x, history_u = [x_current], []
    start_time = time.time()
    for k in range(n_steps):
        x_k.value = x_current; x_ref.value = target_T_np
        problem.solve(solver=cp.OSQP, warm_start=True, verbose=False)
        u_optimal = u_seq[:, 0].value if problem.status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE] else np.zeros(nu)
        x_next = A @ x_current + B @ u_optimal + g
        history_x.append(x_next); history_u.append(u_optimal)
        x_current = x_next
    total_time = time.time() - start_time
    final_mse = np.mean((history_x[-1] - target_T_np)**2)
    print(f"MPC controller finished in {total_time:.2f}s. Final MSE: {final_mse:.4e}")
    return np.array(history_x), np.array(history_u)


def generate_comparison_plots(all_results, x_grid, config, output_dir):
    """
    Generates a single, combined plot showing the final state comparison for
    all three target types in a single row, with a shared legend below.
    Each subplot has its own independently scaled y-axis.
    """
    print(f"\n--- Generating Combined Comparison Plot for All Targets ---")
    os.makedirs(output_dir, exist_ok=True)
    
    font_size = 40
    legend_font_size = 40

    fig, axes = plt.subplots(1, 3, figsize=(38, 9), sharey=False)
    
    target_order = ['sine', 'ramp', 'constant']

    for i, target_type in enumerate(target_order):
        ax = axes[i]
        results = all_results[target_type]
        target_T_np = results['target_T']
        
        ax.plot(x_grid, target_T_np, 'k:', lw=4, label=r'$y_{\text{target}}(x)$')
        ax.plot(x_grid, results['recurrent']['states'][0], 'b--', lw=2.5, label='$y(0, x)$')
        ax.plot(x_grid, results['recurrent']['states'][-1], 'r-', lw=2.5, label='PDE-OP')
        ax.plot(x_grid, results['adjoint']['states'][-1], 'm-.', lw=2.5, label='Adjoint Method')
        ax.plot(x_grid, results['mpc']['states'][-1], 'c:', lw=2.5, label='LMPC')
        
        ax.set_xlabel('x', fontsize=font_size)
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.tick_params(axis='both', which='major', labelsize=font_size)
        ax.set_ylabel('$y(t, x)$', fontsize=font_size)
    
    handles, labels = axes[0].get_legend_handles_labels()
    
    fig.legend(handles, labels, loc='lower center', 
               bbox_to_anchor=(0.5, -0.05), ncol=5, fontsize=legend_font_size)
               
    fig.subplots_adjust(bottom=0.3, wspace=0.3)
    
    save_path = os.path.join(output_dir, "comparison_final_state_ALL_TARGETS.pdf")
    plt.savefig(save_path, bbox_inches='tight')
    plt.close(fig)
    print(f"\nSaved combined comparison plot to: {save_path}")


def main(args):
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    all_results = {}
    target_types = ['sine', 'ramp', 'constant']
    x_grid = np.linspace(0, config['L'], config['M_SENSORS'])

    for target_type in target_types:
        print(f"\n\n{'='*80}")
        print(f"--- Starting Unified Evaluation for Target: '{target_type.upper()}' ---")
        print(f"Recurrent Controller Run ID: '{args.run_id}'")
        print(f"{'='*80}")
        
        target_T_np = generate_target_profile(target_type, x_grid, config['L'])
        
        recurrent_states, recurrent_controls = run_recurrent_evaluation(args, config, target_T_np)
        adjoint_states, adjoint_controls = run_adjoint_evaluation(config, target_T_np)
        mpc_states, mpc_controls = run_mpc_evaluation(config, target_T_np)
        
        all_results[target_type] = {
            'recurrent': {'states': recurrent_states, 'controls': recurrent_controls},
            'adjoint': {'states': adjoint_states, 'controls': adjoint_controls},
            'mpc': {'states': mpc_states, 'controls': mpc_controls},
            'target_T': target_T_np
        }
        
        print(f"\n--- SUMMARY FOR '{target_type.upper()}' ---")
        print(f"{'Method':<20} | {'Final MSE':<20}")
        print("-" * 45)
        print(f"{'Recurrent':<20} | {np.mean((recurrent_states[-1] - target_T_np)**2):<20.4e}")
        print(f"{'Adjoint':<20} | {np.mean((adjoint_states[-1] - target_T_np)**2):<20.4e}")
        print(f"{'MPC':<20} | {np.mean((mpc_states[-1] - target_T_np)**2):<20.4e}")
        print("-" * 45)

    output_dir = os.path.join(args.output_base_dir, args.run_id, "comparison_plots_combined")
    generate_comparison_plots(all_results, x_grid, config, output_dir)

    results_save_path = os.path.join(output_dir, "all_simulation_results.npz")
    np.savez_compressed(results_save_path, **all_results)
    print(f"\nSaved all numerical results to: {results_save_path}")
    
    print("\n--- Unified evaluation for all targets complete. ---")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Unified evaluation of Recurrent, Adjoint, and MPC controllers.")
    parser.add_argument("--config_path", type=str, required=True, help="Path to the main YAML config file.")
    parser.add_argument("--output_base_dir", type=str, required=True, help="Base directory where runs are stored.")
    parser.add_argument("--run_id", type=str, required=True, help="The run_id of the trained recurrent controller.")
    args = parser.parse_args()
    main(args)