# === Module Imports ===
import os
import numpy as np
import torch.nn as nn
import gym
from torch.distributions import Categorical
from spiketorch import Network, LIF, Static
import torch.nn.functional as F
from scipy.special import logit

# Set GPU device (optional for single-GPU environments)
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

# === Hyperparameters ===
NUMIN = 4      # Number of input neurons (CartPole observation space)
NUM = 10       # Number of excitatory/inhibitory neurons per layer
NUMO = 1       # Number of output neurons (binary action decision)
HIDDEN = 1     # Number of recurrent EI layers
POPSIZE = 10   # CMA-ES population size
GENERATIONS = 100  # Max number of evolutionary iterations

# === Define the Model ===
def Model(params, I_input=0, hidden_dim=HIDDEN):
    """
    Build and run a model with given synaptic weights.
    Args:
        params: flattened list of all synaptic weights
        I_input: external input current from environment observation
        hidden_dim: number of EI layers
    Returns:
        output_group.rate: firing rate of output neurons
    """
    # LIF neuron parameters
    fthreshold = -50
    freset = -60
    c_m = 1
    v_rest = -60
    tau_syn_e = 5.0
    tau_syn_i = 5.0
    tau_m = 20
    frefractory = 0
    fv = -55
    run_time = 100
    dt = 0.1

    net = Network(dt=dt)  # Create spiketorch network

    # === Define neuron groups ===
    E_group = []  # Excitatory
    I_group = []  # Inhibitory

    # Synapse groups
    sInpE2E = []
    sInpE2I = []
    sInpI2E = []
    sInpI2I = []
    sE2E = []
    sE2I = []
    sI2E = []
    sI2I = []

    # Input layers (both excitatory and inhibitory projections)
    input_group_e = LIF(num=NUMIN, shape=(1, NUMIN), name='Inpute', 
                        v_init=fv, v_rest=v_rest, v_reset=freset,
                        c_m=c_m, tau_m=tau_m, tau_i=tau_syn_i, tau_e=tau_syn_e,
                        v_thresh=fthreshold, i_offset=I_input, tau_refrac=frefractory, dt=dt)

    input_group_i = LIF(num=NUMIN, shape=(1, NUMIN), name='Inputi',
                        v_init=fv, v_rest=v_rest, v_reset=freset,
                        c_m=c_m, tau_m=tau_m, tau_i=tau_syn_i, tau_e=tau_syn_e,
                        v_thresh=fthreshold, i_offset=-I_input, tau_refrac=frefractory, dt=dt)

    # Hidden EI layers
    for i in range(hidden_dim):
        E_group.append(LIF(num=NUM, shape=(1, NUM), name=f'Exc{i}',
                           v_init=fv, v_rest=v_rest, v_reset=freset,
                           c_m=c_m, tau_m=tau_m, tau_i=tau_syn_i, tau_e=tau_syn_e,
                           v_thresh=fthreshold, i_offset=10, tau_refrac=frefractory, dt=dt))

        I_group.append(LIF(num=NUM, shape=(1, NUM), name=f'Inh{i}',
                           v_init=fv, v_rest=v_rest, v_reset=freset,
                           c_m=c_m, tau_m=tau_m, tau_i=tau_syn_i, tau_e=tau_syn_e,
                           v_thresh=fthreshold, i_offset=0, tau_refrac=frefractory, dt=dt))
    
    # Output neuron group
    output_group = LIF(num=NUMO, shape=(1, NUMO), name='Output',
                       v_init=fv, v_rest=v_rest, v_reset=freset,
                       c_m=c_m, tau_m=tau_m, tau_i=tau_syn_i, tau_e=tau_syn_e,
                       v_thresh=fthreshold, i_offset=0, tau_refrac=frefractory, dt=dt)

    # === Parse weights from flat parameter list ===
    delay = 0
    weight_list = []
    for j in range(hidden_dim * 8 + 2):
        if j < 4:
            weight_list.append(params[j * NUMIN * NUM : (j+1) * NUMIN * NUM])
        elif j < hidden_dim * 8:
            base = 4 * NUMIN * NUM
            weight_list.append(params[base + (j-4)*NUM*NUM : base + (j-3)*NUM*NUM])
        else:
            base = 4 * NUMIN * NUM + (hidden_dim*8 - 4) * (NUM*NUM)
            weight_list.append(params[base + (j-hidden_dim*8)*NUM*NUMO : base + (j-hidden_dim*8+1)*NUM*NUMO])

    # === Create synaptic connections ===
    for i in range(hidden_dim):
        if i == 0:
            # Input layer connections to first EI layer
            sInpE2E.append(Static(num=NUMIN*NUM, name=f'SInpE2E{i}', weight=weight_list[0], delay=delay, dt=dt, tau=0))
            net.all_to_all(input_group_e, E_group[i], sInpE2E[i], syn_type='exec')

            sInpE2I.append(Static(num=NUMIN*NUM, name=f'SInpE2I{i}', weight=weight_list[1], delay=delay, dt=dt, tau=0))
            net.all_to_all(input_group_e, I_group[i], sInpE2I[i], syn_type='exec')

            sInpI2E.append(Static(num=NUMIN*NUM, name=f'SInpI2E{i}', weight=weight_list[2], delay=delay, dt=dt, tau=0))
            net.all_to_all(input_group_i, E_group[i], sInpI2E[i], syn_type='inh')

            sInpI2I.append(Static(num=NUMIN*NUM, name=f'SInpI2I{i}', weight=weight_list[3], delay=delay, dt=dt, tau=0))
            net.all_to_all(input_group_i, I_group[i], sInpI2I[i], syn_type='inh')
        else:
            # Inter-layer EI connections
            idx = i * 8
            sInpE2E.append(Static(num=NUM*NUM, name=f'SInpE2E{i}', weight=weight_list[idx], delay=delay, dt=dt, tau=0))
            net.all_to_all(E_group[i-1], E_group[i], sInpE2E[i], syn_type='exec')

            sInpE2I.append(Static(num=NUM*NUM, name=f'SInpE2I{i}', weight=weight_list[idx+1], delay=delay, dt=dt, tau=0))
            net.all_to_all(E_group[i-1], I_group[i], sInpE2I[i], syn_type='exec')

            sInpI2E.append(Static(num=NUM*NUM, name=f'SInpI2E{i}', weight=weight_list[idx+2], delay=delay, dt=dt, tau=0))
            net.all_to_all(I_group[i-1], E_group[i], sInpI2E[i], syn_type='inh')

            sInpI2I.append(Static(num=NUM*NUM, name=f'SInpI2I{i}', weight=weight_list[idx+3], delay=delay, dt=dt, tau=0))
            net.all_to_all(I_group[i-1], I_group[i], sInpI2I[i], syn_type='inh')

        # Recurrent EI dynamics within each layer
        sE2E.append(Static(num=NUM*NUM, name=f'SE2E{i}', weight=weight_list[idx+4], delay=delay, dt=dt, tau=0))
        net.all_to_all(E_group[i], E_group[i], sE2E[i], syn_type='exec')

        sE2I.append(Static(num=NUM*NUM, name=f'SE2I{i}', weight=weight_list[idx+5], delay=delay, dt=dt, tau=0))
        net.all_to_all(E_group[i], I_group[i], sE2I[i], syn_type='exec')

        sI2E.append(Static(num=NUM*NUM, name=f'SI2E{i}', weight=weight_list[idx+6], delay=delay, dt=dt, tau=0))
        net.all_to_all(I_group[i], E_group[i], sI2E[i], syn_type='inh')

        sI2I.append(Static(num=NUM*NUM, name=f'SI2I{i}', weight=weight_list[idx+7], delay=delay, dt=dt, tau=0))
        net.all_to_all(I_group[i], I_group[i], sI2I[i], syn_type='inh')

    # Connect last EI layer to output
    sE2O = Static(num=NUM*NUMO, name='SE2O', weight=weight_list[hidden_dim*8], delay=delay, dt=dt, tau=0)
    net.all_to_all(E_group[-1], output_group, sE2O, syn_type='exec')

    sI2O = Static(num=NUM*NUMO, name='SI2O', weight=weight_list[hidden_dim*8+1], delay=delay, dt=dt, tau=0)
    net.all_to_all(I_group[-1], output_group, sI2O, syn_type='inh')

    # Compile and simulate the network
    net.build()
    net.run(run_time)

    return output_group.rate

