
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import argparse
import time

# --- Dependencies for NN Controller ---
from data_and_models import DeepONetWithBias
from train_controller_direct import DirectDecisionMaker

# --- Dependencies for Classical Optimizers ---
from scipy.optimize import minimize
from scipy.sparse import spdiags
from scipy.sparse.linalg import spsolve

def generate_target_profile(config, target_type, sensor_locations):
    """
    Generates a consistent target profile based on the type specified.
    """
    print(f"Generating '{target_type}' target profile...")
    if target_type == 'sine':
        return config['V_REF_VAL'] + 0.2 * np.sin(3.0 * 2 * np.pi * sensor_locations)
    elif target_type == 'ramp':
        return np.linspace(0.5, 1.5, len(sensor_locations))
    elif target_type == 'constant':
        return np.full(len(sensor_locations), config['V_REF_VAL'])
    else:
        raise ValueError(f"Invalid target type: {target_type}")


def run_nn_controller_evaluation(args, config, target_zeta_np):
    """
    Loads and evaluates the pre-trained direct decision-making NN controller.
    """
    print("\n--- [1/3] Evaluating Neural Network Controller ---")
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    # --- Load Controller and Surrogate Model ---
    controller_run_dir = os.path.join(args.output_base_dir, args.run_id)
    controller_model_path = os.path.join(controller_run_dir, "controller_model_direct_best.pth")
    with open(os.path.join(controller_run_dir, "hyperparams.yaml"), 'r') as f:
        controller_hyperparams = yaml.safe_load(f)
    controller_kwargs = {k: v for k, v in controller_hyperparams.items() if k in ['hidden_dim', 'num_layers', 'activation_fn']}
    controller = DirectDecisionMaker(input_dim=config['M_SENSORS'], output_dim=config['M_SENSORS'], **controller_kwargs).to(DEVICE)
    controller.load_state_dict(torch.load(controller_model_path, map_location=DEVICE))
    controller.eval()

    deeponet_run_id = controller_hyperparams['deeponet_run_id']
    deeponet_run_dir = os.path.join(args.output_base_dir, deeponet_run_id)
    deeponet_model_path = os.path.join(deeponet_run_dir, "deeponet_model.pth")
    with open(os.path.join(deeponet_run_dir, "hyperparams.yaml"), 'r') as f:
        deeponet_hyperparams = yaml.safe_load(f)
    model_arg_keys = ['latent_dim', 'branch_depth', 'branch_width', 'trunk_depth', 'trunk_width', 'activation_fn']
    deeponet_kwargs = {key: deeponet_hyperparams[key] for key in model_arg_keys}
    physics_simulator = DeepONetWithBias(branch_input_dim=config['M_SENSORS'], trunk_input_dim=config['TRUNK_INPUT_DIM'], **deeponet_kwargs).to(DEVICE)
    physics_simulator.load_state_dict(torch.load(deeponet_model_path, map_location=DEVICE))
    physics_simulator.eval()
    print("Successfully loaded NN controller and surrogate model.")

    # --- Get Controller's Action and Simulate Outcome ---
    target_zeta_torch = torch.from_numpy(target_zeta_np).float().unsqueeze(0).to(DEVICE)
    sensor_locs_np = np.linspace(0, config['L'], config['M_SENSORS'])
    
    with torch.no_grad():
        start_time = time.time()
        u_at_sensors_torch = controller(target_zeta_torch)
        inference_time = time.time() - start_time
        
        sensor_locs_torch = torch.from_numpy(sensor_locs_np).float().to(DEVICE).unsqueeze(1)
        trunk_input_queries = torch.cat([sensor_locs_torch, torch.full((config['M_SENSORS'], 1), config['T_FINAL'], device=DEVICE)], dim=1)
        branch_input_tiled = u_at_sensors_torch.repeat(config['M_SENSORS'], 1)
        final_state_predicted = physics_simulator(branch_input_tiled, trunk_input_queries)

    final_state_achieved_np = final_state_predicted.squeeze().cpu().numpy()
    u_profile_np = u_at_sensors_torch.cpu().numpy().flatten()
    final_mse = np.mean((final_state_achieved_np - target_zeta_np)**2)
    print(f"NN controller finished in {inference_time:.6f}s. Final MSE (surrogate): {final_mse:.4e}")
    
    return final_state_achieved_np, u_profile_np, inference_time

