# === Module Imports ===
import os
import numpy as np
import torch
import torch.nn as nn
import gym
from torch.distributions import Categorical
import random
from spiketorch import Network, LIF, Static
import torch.nn.functional as F
import imageio
import cma
from cma.optimization_tools import EvalParallel2

# Set the visible GPU (optional; useful on multi-GPU servers)
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

# === Network & Evolution Parameters ===
NUMIN = 4      # Number of input neurons (CartPole has 4 observations)
NUM = 20       # Number of hidden recurrent neurons
NUMO = 1       # Number of output neurons (binary decision)
HIDDEN = 1     # Number of hidden layers (currently fixed at 1)
POPSIZE = 10   # CMA-ES population size
GENERATIONS = 50  # Number of CMA-ES generations

# === SLURM Job Info for Multi-node Parallelization (Optional) ===
job_id = os.getenv("SLURM_JOB_ID")
node_id = os.getenv("SLURM_NODEID")
task_id = os.getenv("SLURM_PROCID")

# === Define Spiking Neural Network Model ===
def Model(params, I_input=0):
    """
    Construct and simulate a single-layer recurrent SNN with fixed weights.
    Args:
        params: flattened vector of synaptic weights and thresholds
        I_input: external input current, derived from environment state
    Returns:
        Firing rate of the output neuron after simulation
    """
    # LIF neuron model parameters
    fthreshold = -50
    freset = -60
    c_m = 1
    v_rest = -60
    tau_syn_e = 5.0
    tau_syn_i = 5.0
    tau_m = 20
    frefractory = 0
    fv = -55
    run_time = 100
    dt = 0.1

    # Create simulation network
    net = Network(dt=dt)

    # Input layer (receives transformed CartPole observation)
    input_group = LIF(num=NUMIN, shape=(1, NUMIN), name='Input',
                      v_init=fv, v_rest=v_rest, v_reset=freset,
                      c_m=c_m, tau_m=tau_m, tau_i=tau_syn_i, tau_e=tau_syn_e,
                      v_thresh=fthreshold, i_offset=I_input,
                      tau_refrac=frefractory, dt=dt)

    # Recurrent hidden layer
    rnn_group = LIF(num=NUM, shape=(1, NUM), name='Rnn',
                    v_init=fv, v_rest=v_rest, v_reset=freset,
                    c_m=c_m, tau_m=tau_m, tau_i=tau_syn_i, tau_e=tau_syn_e,
                    v_thresh=fthreshold, i_offset=10,
                    tau_refrac=frefractory, dt=dt)

    # Output neuron
    output_group = LIF(num=NUMO, shape=(1, NUMO), name='Output',
                       v_init=fv, v_rest=v_rest, v_reset=freset,
                       c_m=c_m, tau_m=tau_m, tau_i=tau_syn_i, tau_e=tau_syn_e,
                       v_thresh=fthreshold, i_offset=0,
                       tau_refrac=frefractory, dt=dt)

    # === Extract weights from flat parameter list ===
    delay = 0
    weight_list = []
    weight_list.append(params[0 : NUMIN * NUM])                          # Input to RNN
    weight_list.append(params[NUMIN * NUM : NUMIN * NUM + NUM * NUM])   # Recurrent RNN
    weight_list.append(params[NUMIN * NUM + NUM * NUM : -2])            # RNN to Output

    # === Define Static Synapses ===
    si2r = Static(num=NUMIN * NUM, name='Si2r', weight=weight_list[0], delay=delay, dt=dt, tau=0)
    net.all_to_all(input_group, rnn_group, si2r, syn_type='exec')

    sr = Static(num=NUM * NUM, name='Sr', weight=weight_list[1], delay=delay, dt=dt, tau=0)
    net.all_to_all(rnn_group, rnn_group, sr, syn_type='exec')

    sr2o = Static(num=NUM * NUMO, name='Sr2o', weight=weight_list[2], delay=delay, dt=dt, tau=0)
    net.all_to_all(rnn_group, output_group, sr2o, syn_type='exec')

    # Run the network simulation
    net.build()
    net.run(run_time)

    return output_group.rate

# === Evaluation Function for CartPole Task ===
def evaluate_cartpole(params, episodes=2):
    """
    Evaluate an SNN controller on CartPole-v1 over multiple episodes.
    Args:
        params: flat vector of synaptic weights and thresholds
        episodes: number of episodes to average
    Returns:
        Negative average reward (for CMA-ES minimization)
    """
    env = gym.make('CartPole-v1')
    total_reward = 0

    for _ in range(episodes):
        observation, _ = env.reset()
        episode_reward = 0
        for _ in range(500):  # maximum steps per episode
            I_input = observation * 1000 * params[-2]   # Scale observation into input current
            action_rate = Model(params, I_input)
            ctl_th = params[-1] * 100                   # Action threshold
            action = 1 if action_rate > ctl_th else 0   # Thresholded decision
            observation, reward, done, _, _ = env.step(action)
            episode_reward += reward
            if done:
                break
        total_reward += episode_reward

    env.close()
    return -total_reward / episodes  # Negative for CMA-ES (which minimizes)

# === CMA-ES Training Function ===
def train_cmaes(generations=GENERATIONS, population_size=POPSIZE):
    """
    Train an SNN controller using the CMA-ES evolutionary strategy.
    Returns:
        best_params: evolved parameters with the best performance
    """
    num_params = NUMIN * NUM + NUM * NUM + NUM * NUMO + 2  # Final two: gain + threshold

    init_params = np.random.uniform(-1, 1, size=num_params)  # Uniform initialization
    sigma = 0.5                                               # Initial search variance

    es = cma.CMAEvolutionStrategy(init_params, sigma, {'popsize': population_size})

    best_params = None
    best_score = -np.inf

    fitness = evaluate_cartpole
    num_cores = 5  # Number of parallel processes

    with EvalParallel2(fitness, number_of_processes=num_cores) as eval_all:
        for generation in range(generations):
            solutions = es.ask()
            scores = eval_all(solutions)
            es.tell(solutions, scores)
            es.disp()

            best_gen_idx = np.argmin(scores)
            if -scores[best_gen_idx] > best_score:
                best_score = -scores[best_gen_idx]
                best_params = solutions[best_gen_idx]

            print(f"Generation {generation+1}: Best Score = {best_score}")

            if best_score >= 500:
                print("Solved!")
                break

    return best_params

if __name__ == "__main__":
    best_params = train_cmaes()

    filename = f"results_job{job_id}_node{node_id}_task{task_id}.npy"
    np.save(filename, best_params, allow_pickle=True)
    print(f"Data saved to {filename}")
