# === Module Imports ===
import os
import numpy as np
import torch.nn as nn
import gym
from spiketorch import Network, LIF, Static
import torch.nn.functional as F
import cma
from cma.optimization_tools import EvalParallel2
from scipy.special import expit, logit

# Set GPU device (optional for single-GPU environments)
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

# === Hyperparameters ===
NUMIN = 4      # Number of input neurons (CartPole observation space)
NUM = 10       # Number of excitatory/inhibitory neurons per layer
NUMO = 1       # Number of output neurons (binary action decision)
HIDDEN = 1     # Number of recurrent EI layers
POPSIZE = 10   # CMA-ES population size
GENERATIONS = 100  # Max number of evolutionary iterations

# === Get job info from SLURM environment (for HPC cluster logging) ===
job_id = os.getenv("SLURM_JOB_ID")
node_id = os.getenv("SLURM_NODEID")
task_id = os.getenv("SLURM_PROCID")

# === Define the Model ===
def Model(params, I_input=0, hidden_dim=HIDDEN):
    """
    Build and run a model with given synaptic weights.
    Args:
        params: flattened list of all synaptic weights
        I_input: external input current from environment observation
        hidden_dim: number of EI layers
    Returns:
        output_group.rate: firing rate of output neurons
    """
    # LIF neuron parameters
    fthreshold = -50
    freset = -60
    c_m = 1
    v_rest = -60
    tau_syn_e = 5.0
    tau_syn_i = 5.0
    tau_m = 20
    frefractory = 0
    fv = -55
    run_time = 100
    dt = 0.1

    net = Network(dt=dt)  # Create spiketorch network

    # === Define neuron groups ===
    E_group = []  # Excitatory
    I_group = []  # Inhibitory

    # Synapse groups
    sInpE2E = []
    sInpE2I = []
    sInpI2E = []
    sInpI2I = []
    sE2E = []
    sE2I = []
    sI2E = []
    sI2I = []

    # Input layers (both excitatory and inhibitory projections)
    input_group_e = LIF(num=NUMIN, shape=(1, NUMIN), name='Inpute', 
                        v_init=fv, v_rest=v_rest, v_reset=freset,
                        c_m=c_m, tau_m=tau_m, tau_i=tau_syn_i, tau_e=tau_syn_e,
                        v_thresh=fthreshold, i_offset=I_input, tau_refrac=frefractory, dt=dt)

    input_group_i = LIF(num=NUMIN, shape=(1, NUMIN), name='Inputi',
                        v_init=fv, v_rest=v_rest, v_reset=freset,
                        c_m=c_m, tau_m=tau_m, tau_i=tau_syn_i, tau_e=tau_syn_e,
                        v_thresh=fthreshold, i_offset=-I_input, tau_refrac=frefractory, dt=dt)

    # Hidden EI layers
    for i in range(hidden_dim):
        E_group.append(LIF(num=NUM, shape=(1, NUM), name=f'Exc{i}',
                           v_init=fv, v_rest=v_rest, v_reset=freset,
                           c_m=c_m, tau_m=tau_m, tau_i=tau_syn_i, tau_e=tau_syn_e,
                           v_thresh=fthreshold, i_offset=10, tau_refrac=frefractory, dt=dt))

        I_group.append(LIF(num=NUM, shape=(1, NUM), name=f'Inh{i}',
                           v_init=fv, v_rest=v_rest, v_reset=freset,
                           c_m=c_m, tau_m=tau_m, tau_i=tau_syn_i, tau_e=tau_syn_e,
                           v_thresh=fthreshold, i_offset=0, tau_refrac=frefractory, dt=dt))
    
    # Output neuron group
    output_group = LIF(num=NUMO, shape=(1, NUMO), name='Output',
                       v_init=fv, v_rest=v_rest, v_reset=freset,
                       c_m=c_m, tau_m=tau_m, tau_i=tau_syn_i, tau_e=tau_syn_e,
                       v_thresh=fthreshold, i_offset=0, tau_refrac=frefractory, dt=dt)

    # === Parse weights from flat parameter list ===
    delay = 0
    weight_list = []
    for j in range(hidden_dim * 8 + 2):
        if j < 4:
            weight_list.append(params[j * NUMIN * NUM : (j+1) * NUMIN * NUM])
        elif j < hidden_dim * 8:
            base = 4 * NUMIN * NUM
            weight_list.append(params[base + (j-4)*NUM*NUM : base + (j-3)*NUM*NUM])
        else:
            base = 4 * NUMIN * NUM + (hidden_dim*8 - 4) * (NUM*NUM)
            weight_list.append(params[base + (j-hidden_dim*8)*NUM*NUMO : base + (j-hidden_dim*8+1)*NUM*NUMO])

    # === Create synaptic connections ===
    for i in range(hidden_dim):
        if i == 0:
            # Input layer connections to first EI layer
            sInpE2E.append(Static(num=NUMIN*NUM, name=f'SInpE2E{i}', weight=weight_list[0], delay=delay, dt=dt, tau=0))
            net.all_to_all(input_group_e, E_group[i], sInpE2E[i], syn_type='exec')

            sInpE2I.append(Static(num=NUMIN*NUM, name=f'SInpE2I{i}', weight=weight_list[1], delay=delay, dt=dt, tau=0))
            net.all_to_all(input_group_e, I_group[i], sInpE2I[i], syn_type='exec')

            sInpI2E.append(Static(num=NUMIN*NUM, name=f'SInpI2E{i}', weight=weight_list[2], delay=delay, dt=dt, tau=0))
            net.all_to_all(input_group_i, E_group[i], sInpI2E[i], syn_type='inh')

            sInpI2I.append(Static(num=NUMIN*NUM, name=f'SInpI2I{i}', weight=weight_list[3], delay=delay, dt=dt, tau=0))
            net.all_to_all(input_group_i, I_group[i], sInpI2I[i], syn_type='inh')
        else:
            # Inter-layer EI connections
            idx = i * 8
            sInpE2E.append(Static(num=NUM*NUM, name=f'SInpE2E{i}', weight=weight_list[idx], delay=delay, dt=dt, tau=0))
            net.all_to_all(E_group[i-1], E_group[i], sInpE2E[i], syn_type='exec')

            sInpE2I.append(Static(num=NUM*NUM, name=f'SInpE2I{i}', weight=weight_list[idx+1], delay=delay, dt=dt, tau=0))
            net.all_to_all(E_group[i-1], I_group[i], sInpE2I[i], syn_type='exec')

            sInpI2E.append(Static(num=NUM*NUM, name=f'SInpI2E{i}', weight=weight_list[idx+2], delay=delay, dt=dt, tau=0))
            net.all_to_all(I_group[i-1], E_group[i], sInpI2E[i], syn_type='inh')

            sInpI2I.append(Static(num=NUM*NUM, name=f'SInpI2I{i}', weight=weight_list[idx+3], delay=delay, dt=dt, tau=0))
            net.all_to_all(I_group[i-1], I_group[i], sInpI2I[i], syn_type='inh')

        # Recurrent EI dynamics within each layer
        sE2E.append(Static(num=NUM*NUM, name=f'SE2E{i}', weight=weight_list[idx+4], delay=delay, dt=dt, tau=0))
        net.all_to_all(E_group[i], E_group[i], sE2E[i], syn_type='exec')

        sE2I.append(Static(num=NUM*NUM, name=f'SE2I{i}', weight=weight_list[idx+5], delay=delay, dt=dt, tau=0))
        net.all_to_all(E_group[i], I_group[i], sE2I[i], syn_type='exec')

        sI2E.append(Static(num=NUM*NUM, name=f'SI2E{i}', weight=weight_list[idx+6], delay=delay, dt=dt, tau=0))
        net.all_to_all(I_group[i], E_group[i], sI2E[i], syn_type='inh')

        sI2I.append(Static(num=NUM*NUM, name=f'SI2I{i}', weight=weight_list[idx+7], delay=delay, dt=dt, tau=0))
        net.all_to_all(I_group[i], I_group[i], sI2I[i], syn_type='inh')

    # Connect last EI layer to output
    sE2O = Static(num=NUM*NUMO, name='SE2O', weight=weight_list[hidden_dim*8], delay=delay, dt=dt, tau=0)
    net.all_to_all(E_group[-1], output_group, sE2O, syn_type='exec')

    sI2O = Static(num=NUM*NUMO, name='SI2O', weight=weight_list[hidden_dim*8+1], delay=delay, dt=dt, tau=0)
    net.all_to_all(I_group[-1], output_group, sI2O, syn_type='inh')

    # Compile and simulate the network
    net.build()
    net.run(run_time)

    return output_group.rate