# === Evaluate a single parameter set on the CartPole-v1 environment ===
def evaluate_cartpole(params, node_num):
    """
    Run a single CartPole-v1 episode using the given SNN parameters.
    
    Args:
        params (np.ndarray): Flattened parameter vector (including weights and thresholds).
        node_num (int): Index of the current evaluation (used for tracking/logging).
    """
    env = gym.make('CartPole-v1')
    total_reward = 0
    observation, _ = env.reset()

    for step_num in range(500):
        # Convert observation to input current for SNN
        I_input = observation * 1000 * params[-2]

        # Run the SNN model and obtain a continuous action signal
        action = Model(params, I_input, node_num, step_num)

        # Threshold the output rate to produce binary action (0 or 1)
        ctl_th = params[-1] * 100
        action = 1 if action > ctl_th else 0

        # Step the environment with the chosen action
        observation, reward, done, _, _ = env.step(action)
        total_reward += reward

        # Optional early termination if done
        # if done:
        #     break

    env.close()
    print(total_reward)

# === Load multiple parameter sets from .npz files ===
def load_params():
    """
    Load all .npz parameter files from the current folder.
    
    Returns:
        List[np.ndarray]: A list of parameter arrays loaded from disk.
    """
    folder_path = './'  # Folder containing saved .npz results
    npy_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.npz')]

    param = []
    for file in npy_files:
        arr = np.load(file, allow_pickle=True)
        param.append(arr['arr_0'])  # Assumes data is stored under 'arr_0'
    return param

# === Initialize multiple sets of SNN parameters ===
def init_params():
    """
    Initialize 40 randomized parameter vectors for CMA-ES or parallel evaluation.
    
    The parameters are drawn from a uniform distribution in (0, 1),
    then transformed via the logit function to map them to an unbounded domain.

    Returns:
        List[np.ndarray]: A list of 40 logit-transformed parameter vectors.
    """
    param = []
    for _ in range(40):
        # Total number of trainable parameters in the network
        num_params = NUMIN * NUM * 4 + NUM * NUM * (HIDDEN * 8 - 4) + NUM * NUMO * 2 + 2

        # Initialize parameters uniformly in (0, 1)
        init_params_raw = np.random.rand(num_params)

        # Apply logit transform to map to (-∞, +∞) for CMA-ES optimization
        init_params = logit(init_params_raw)

        param.append(init_params)
    return param

# === Main execution block ===
if __name__ == "__main__":
    # Load all parameter sets
    # params = load_params()

    # Initialize parameter sets
    params = init_params()
    
    # Evaluate each parameter set sequentially
    node_num = 0
    for param in params:
        evaluate_cartpole(param, node_num)
        node_num += 1