{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aa89eacd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import scipy\n",
    "from matplotlib import pyplot as plt\n",
    "from scipy.stats import multivariate_normal\n",
    "from matplotlib.collections import LineCollection\n",
    "from matplotlib.colors import ListedColormap, LinearSegmentedColormap\n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
    "import matplotlib.ticker as mticker\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "import warnings\n",
    "import numpy.matlib\n",
    "import matplotlib.ticker as ticker\n",
    "warnings.filterwarnings('ignore')\n",
    "import matplotlib.colors as mcolors\n",
    "from matplotlib.animation import FuncAnimation, PillowWriter\n",
    "import itertools\n",
    "import pandas as pd\n",
    "from joblib import Parallel, delayed\n",
    "from multiprocessing import cpu_count\n",
    "from scipy.optimize import minimize\n",
    "from numpy.linalg import inv, det\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "np.seterr(all='ignore')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "id": "7dcf4caf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Param():\n",
    "    tauE = 1\n",
    "    tauS = 5*tauE\n",
    "    xmax = 180\n",
    "    xmin = -180\n",
    "    xrange = xmax - xmin\n",
    "    N = 180\n",
    "    T = 100\n",
    "    dt = 0.01\n",
    "    F = 0.5\n",
    "    k = 0.0005\n",
    "    gain = 10\n",
    "\n",
    "    a_ES = 20\n",
    "    aE = 40\n",
    "    a_SE = np.sqrt(aE**2 - a_ES**2)\n",
    "    aS = np.sqrt((a_SE**2 + aE**2)/2)\n",
    "    \n",
    "    #density\n",
    "    rho = N/(xrange)\n",
    "    Wc = np.sqrt(8*np.sqrt(2*np.pi)*k*aE/rho)\n",
    "    Wee = 0.5*Wc\n",
    "    wie = 0.5*Wc\n",
    "    Ufd = Wc / (2 * (np.sqrt(np.pi)) * k * aE)\n",
    "    l = 0.8*Ufd\n",
    "    \n",
    "    wie = 0.5*Wc\n",
    "    SEweights = np.array([[wie,wie]])\n",
    "    \n",
    "    seed = 13\n",
    "    D = 2\n",
    "\n",
    "    PrefStim = np.linspace(xmin, xmax,N+1)\n",
    "    PrefStim = PrefStim[1:N+1]\n",
    "\n",
    "    return {'feedforward':l, 'density':rho,'CriticalW': Wc, 'V width':aS, 'Recurrent Weight':Wee, 'Neuron':N, 'Time':T, \n",
    "            'dist_range':xrange, 'xmin': xmin, 'xmax':xmax, 'FanoFactor': F,'time scale':dt, 'time constantE':tauE, \n",
    "            'space':PrefStim, 'inhibitory gain': gain,'k': k,'U width':aE, 'Re width':a_SE, 'Rs width': a_ES,\n",
    "            'time constantS':tauS,'recurrent':Wee, 'E to S':SEweights,'seed':seed, 'Dimension':D}\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "id": "ea7e6786",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Inputs(): \n",
    "\n",
    "    param_dict = Param()\n",
    "    PrefStim = param_dict[\"space\"]\n",
    "    a_SE = param_dict[\"Re width\"]\n",
    "    xrange = param_dict['dist_range']\n",
    "    N = param_dict[\"Neuron\"]\n",
    "    aE = param_dict[\"U width\"]\n",
    "    l = param_dict[\"feedforward\"]\n",
    "    a_ES = param_dict['Rs width']\n",
    "    T = param_dict['Time']\n",
    "    dt = param_dict['time scale']\n",
    "    Wc = param_dict['CriticalW']\n",
    "\n",
    "    ## Connection kernels\n",
    "    W_angle = np.angle(np.exp((1j*(PrefStim-PrefStim[0])*(2*np.pi/xrange)))) * xrange/(2*np.pi)\n",
    "    W_kerFtE = np.exp(-W_angle**2/(2*aE**2))/(np.sqrt(2*np.pi)*aE)    \n",
    "    WE = np.expand_dims(W_kerFtE,axis=1)\n",
    "    WE = np.fft.fft(W_kerFtE)\n",
    "    \n",
    "    W_kerFtES = np.exp(-W_angle**2/(2*a_ES**2))/(np.sqrt(2*np.pi)*a_ES)   \n",
    "    WES = np.expand_dims(W_kerFtES,axis=1) \n",
    "    WES = np.fft.fft(W_kerFtES)\n",
    "    \n",
    "    W_kerFtSE = np.exp(-W_angle**2/(2*a_SE**2))/(np.sqrt(2*np.pi)*a_SE) \n",
    "    WSE = np.expand_dims(W_kerFtSE,axis=1)   \n",
    "    WSE = np.fft.fft(W_kerFtSE)\n",
    "    \n",
    "    figc1, axc1 = plt.subplots(figsize=(8,8))\n",
    "    figc1.subplots_adjust(left=0.15, bottom=0.2, right=0.85, top=0.9, wspace=0.04, hspace=0.3)\n",
    "    axc1.plot(PrefStim,W_kerFtE, label=\"Recurrent Kernel\",linewidth=1)\n",
    "    axc1.plot(PrefStim,W_kerFtES, label=\"WES Kernel\",linewidth=1)\n",
    "    axc1.plot(PrefStim,W_kerFtSE, label=\"WSE Kernel\",linewidth=1)\n",
    "    axc1.legend()\n",
    "    \n",
    "    \n",
    "    x1 =  np.repeat(-20,N)\n",
    "    x2 = np.repeat(20,N)\n",
    "    \n",
    "    t = np.arange(0,int(T/dt),1)\n",
    "    \n",
    "    pos1 = np.subtract(x1,PrefStim)\n",
    "    pos1 = np.angle(np.exp(1j*pos1 * np.pi/N)) * N/np.pi\n",
    "    pos1 = np.expand_dims(pos1,axis=1)\n",
    "    pos1 = numpy.matlib.repmat(pos1,1,int(T/dt))\n",
    "    \n",
    "    pos2 = np.subtract(x2,PrefStim)\n",
    "    pos2 = np.angle(np.exp(1j*pos2 * np.pi/N)) * N/np.pi\n",
    "    pos2 = np.expand_dims(pos2,axis=1)\n",
    "    pos2 = numpy.matlib.repmat(pos2,1,int(T/dt))\n",
    "    \n",
    "    #ff1 = 1.5*Wc\n",
    "    #ff2 = 1.5*Wc\n",
    "    Ipos1 = (l)*  np.exp(-(pos1**2) / (4*aE**2))\n",
    "    Ipos2 = 0*(l)*  np.exp(-(pos2**2) / (4*aE**2))\n",
    "    \n",
    "    I1 = np.expand_dims(Ipos1,axis=1)\n",
    "    I1 = np.insert(I1,1,Ipos2,axis=1)\n",
    "    \n",
    "    return {\"E Kernel\": WE,'ES Kernel':WES, 'SE Kernel':WSE, 'Feedforward Input': I1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "78b989fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def SOM2D(ffl,wei,Wcoup,**kwargs):\n",
    "    import warnings\n",
    "    warnings.filterwarnings(\"ignore\")\n",
    "    np.seterr(all='ignore')\n",
    "    \n",
    "    param_dict = Param()\n",
    "    \n",
    "    wse = param_dict['E to S']\n",
    "    seed = param_dict['seed']\n",
    "    D = param_dict['Dimension']\n",
    "\n",
    "    N = param_dict[\"Neuron\"]\n",
    "    T = param_dict[\"Time\"]\n",
    "    dt = param_dict[\"time scale\"]\n",
    "    tauE = param_dict[\"time constantE\"]\n",
    "    tauS = param_dict['time constantS']\n",
    "    k = param_dict[\"k\"]\n",
    "    gain = param_dict[\"inhibitory gain\"]\n",
    "    PrefStim = param_dict['space']\n",
    "    F = param_dict['FanoFactor']\n",
    "    Wee = param_dict['Recurrent Weight']\n",
    "    #'U width':aE, 'Re width':a_SE, 'Rs width': a_ES\n",
    "    \n",
    "    wes = np.array([[wei,wei]])\n",
    "    ff = np.array([[ffl,ffl]])\n",
    "    Wpeak = np.array([[Wee,Wcoup],[Wcoup,Wee]])\n",
    "    \n",
    "    Idict = Inputs()\n",
    "\n",
    "    WE = Idict['E Kernel']\n",
    "    WES = Idict['ES Kernel']\n",
    "    WSE = Idict['SE Kernel']\n",
    "  \n",
    "\n",
    "    ContFF = Idict['Feedforward Input']\n",
    "    #print(np.shape(ContFF))\n",
    "    \n",
    "    ti = 0\n",
    "    t = np.arange(0,T/dt,1)\n",
    "    U = np.zeros([N,D,int(T/dt)])\n",
    "    V = np.zeros([N,D,int(T/dt)])\n",
    "    Oc = np.zeros([N,D,int(T/dt)])\n",
    "    Vc = np.zeros([N,D,int(T/dt)])\n",
    "   \n",
    "    \n",
    "    s = np.zeros([1,int(T/dt)])\n",
    "    z = np.zeros([1,int(T/dt)])\n",
    "    np.random.seed(seed)\n",
    "\n",
    "    counter = 0\n",
    "     \n",
    "    for i in range(ti,int(T/dt) -1):     \n",
    "        \n",
    "        #noise\n",
    "        noisevr = U[:,:,i]\n",
    "        noisevr[noisevr < 0]= 0\n",
    "        \n",
    "        \n",
    "        ##Firing Rates for this time point\n",
    "        O = np.array(Oc[:,:,i])\n",
    "        Oin = np.array(Vc[:,:,i])\n",
    "        Oft = np.fft.fft(O,axis=0)\n",
    "        Oinft = np.fft.fft(Oin,axis=0)\n",
    " \n",
    "        Fre = np.multiply(WE.reshape(180, 1), Oft) #Nx2\n",
    "        FRE = np.fft.ifft(Fre,axis=0) @ Wpeak #Nx2\n",
    "        Frse = np.multiply(WSE.reshape(180, 1), Oft)\n",
    "        FRSE = np.fft.ifft(Frse,axis=0) *wse\n",
    "        Fri = np.multiply(WES.reshape(180, 1), Oinft)\n",
    "        FRI = np.fft.ifft(Fri,axis=0) *wes\n",
    "\n",
    "        \n",
    "        U[:,:,i+1] = ((1-dt/tauE)* U[:,:,i]) + (FRE*dt/tauE) + (FRI*dt/tauE) + (ff*ContFF[:,:,i]*dt/tauE) + (np.sqrt(F*noisevr*dt) * np.random.randn(N,2))\n",
    "        V[:,:,i+1] = ((1-dt/tauS)* V[:,:,i]) + (FRSE*(dt/tauS))\n",
    "        #Divisive normalization calculation\n",
    "        Ocn = U[:,:,i+1]\n",
    "        Ocn[Ocn<0]=0\n",
    "        Ocn = np.square(Ocn)\n",
    "        Ocd = np.multiply(k, (np.sum(Ocn, axis=0)))\n",
    "        Oc[:,:,i+1] = np.divide(Ocn,(Ocd+1))\n",
    "        \n",
    "        #Inhibitory\n",
    "        Vc[:,:,i+1] = gain*V[:,:,i]\n",
    "        #Vc[Vc<0]=0\n",
    "\n",
    "        counter += 1\n",
    "     \n",
    "\n",
    "    #Check for negative firing rate\n",
    "    O_filtered = np.any(Oc[:,:,:] < 0)\n",
    "    Oc[Oc < 0]= 0\n",
    "    \n",
    "    O1 = Oc[:,0,:]\n",
    "    O2 = Oc[:,1,:]\n",
    "    \n",
    "    V1 = Vc[:,0,:]\n",
    "    V2 = Vc[:,1,:]\n",
    "    \n",
    "    ## Stimulus response calculations\n",
    "    e1 = np.exp(1j*PrefStim.transpose() * np.pi/ 180) @ (Oc[:,0,:])\n",
    "    S1 = np.angle(e1, deg=True)\n",
    "    Se1 = np.nan_to_num(S1)\n",
    "    \n",
    "    e2 = np.exp(1j*PrefStim.transpose() * np.pi/ 180) @ (Oc[:,1,:])\n",
    "    S2 = np.angle(e2, deg=True)\n",
    "    Se2 = np.nan_to_num(S2)\n",
    "    \n",
    "    e3 = np.exp(1j*PrefStim.transpose() * np.pi/ 180) @ (Vc[:,0,:])\n",
    "    S3 = np.angle(e3, deg=True)\n",
    "    Ss1 = np.nan_to_num(S3)\n",
    "    \n",
    "    e4 = np.exp(1j*PrefStim.transpose() * np.pi/ 180) @ (Vc[:,1,:])\n",
    "    S4 = np.angle(e4, deg=True)\n",
    "    Ss2 = np.nan_to_num(S4)\n",
    "    \n",
    "\n",
    "    return {'E stim1': Se1, 'I stim1': Ss1,'E stim2':Se2, 'I stim2':Ss2,\n",
    "            'Feedforward':ffl, 'Inhibitory':wei, 'Coupled': Wcoup\n",
    "            ,'E Firing Rate': Oc , 'I FR': Vc, 'Synaptic Input': U,\n",
    "            'I Synaptic Input':V\n",
    "            }\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "id": "e10d8196",
   "metadata": {},
   "outputs": [],
   "source": [
    "def HMC2D():\n",
    "    param_dict = Param()\n",
    "    Wc = param_dict[\"CriticalW\"]\n",
    "    lh = 1.7*Wc \n",
    "    wei = 0.9*Wc\n",
    "    Wcoup = 0.2*Wc\n",
    "    weil = 0.0\n",
    "\n",
    "    #For Hamiltonian Sampling\n",
    "    UtH = SOM2D(lh, -wei, Wcoup,Noise=True, temp=None, tmpgrph=True)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "workd",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
