{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "Code adapted from riken_demo.m in https://github.com/FieteLab/neural-circuit-inference/tree/main\n",
    "\n",
    "Values for recurrent weight ($r$) and spiking threshold are obtained from thresholds_pinned.mat in the repo above"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# values for r and spiking threshold \n",
    "all_rs = [0.0025, 0.005, 0.0075, 0.01, 0.0125, 0.015, 0.0175, 0.02, 0.0225, 0.025]\n",
    "all_threshs = [0.0013, 0.0012, 0.0011, 0.001, 9.55e-4, 8.93e-4, 8.52e-4, 8.15e-4, 7.75e-4, 7.4e-4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PARAMETERS\n",
    "dt = 0.0001         # time-step in seconds\n",
    "N = 100             # number of neurons\n",
    "tau = 0.01          # synaptic time-constant in seconds\n",
    "\n",
    "# Weight parameters\n",
    "sig_1 = 6.98\n",
    "sig_2 = 7\n",
    "a1 = 1\n",
    "a2 = 1.0005\n",
    "\n",
    "# Create weight matrix W\n",
    "indices = np.arange(N)\n",
    "i_indices = indices.reshape(N, 1)\n",
    "j_indices = indices.reshape(1, N)\n",
    "x = np.minimum(np.abs(i_indices - j_indices), N - np.abs(i_indices - j_indices))\n",
    "W = a1 * (np.exp(-x**2 / (2 * sig_1**2)) - a2 * np.exp(-x**2 / (2 * sig_2**2)))\n",
    "\n",
    "plt.imshow(W)\n",
    "plt.colorbar()\n",
    "plt.title('Weight matrix W')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "W.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('W.npy', W)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Other parameters\n",
    "b = 0.001           # uniform feedforward input\n",
    "noise_sd = 0.3      # amplitude of noise\n",
    "noise_sparsity = 1.5  # noise is injected with the probability that a standard normal exceeds this\n",
    "r = 0.005           # recurrence strength\n",
    "threshold = 0.0012 # spiking threshold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_rec = 1 # sec. start collecting spikes\n",
    "start_t = int(start_rec/dt)\n",
    "print(start_t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.default_rng(seed=0)\n",
    "s = np.zeros((N, 1))     # starting activations\n",
    "t = 0\n",
    "\n",
    "total_time = 1000000\n",
    "\n",
    "activation = []\n",
    "all_spikes = []\n",
    "\n",
    "while t < total_time:\n",
    "    # Dynamics\n",
    "    noise = noise_sd * rng.standard_normal(size=(N, 1)) * (rng.standard_normal(size=(N, 1)) > noise_sparsity)\n",
    "    I = r * W @ s + b * (1 + noise)  # neuron inputs\n",
    "    spike = I > threshold                    # binary neural spikes\n",
    "    s = s + spike.astype(float) - s / tau * dt  # update activations\n",
    "    activation.append(s)\n",
    "    all_spikes.append(spike)\n",
    "    t += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "activation = np.array(activation).squeeze()\n",
    "all_spikes = np.array(all_spikes).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "activation = activation.astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "activation.shape, all_spikes.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of spikes in the collection window\n",
    "np.sum(all_spikes[start_t:, :])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r, threshold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(f'activation_r{r}.npy', activation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(f'spikes_r{r}.npy', all_spikes)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