def _get_pde_solver_matrices(config):
    """Helper to build Crank-Nicolson matrices A and B."""
    NX_SOLVER, NT_SOLVER = config['NX_SOLVER'], config['NT_SOLVER']
    L, T_FINAL = config['L'], config['T_FINAL']
    D, BETA = config['D'], config['BETA']
    DX = L / (NX_SOLVER - 1)
    DT = T_FINAL / (NT_SOLVER - 1)
    
    lambda_ = D * DT / (2 * DX**2)
    A_diag_val = 1 + 2 * lambda_ + 0.5 * BETA * DT
    A_diagonals = [np.full(NX_SOLVER, -lambda_), np.full(NX_SOLVER, A_diag_val), np.full(NX_SOLVER, -lambda_)]
    A = spdiags(A_diagonals, [-1, 0, 1], NX_SOLVER, NX_SOLVER, format='csc')
    A[0, 1], A[-1, -2] = -2 * lambda_, -2 * lambda_
    B_diag_val = 1 - 2 * lambda_ - 0.5 * BETA * DT
    B_diagonals = [np.full(NX_SOLVER, lambda_), np.full(NX_SOLVER, B_diag_val), np.full(NX_SOLVER, lambda_)]
    B = spdiags(B_diagonals, [-1, 0, 1], NX_SOLVER, NX_SOLVER, format='csc')
    B[0, 1], B[-1, -2] = 2 * lambda_, 2 * lambda_
    
    return A, B

def _solve_pde(config, u_profile_full):
    """Forward PDE solver using Crank-Nicolson."""
    NX_SOLVER, NT_SOLVER = config['NX_SOLVER'], config['NT_SOLVER']
    DT = config['T_FINAL'] / (NT_SOLVER - 1)
    ALPHA, BETA, V_REF_VAL = config['ALPHA'], config['BETA'], config['V_REF_VAL']
    V_current = np.full(NX_SOLVER, config['INITIAL_STATE_VAL'])
    source_term = ALPHA * u_profile_full + BETA * V_REF_VAL
    A, B = _get_pde_solver_matrices(config)
    for _ in range(NT_SOLVER - 1):
        b_vec = B @ V_current + source_term * DT
        V_current = spsolve(A, b_vec)
    return V_current


def run_direct_optimizer_evaluation(config, target_zeta_np):
    """
    Runs the classical optimizer using finite-difference gradients (Direct method).
    """
    print("\n--- [2/3] Evaluating Classical DIRECT Optimizer ---")
    sensor_locs_np = np.linspace(0, config['L'], config['M_SENSORS'])
    x_grid_solver_np = np.linspace(0, config['L'], config['NX_SOLVER'])
    target_V_profile_full = np.interp(x_grid_solver_np, sensor_locs_np, target_zeta_np)

    def classical_direct_objective(u_at_sensors_np):
        u_profile_full = np.interp(x_grid_solver_np, sensor_locs_np, u_at_sensors_np)
        V_final_full = _solve_pde(config, u_profile_full)
        tracking_error = np.mean((V_final_full - target_V_profile_full)**2)
        effort_penalty = 1e-5 * np.sum(u_at_sensors_np**2)
        cost = tracking_error + effort_penalty
        print(f"Trying direct control... Cost = {cost:.6f}", end='\r')
        return cost

    bounds = [(config['U_MIN'], config['U_MAX']) for _ in range(config['M_SENSORS'])]
    initial_guess = np.zeros(config['M_SENSORS'])
    start_time = time.time()
    result = minimize(classical_direct_objective, initial_guess, method='L-BFGS-B', bounds=bounds, options={'disp': False, 'maxiter': 100})
    total_time = time.time() - start_time
    u_optimal = result.x
    V_final = _solve_pde(config, np.interp(x_grid_solver_np, sensor_locs_np, u_optimal))
    V_final_sensors = np.interp(sensor_locs_np, x_grid_solver_np, V_final)
    final_mse = np.mean((V_final_sensors - target_zeta_np)**2)
    print(f"\nDirect optimizer finished in {total_time:.4f}s. Final MSE: {final_mse:.4e}")
    return V_final_sensors, u_optimal, total_time


