{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# for reproducibility:\n",
    "np.random.seed(0) #100\n",
    "\n",
    "# simulation time:\n",
    "T = 1200  # ms\n",
    "\n",
    "# constants:\n",
    "Ne = 1600  # Number of excitatory neurons\n",
    "Ni = 400  # Number of inhibitory neurons\n",
    "\n",
    "# Define neuron parameters for different types:\n",
    "# Parameters: a, b, c, d, name\n",
    "p_RS  = [0.02, 0.2, -65, 8, \"regular spiking (RS)\"] # regular spiking (RS)\n",
    "p_IB  = [0.02, 0.2, -55, 4, \"intrinsically bursting (IB)\"] # intrinsically bursting (IB)\n",
    "p_CH  = [0.02, 0.2, -51, 2, \"chattering (CH)\"] # chattering (CH)\n",
    "p_FS  = [0.1, 0.2, -65, 2, \"fast spiking (FS)\"] # fast spiking (FS)\n",
    "p_TC  = [0.02, 0.25, -65, 0.05, \"thalamic-cortical (TC)\"] # thalamic-cortical (TC)\n",
    "p_LTS = [0.02, 0.25, -65, 2, \"low-threshold spiking (LTS)\"] # low-threshold spiking (LTS)\n",
    "p_RZ  = [0.1, 0.26, -65, 2, \"resonator (RZ)\"] # resonator (RZ)\n",
    "\n",
    "# Define neuron types\n",
    "neuron_type_list = ['RS', 'IB', 'CH', 'FS', 'LTS']\n",
    "\n",
    "# Initialize neuron activity dictionary\n",
    "neuron_activity = {neuron_type: np.zeros((0, T)) for neuron_type in neuron_type_list}\n",
    "\n",
    "# Total number of neurons\n",
    "Ne_total = Ne + Ni\n",
    "\n",
    "# Assign neurons to types (indices for each type)\n",
    "Ne_RS = int(Ne * 0.25)   # 25% Regular Spiking\n",
    "Ne_IB = int(Ne * 0.25)   # 25% Intrinsically Bursting\n",
    "Ne_CH = int(Ne * 0.25)   # 25% Chattering\n",
    "Ne_FS = int(Ne * 0.25)   # 25% Fast Spiking\n",
    "\n",
    "# Initialize parameters for each neuron type\n",
    "a = np.zeros((Ne_total, 1))\n",
    "b = np.zeros((Ne_total, 1))\n",
    "c = np.zeros((Ne_total, 1))\n",
    "d = np.zeros((Ne_total, 1))\n",
    "names = []\n",
    "\n",
    "# Assign parameters for excitatory neurons\n",
    "a[:Ne_RS] = p_RS[0]\n",
    "b[:Ne_RS] = p_RS[1]\n",
    "c[:Ne_RS] = p_RS[2]\n",
    "d[:Ne_RS] = p_RS[3]\n",
    "names.extend([p_RS[4]] * Ne_RS)\n",
    "\n",
    "a[Ne_RS:Ne_RS + Ne_IB] = p_IB[0]\n",
    "b[Ne_RS:Ne_RS + Ne_IB] = p_IB[1]\n",
    "c[Ne_RS:Ne_RS + Ne_IB] = p_IB[2]\n",
    "d[Ne_RS:Ne_RS + Ne_IB] = p_IB[3]\n",
    "names.extend([p_IB[4]] * Ne_IB)\n",
    "\n",
    "a[Ne_RS + Ne_IB:Ne_RS + Ne_IB + Ne_CH] = p_CH[0]\n",
    "b[Ne_RS + Ne_IB:Ne_RS + Ne_IB + Ne_CH] = p_CH[1]\n",
    "c[Ne_RS + Ne_IB:Ne_RS + Ne_IB + Ne_CH] = p_CH[2]\n",
    "d[Ne_RS + Ne_IB:Ne_RS + Ne_IB + Ne_CH] = p_CH[3]\n",
    "names.extend([p_CH[4]] * Ne_CH)\n",
    "\n",
    "a[Ne_RS + Ne_IB + Ne_CH:] = p_FS[0]\n",
    "b[Ne_RS + Ne_IB + Ne_CH:] = p_FS[1]\n",
    "c[Ne_RS + Ne_IB + Ne_CH:] = p_FS[2]\n",
    "d[Ne_RS + Ne_IB + Ne_CH:] = p_FS[3]\n",
    "names.extend([p_FS[4]] * (Ne - Ne_RS - Ne_IB - Ne_CH))\n",
    "\n",
    "# Parameters for inhibitory neurons\n",
    "a[Ne:] = p_LTS[0]\n",
    "b[Ne:] = p_LTS[1]\n",
    "c[Ne:] = p_LTS[2]\n",
    "d[Ne:] = p_LTS[3]\n",
    "names.extend([p_LTS[4]] * Ni)\n",
    "\n",
    "S = np.hstack((0.05 * np.random.rand(Ne_total, Ne_total), -0.5 * np.random.rand(Ne_total, Ni)))\n",
    "\n",
    "# initial values of v and u:\n",
    "v = -65 * np.ones((Ne_total, 1))\n",
    "u = b * v\n",
    "firings = np.array([]).reshape(0, 2)  # Spike timings\n",
    "\n",
    "# initialize variables for recording data:\n",
    "I_array = np.zeros((Ne_total, T))\n",
    "v_array = np.zeros((Ne_total, T))\n",
    "u_array = np.zeros((Ne_total, T))\n",
    "\n",
    "# simulation of 1200 ms:\n",
    "for t in range(0, T):\n",
    "    # step 1: input current calculation:\n",
    "    I = np.vstack((5 * np.random.randn(Ne, 1), 2 * np.random.randn(Ni, 1)))\n",
    "    # summing synaptic contributions if there are any fired neurons in previous time step:\n",
    "    if t > 0:  \n",
    "        I += np.sum(S[:, fired], axis=1).reshape(-1, 1)\n",
    "        \n",
    "    # step 2: update the membrane potential and recovery variable (neuron dynamics) with Euler's method:\n",
    "    v += 0.5 * (0.04 * v**2 + 5 * v + 140 - u + I)\n",
    "    v += 0.5 * (0.04 * v**2 + 5 * v + 140 - u + I)\n",
    "    u +=  a * (b * v - u)\n",
    "\n",
    "    # step 3: check for spikes and update the membrane potential and recovery variable:\n",
    "    fired = np.where(v >= 30)[0] # check if the membrane potential exceeds 30 mV\n",
    "    if fired.size > 0:\n",
    "        firings = np.vstack((firings, np.hstack((t * np.ones((fired.size, 1)), fired.reshape(-1, 1)))))\n",
    "        # equalize all spikes at 30 mV by resetting v first to +30 mV and then to c:\n",
    "    v[fired] = c[fired] # reset v for fired neurons\n",
    "    u[fired] = u[fired] + d[fired] # increment u for fired neurons\n",
    "    \n",
    "    # step 4: record data:\n",
    "    I_array[:, t] = I.flatten()\n",
    "    v_array[:, t] = v.flatten()\n",
    "    u_array[:, t] = u.flatten()\n",
    "\n",
    "# Update neuron_activity dictionary with membrane potentials\n",
    "for neuron_type in neuron_type_list:\n",
    "    neuron_indices = [i for i, name in enumerate(names) if neuron_type in name]\n",
    "    if neuron_indices:\n",
    "        neuron_activity[neuron_type] = v_array[neuron_indices, :]\n",
    "\n",
    "# Plotting the spike timings:\n",
    "plt.figure(figsize=(10, 7))\n",
    "plt.scatter(firings[:, 0], firings[:, 1], s=1, c='k')\n",
    "plt.axhline(y=Ne, color='k', linestyle='-', linewidth=1)\n",
    "plt.text(0.8, 0.76, 'excitatory', color='k', fontsize=12, ha='left', va='center', transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=1)) \n",
    "plt.text(0.8, 0.84, 'inhibitory', color='k', fontsize=12, ha='left', va='center', transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=1))\n",
    "\n",
    "plt.xlabel('Time (ms)')\n",
    "plt.ylabel('Neuron index')\n",
    "plt.xlim([100, T])\n",
    "plt.ylim([0, Ne_total])\n",
    "plt.yticks(np.arange(0, Ne_total+1, 200))\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Print out neuron activity for verification\n",
    "for neuron_type in neuron_activity:\n",
    "    print(f\"Neuron Type: {neuron_type}\")\n",
    "    print(f\"Shape of the activity array: {neuron_activity[neuron_type].shape}\")\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