# === Environment Evaluation Function ===
def evaluate_cartpole(params, episodes=2):
    """
    Evaluate a given policy in the CartPole-v1 environment.
    Args:
        params: flattened weight vector
        episodes: number of episodes to average
    Returns:
        Negative average reward (CMA-ES minimizes the objective)
    """
    env = gym.make('CartPole-v1')
    total_reward = 0

    for _ in range(episodes):
        observation, _ = env.reset()
        episode_reward = 0
        for _ in range(500):  # Max 500 steps
            I_input = observation * 1000 * params[-2]  # Scale observation to input current
            action_rate = Model(params, I_input)
            ctl_th = params[-1] * 100
            action = 1 if action_rate > ctl_th else 0
            observation, reward, done, _, _ = env.step(action)
            episode_reward += reward
            if done:
                break
        total_reward += episode_reward

    env.close()
    return -total_reward / episodes  # CMA-ES is a minimizer

# === CMA-ES Optimization Loop ===
def train_cmaes(generations=GENERATIONS, population_size=POPSIZE):
    """
    Run CMA-ES optimization to evolve the parameters of the controller.
    Returns:
        best_params: evolved best-performing parameters
        best_scores: score history across generations
    """
    num_params = NUMIN*NUM*4 + NUM*NUM*(HIDDEN*8-4) + NUM*NUMO*2 + 2

    init_params_raw = np.random.rand(num_params)
    init_params = logit(init_params_raw)  # Convert to unconstrained space
    sigma = 0.5

    es = cma.CMAEvolutionStrategy(init_params, sigma, {'popsize': population_size})

    best_params = None
    best_score = -np.inf
    best_scores = []

    fitness = evaluate_cartpole
    num_cores = 10  # Parallel evaluation

    with EvalParallel2(fitness, number_of_processes=num_cores) as eval_all:
        for generation in range(generations):
            solutions = es.ask()
            scores = eval_all(expit(solutions))  # Convert back to bounded space
            es.tell(solutions, scores)
            es.disp()

            best_gen_idx = np.argmin(scores)
            if -scores[best_gen_idx] > best_score:
                best_score = -scores[best_gen_idx]
                best_params = solutions[best_gen_idx]

            print(f"Generation {generation+1}: Best Score = {best_score}")
            best_scores.append(best_score)

            if best_score >= 500:
                print("Solved!")
                break

    return best_params, best_scores

if __name__ == "__main__":
    best_params, best_scores = train_cmaes()

    filename = f"results_job{job_id}_node{node_id}_task{task_id}.npz"
    data_to_save = {'best_params': best_params, 'best_scores': best_scores}
    np.savez(filename, **data_to_save, allow_pickle=True)
    print(f"Data saved to {filename}")