def run_adjoint_optimizer_evaluation(config, target_zeta_np):
    """
    Runs the classical optimizer using analytical adjoint-based gradients.
    """
    print("\n--- [3/3] Evaluating ADJOINT-Based Optimizer ---")
    sensor_locs_np = np.linspace(0, config['L'], config['M_SENSORS'])
    x_grid_solver_np = np.linspace(0, config['L'], config['NX_SOLVER'])
    target_V_profile_full = np.interp(x_grid_solver_np, sensor_locs_np, target_zeta_np)
    NX_SOLVER, NT_SOLVER = config['NX_SOLVER'], config['NT_SOLVER']
    DT = config['T_FINAL'] / (NT_SOLVER - 1)
    ALPHA = config['ALPHA']

    def solve_pde_and_adjoint(u_profile_full):
        A, B = _get_pde_solver_matrices(config)
        V_history = [_solve_pde(config, u_profile_full)] # Simplified for direct control
        error_grad = (2.0 / NX_SOLVER) * (V_history[-1] - target_V_profile_full)
        p_current = error_grad
        for _ in range(NT_SOLVER - 1):
            p_current = spsolve(A, B @ p_current)
        p_history = [p_current] * NT_SOLVER # Simplified adjoint solve
        return np.array(V_history), np.array(p_history)

    iter_count = [0]
    def adjoint_objective_and_grad(u_at_sensors_np):
        u_profile_full = np.interp(x_grid_solver_np, sensor_locs_np, u_at_sensors_np)
        V_history_full = _solve_pde(config, u_profile_full)
        
        # --- Adjoint Solve ---
        A_T, B_T = _get_pde_solver_matrices(config)
        A_T, B_T = A_T.T.tocsc(), B_T.T.tocsc()
        p_final = (2.0 / NX_SOLVER) * (V_history_full - target_V_profile_full)
        p_current = p_final
        for _ in range(NT_SOLVER - 1):
            p_current = spsolve(A_T, B_T @ p_current)
        p_integrated_full = p_current * config['T_FINAL']

        # Cost
        tracking_error = np.mean((V_history_full - target_V_profile_full)**2)
        effort_penalty = 1e-5 * np.sum(u_at_sensors_np**2)
        cost = tracking_error + effort_penalty
        
        # Gradient
        grad_effort = 2 * 1e-5 * u_at_sensors_np
        grad_tracking_full = ALPHA * p_integrated_full
        grad_tracking_sensors = np.interp(sensor_locs_np, x_grid_solver_np, grad_tracking_full)
        gradient = grad_effort + grad_tracking_sensors
        
        print(f"Trying adjoint control (iter {iter_count[0]})... Cost = {cost:.6f}", end='\r')
        iter_count[0] += 1
        return cost, gradient

    bounds = [(config['U_MIN'], config['U_MAX']) for _ in range(config['M_SENSORS'])]
    initial_guess = np.zeros(config['M_SENSORS'])
    start_time = time.time()
    result = minimize(adjoint_objective_and_grad, initial_guess, method='L-BFGS-B', jac=True, bounds=bounds, options={'disp': False, 'maxiter': 100})
    total_time = time.time() - start_time
    u_optimal = result.x
    V_final = _solve_pde(config, np.interp(x_grid_solver_np, sensor_locs_np, u_optimal))
    V_final_sensors = np.interp(sensor_locs_np, x_grid_solver_np, V_final)
    final_mse = np.mean((V_final_sensors - target_zeta_np)**2)
    print(f"\nAdjoint optimizer finished in {total_time:.4f}s. Final MSE: {final_mse:.4e}")
    return V_final_sensors, u_optimal, total_time


def generate_comparison_plots(all_results, sensor_locs_np, config, output_dir):

    print(f"\n--- Generating Combined Comparison Plot for All Targets ---")
    os.makedirs(output_dir, exist_ok=True)
    
    font_size = 40
    legend_font_size = 40

    fig, axes = plt.subplots(1, 3, figsize=(38, 9), sharey=False)
    
    target_order = ['sine', 'ramp', 'constant']

    for i, target_type in enumerate(target_order):
        ax = axes[i]
        results = all_results[target_type]
        target_zeta_np = results['target_zeta']
        
        ax.plot(sensor_locs_np, target_zeta_np, 'k:', lw=4, label=r'$y_{\text{target}}(x)$')
        ax.plot(sensor_locs_np, np.full_like(sensor_locs_np, config['INITIAL_STATE_VAL']), 'b--', lw=2, label='$y(0, x)$')
        ax.plot(sensor_locs_np, results['nn']['state'], 'r-', lw=2.5, label='PDE-OP')
        ax.plot(sensor_locs_np, results['adjoint']['state'], 'm:', lw=2.5, label='Adjoint Method')
        ax.plot(sensor_locs_np, results['direct']['state'], 'c-.', lw=2.5, label='Direct Method')
        
        ax.set_xlabel('x', fontsize=font_size)
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.tick_params(axis='both', which='major', labelsize=font_size)
        ax.set_ylabel('$y(t, x)$', fontsize=font_size)
    
    handles, labels = axes[0].get_legend_handles_labels()
    
    fig.legend(handles, labels, loc='lower center', 
               bbox_to_anchor=(0.5, -0.05), ncol=5, fontsize=legend_font_size)
               
    fig.subplots_adjust(bottom=0.3, wspace=0.3)
    
    save_path = os.path.join(output_dir, "comparison_final_state_ALL_TARGETS.pdf")
    plt.savefig(save_path, bbox_inches='tight')
    plt.close(fig)
    print(f"\nSaved combined comparison plot to: {save_path}")


def main(args):
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)

    all_results = {}
    target_types = ['sine', 'ramp', 'constant']
    sensor_locs_np = np.linspace(0, config['L'], config['M_SENSORS'])

    for target_type in target_types:
        config['target_type'] = target_type
        print(f"\n\n{'='*80}")
        print(f"--- Starting Unified Evaluation for Target: '{target_type.upper()}' ---")
        print(f"NN Controller Run ID: '{args.run_id}'")
        print(f"{'='*80}")
        
        target_zeta_np = generate_target_profile(config, target_type, sensor_locs_np)
        
        nn_state, nn_control, nn_time = run_nn_controller_evaluation(args, config, target_zeta_np)
        direct_state, direct_control, direct_time = run_direct_optimizer_evaluation(config, target_zeta_np)
        adjoint_state, adjoint_control, adjoint_time = run_adjoint_optimizer_evaluation(config, target_zeta_np)
        
        all_results[target_type] = {
            'nn': {'state': nn_state, 'control': nn_control, 'time': nn_time},
            'direct': {'state': direct_state, 'control': direct_control, 'time': direct_time},
            'adjoint': {'state': adjoint_state, 'control': adjoint_control, 'time': adjoint_time},
            'target_zeta': target_zeta_np
        }

        print(f"\n--- SUMMARY FOR '{target_type.upper()}' ---")
        print(f"{'Method':<20} | {'Time Taken (s)':<20} | {'Final MSE':<20}")
        print("-" * 65)
        print(f"{'NN Controller':<20} | {all_results[target_type]['nn']['time']:<20.6f} | {np.mean((nn_state - target_zeta_np)**2):<20.4e}")
        print(f"{'Direct Optimizer':<20} | {all_results[target_type]['direct']['time']:<20.4f} | {np.mean((direct_state - target_zeta_np)**2):<20.4e}")
        print(f"{'Adjoint Optimizer':<20} | {all_results[target_type]['adjoint']['time']:<20.4f} | {np.mean((adjoint_state - target_zeta_np)**2):<20.4e}")
        print("-" * 65)
    
    output_dir = os.path.join(args.output_base_dir, args.run_id, "comparison_plots_v2_combined")
    generate_comparison_plots(all_results, sensor_locs_np, config, output_dir)
    
    results_save_path = os.path.join(output_dir, "all_simulation_results.npz")
    np.savez_compressed(results_save_path, **all_results)
    print(f"\nSaved all numerical results to: {results_save_path}")
    
    print("\n--- Unified evaluation for all targets complete. ---")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Unified evaluation of NN, Direct, and Adjoint controllers.")
    parser.add_argument("--config_path", type=str, required=True, help="Path to the main YAML config file.")
    parser.add_argument("--output_base_dir", type=str, required=True, help="Base directory where runs are stored.")
    parser.add_argument("--run_id", type=str, required=True, help="The run_id of the trained NN controller.")
    args = parser.parse_args()
    main(args)