{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define the neuron and network models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import brainpy as bp\n",
    "import brainpy.math as bm\n",
    "from brainpy.types import Shape\n",
    "\n",
    "bm.random.seed()\n",
    "\n",
    "%matplotlib qt5\n",
    "\n",
    "# define neuron model with gaussian recurrent connection\n",
    "class GaussRecUnits(bp.dyn.NeuDyn):\n",
    "    def __init__(self, size: Shape,tau=1.,J0=1.1,k=5e-4,a=2/9*bm.pi,z_min=-bm.pi,z_max=bm.pi,noise=2.):\n",
    "        super().__init__(size=size)\n",
    "        self.tau = tau  # The time constant\n",
    "        self.k = k  # The inhibition strength\n",
    "        self.a = a # The width of the Gaussian connection\n",
    "        self.noise_0 = noise # The noise level\n",
    "\n",
    "        # feature space\n",
    "        self.z_min = z_min\n",
    "        self.z_max = z_max\n",
    "        self.z_range = z_max - z_min\n",
    "        self.x = bm.linspace(z_min, z_max, size, endpoint=False)  # The encoded feature values\n",
    "        self.rho = size / self.z_range  # The neural density\n",
    "        self.dx = self.z_range / size  # The stimulus density\n",
    "\n",
    "        self.J = J0*self.Jc()  # The connection strength\n",
    "        self.conn_mat = self.make_conn() # The connection matrix\n",
    "\n",
    "        # variables\n",
    "        self.r = bm.Variable(bm.zeros(size)) # The neural firing rate\n",
    "        self.u = bm.Variable(bm.zeros(size)) # The neural synaptic input\n",
    "        self.input = bm.Variable(bm.zeros(size)) # The external input\n",
    "\n",
    "    # critical connection strength\n",
    "    def Jc(self):\n",
    "        Jc = bm.sqrt(8*bm.sqrt(2*bm.pi)*self.k*self.a/self.rho)\n",
    "        return Jc\n",
    "\n",
    "    # truncate the distance into the range of feature space\n",
    "    def dist(self, d):\n",
    "        d = bm.remainder(d, self.z_range)\n",
    "        d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)\n",
    "        return d\n",
    "    \n",
    "    # make the connection matrix\n",
    "    def make_conn(self):\n",
    "        dis = self.x[:, None] - self.x[None, :]\n",
    "        d = self.dist(dis)\n",
    "        conn = self.J * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)\n",
    "        return conn\n",
    "    \n",
    "    # Initialize the neural activity\n",
    "    def initialze(self):\n",
    "        self.u = 10.*bm.exp(-0.5*bm.square((self.x-0)/self.a))/(bm.sqrt(2*bm.pi)*self.a)\n",
    "        self.r = 30.*bm.exp(-0.5*bm.square((self.x-0)/self.a))/(bm.sqrt(2*bm.pi)*self.a)\n",
    "\n",
    "    # decode the neural activity\n",
    "    def decode(self, r, axis=0):\n",
    "        expo_r = bm.exp(1j * self.x) * r\n",
    "        return bm.angle(bm.sum(expo_r,axis=axis) / bm.sum(r,axis=axis))\n",
    "    \n",
    "    # update the neural activity\n",
    "    def update(self, input):\n",
    "        self.input.value = input\n",
    "        dt = bp.share['dt']\n",
    "        r1 = bm.square(self.u)\n",
    "        r2 = 1.0 + self.k * bm.sum(r1)\n",
    "        self.r.value = r1 / r2 + self.noise_0 * bm.random.randn(self.num)\n",
    "        Irec = bm.dot(self.conn_mat, self.r)\n",
    "        self.u.value = self.u + (-self.u + Irec + self.input ) / self.tau * dt\n",
    "        self.input[:] = 0.\n",
    "        return self.r\n",
    "    \n",
    "\n",
    "# define neuron model with non-recurrent connection\n",
    "class NonRecUnits(bp.dyn.NeuDyn):\n",
    "    def __init__(self, size: Shape, tau=1., z_min=-bm.pi, z_max=bm.pi, noise=2., square=False):\n",
    "        super().__init__(size=size)\n",
    "        self.tau = tau\n",
    "        self.noise_0 = noise\n",
    "        self.square = square # whether the activation function is square function\n",
    "\n",
    "        # feature space\n",
    "        self.z_min = z_min\n",
    "        self.z_max = z_max\n",
    "        self.z_range = z_max - z_min\n",
    "        self.x = bm.linspace(z_min, z_max, size, endpoint=False)  # The encoded feature values\n",
    "        self.rho = size / self.z_range  # The neural density\n",
    "        self.dx = self.z_range / size  # The stimulus density\n",
    "\n",
    "        # variables\n",
    "        self.r = bm.Variable(bm.zeros(size))\n",
    "        self.u = bm.Variable(bm.zeros(size))\n",
    "        self.input = bm.Variable(bm.zeros(size))\n",
    "\n",
    "    # choose the activation function\n",
    "    def activate(self, x):\n",
    "        if self.square:\n",
    "            return bm.square(x)\n",
    "        else:\n",
    "            return bm.relu(x)\n",
    "\n",
    "    def dist(self, d):\n",
    "        d = bm.remainder(d, self.z_range)\n",
    "        d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)\n",
    "        return d\n",
    "\n",
    "\n",
    "    def update(self,input):\n",
    "        self.input.value = input\n",
    "        dt = bp.share['dt']\n",
    "        self.r.value = self.activate(self.u) + self.noise_0 * bm.random.randn(self.num)\n",
    "        self.u.value = self.u + (-self.u + self.input) / self.tau * dt\n",
    "        self.input[:] = 0.\n",
    "        return self.r\n",
    "    \n",
    "\n",
    "\n",
    "# the intact networks contains a group of EPG neurons (recurrent units), two P-EN neurons (non-recurrent units), one group of \n",
    "        # FC2 (recurrent units), two PFL3 (non-recurrent units) and two DN neurons (non-recurrent units)\n",
    "\n",
    "class Drosophila(bp.DynamicalSystem):\n",
    "    def __init__(self,size=180,z_min=-bm.pi,z_max=bm.pi,  noise=2., **kwargs):\n",
    "        super(Drosophila, self).__init__(**kwargs)\n",
    "        self.size = size # The number of neurons in each neuron group except DN\n",
    "\n",
    "        # feature space\n",
    "        self.z_min = z_min\n",
    "        self.z_max = z_max\n",
    "        self.z_range = z_max - z_min\n",
    "        self.x = bm.linspace(z_min, z_max, size,endpoint=False)  # The encoded feature values\n",
    "        self.rho = size / self.z_range  # The neural density\n",
    "        self.dx = self.z_range / size  # The stimulus density\n",
    "\n",
    "        # shifts\n",
    "        self.PEN_shift = 1/18*bm.pi # the shift of the connection from PEN to EPG\n",
    "        self.PFL3_shift = 3/8*bm.pi # the shift of the connection from EPG to PFL3\n",
    "        self.PEN_shift_num = int(self.PEN_shift / self.dx) # the number of interval shifted\n",
    "        self.PFL3_shift_num = int(self.PFL3_shift / self.dx) # the number of interval shifted\n",
    "\n",
    "        # neurons\n",
    "        self.EPG = GaussRecUnits(size=size, noise=noise) #heading direction\n",
    "        self.FC2 = GaussRecUnits(size=size, noise=noise) #goal direction\n",
    "        self.PENl = NonRecUnits(size=size, noise=noise)\n",
    "        self.PENr = NonRecUnits(size=size, noise=noise)\n",
    "        self.PFL3l = NonRecUnits(size=size, noise=noise, square=True)\n",
    "        self.PFL3r = NonRecUnits(size=size, noise=noise, square=True)\n",
    "        self.DNl = NonRecUnits(size=1, noise=noise)\n",
    "        self.DNr = NonRecUnits(size=1, noise=noise)\n",
    "\n",
    "        #weights\n",
    "        self.w_EPG2PEN = 0.2\n",
    "        self.w_PEN2EPG = 1. \n",
    "        self.w_EPG2PFL3 = 1.\n",
    "        self.w_FC2PFL3 = 1.\n",
    "        self.w_PFL32DN = 0.2/size \n",
    "        self.w_DN2PEN = 1 \n",
    "        self.PEN_gain = 84.5\n",
    "\n",
    "        #conn\n",
    "        #EPG -> PFL3L\n",
    "        pre_ids = bm.arange(self.size)\n",
    "        post_ids = bm.roll(pre_ids, shift=-self.PFL3_shift_num)\n",
    "        conn_EPG2PFL3l = bp.conn.IJConn(pre_ids, post_ids, pre=self.size, post=self.size)\n",
    "        \n",
    "        #EPG -> PFL3R\n",
    "        pre_ids = bm.arange(self.size)\n",
    "        post_ids = bm.roll(pre_ids, shift=self.PFL3_shift_num)\n",
    "        conn_EPG2PFL3r = bp.conn.IJConn(pre_ids, post_ids, pre=self.size, post=self.size)\n",
    "\n",
    "\n",
    "        self.syn_EPG2PFL3l = bp.dnn.CSRLinear(conn_EPG2PFL3l,self.w_EPG2PFL3*bm.ones(self.size))\n",
    "        self.syn_EPG2PFL3r = bp.dnn.CSRLinear(conn_EPG2PFL3r,self.w_EPG2PFL3*bm.ones(self.size))\n",
    "\n",
    "        self.synapses()\n",
    "\n",
    "\n",
    "        # init heading direction\n",
    "        self.EPG.initialze()\n",
    "        # init goal direction\n",
    "        self.FC2.initialze()\n",
    "\n",
    "    # define the synapses\n",
    "    def synapses(self):\n",
    "        self.W_PENl2EPG = self.w_PEN2EPG*self.make_conn(self.PEN_shift)\n",
    "        self.W_PENr2EPG = self.w_PEN2EPG*self.make_conn(-self.PEN_shift)\n",
    "\n",
    "        # synapses\n",
    "        self.syn_EPG2PENl = bp.dnn.OneToOne(self.size,self.w_EPG2PEN)\n",
    "        self.syn_EPG2PENr = bp.dnn.OneToOne(self.size,self.w_EPG2PEN)\n",
    "        self.syn_PENl2EPG = bp.dnn.Linear(self.size,self.size,self.W_PENl2EPG)\n",
    "        self.syn_PENr2EPG = bp.dnn.Linear(self.size,self.size,self.W_PENr2EPG)\n",
    "        self.syn_FC2PFL3l = bp.dnn.OneToOne(self.size,self.w_FC2PFL3)\n",
    "        self.syn_FC2PFL3r = bp.dnn.OneToOne(self.size,self.w_FC2PFL3)\n",
    "        self.syn_PFL32DNl = bp.dnn.Linear(self.size,self.DNl.num,self.w_PFL32DN*bm.ones([self.size,self.DNl.num]))\n",
    "        self.syn_PFL32DNr = bp.dnn.Linear(self.size,self.DNr.num,self.w_PFL32DN*bm.ones([self.size,self.DNr.num]))\n",
    "        self.syn_DNl2PENl = bp.dnn.Linear(self.DNl.num,self.size,self.w_DN2PEN*bm.ones([self.DNl.num,self.size]))\n",
    "        self.syn_DNr2PENr = bp.dnn.Linear(self.DNr.num,self.size,self.w_DN2PEN*bm.ones([self.DNr.num,self.size]))\n",
    "\n",
    "    # move the heading direction representation (for testing)\n",
    "    def move_heading(self,shift):\n",
    "        self.EPG.r.value = bm.roll(self.EPG.r,shift)\n",
    "        self.EPG.u.value = bm.roll(self.EPG.u,shift)\n",
    "\n",
    "    # move the goal direction representation (for testing)\n",
    "    def move_goal(self,shift):\n",
    "        # shift_num = int(shift / self.dx)\n",
    "        self.FC2.r = bm.roll(self.FC2.r,shift)\n",
    "        self.FC2.u = bm.roll(self.FC2.u,shift)\n",
    "\n",
    "    def dist(self, d):\n",
    "        d = bm.remainder(d, self.z_range)\n",
    "        d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)\n",
    "        return d\n",
    "\n",
    "    def make_conn(self,shift):\n",
    "        d = self.dist(self.x[:,None]-self.x[None,:] + shift)\n",
    "        conn =  bm.exp(-0.5 * bm.square(d / self.EPG.a)) / (bm.sqrt(2 * bm.pi) * self.EPG.a)\n",
    "        return conn\n",
    "\n",
    "    def update(self,goal_input):\n",
    "        FC_output = self.FC2.update(goal_input) # goal direction\n",
    "\n",
    "        # EPG output last time step\n",
    "        EPG_output = self.EPG.r\n",
    "        # DN output last time step\n",
    "        DNl_output = self.DNl.r - 1/2*(self.DNl.r + self.DNr.r)\n",
    "        DNr_output = self.DNr.r - 1/2*(self.DNl.r + self.DNr.r)\n",
    "\n",
    "        # PEN input\n",
    "        PENl_input = self.syn_EPG2PENl(EPG_output)\n",
    "        PENr_input = self.syn_EPG2PENr(EPG_output)\n",
    "        # PEN output and gain\n",
    "        self.PENl.update(PENl_input)\n",
    "        self.PENr.update(PENr_input)\n",
    "        self.PENl.r.value = (self.PEN_gain+self.syn_DNl2PENl(DNl_output))*self.PENl.r\n",
    "        self.PENr.r.value = (self.PEN_gain+self.syn_DNr2PENr(DNr_output))*self.PENr.r\n",
    "\n",
    "        # EPG input\n",
    "        EPG_input = self.syn_PENl2EPG(self.PENl.r) + self.syn_PENr2EPG(self.PENr.r)\n",
    "        # EPG output\n",
    "        EPG_output = self.EPG.update(EPG_input)\n",
    "\n",
    "        # PFL3 input\n",
    "        PFL3l_input = self.syn_EPG2PFL3l(EPG_output) + self.syn_FC2PFL3l(FC_output)\n",
    "        PFL3r_input = self.syn_EPG2PFL3r(EPG_output) + self.syn_FC2PFL3r(FC_output)\n",
    "        # PFL3 output\n",
    "        PFL3l_output = self.PFL3l.update(PFL3l_input)\n",
    "        PFL3r_output = self.PFL3r.update(PFL3r_input)\n",
    "\n",
    "        # DN input\n",
    "        DNl_input = self.syn_PFL32DNl(PFL3l_output)\n",
    "        DNr_input = self.syn_PFL32DNr(PFL3r_output)\n",
    "        # DN update\n",
    "        self.DNl.update(DNl_input)\n",
    "        self.DNr.update(DNr_input)\n",
    "\n",
    "        return self.EPG.r, self.FC2.r, self.DNl.r, self.DNr.r\n",
    "    \n",
    "    def reset(self):\n",
    "        self.PENl.r.value = bm.zeros(self.size)\n",
    "        self.PENr.r.value = bm.zeros(self.size)\n",
    "        self.PFL3l.r.value = bm.zeros(self.size)\n",
    "        self.PFL3r.r.value = bm.zeros(self.size)\n",
    "        self.DNl.r.value = bm.zeros(1)\n",
    "        self.DNr.r.value = bm.zeros(1)\n",
    "        self.PENl.u.value = bm.zeros(self.size)\n",
    "        self.PENr.u.value = bm.zeros(self.size)\n",
    "        self.PFL3l.u.value = bm.zeros(self.size)\n",
    "        self.PFL3r.u.value = bm.zeros(self.size)\n",
    "        self.DNl.u.value = bm.zeros(1)\n",
    "        self.DNr.u.value = bm.zeros(1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# run and plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/Shared/Anaconda/anaconda3/envs/brainpy/lib/python3.9/site-packages/jax/_src/ops/scatter.py:96: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int32 to dtype=uint32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.002971172332763672,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "",
       "rate": null,
       "total": 400,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2641f6fc874e417aac3775318b398ee4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/400 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.\n",
      "/var/folders/mj/0tys327n22ng7h284dc264yw0000gn/T/ipykernel_1819/881151219.py:71: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n",
      "  plt.figure()\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "\n",
    "def jump_goal(jump_times = 3):\n",
    "    neuron_num = 180\n",
    "    drosophila = Drosophila(size=neuron_num,noise=0.)\n",
    "    drosophila.synapses()\n",
    "\n",
    "    # simulation para\n",
    "    T = 200\n",
    "    bm.set_dt(.5)\n",
    "    t_steps = int(T/bm.get_dt())\n",
    "\n",
    "    #goal inputs\n",
    "    goal_input = bm.zeros([t_steps,drosophila.FC2.num])\n",
    "    goal_jumps = bm.linspace(0,t_steps,jump_times+1,endpoint=True,dtype=int)\n",
    "    goal_directions = bm.zeros(t_steps)\n",
    "    for i in range(jump_times):\n",
    "        goal_input[goal_jumps[i]:goal_jumps[i+1]] += 50.*bm.exp(-0.5*bm.square(drosophila.dist(drosophila.FC2.x-(-1)**i*bm.pi/4)/drosophila.FC2.a))/(bm.sqrt(2*bm.pi)*drosophila.FC2.a)\n",
    "        goal_directions[goal_jumps[i]:goal_jumps[i+1]] = drosophila.dist((-1)**i*bm.pi/4)\n",
    "\n",
    "    # run\n",
    "    indices = bm.arange(t_steps)\n",
    "    EPG_activity,FC_activity,DNl,DNr = bm.for_loop(drosophila.step_run,(indices,goal_input),progress_bar=True,jit=True)\n",
    "\n",
    "    start = int(t_steps/4)\n",
    "\n",
    "    # omit the first few time steps\n",
    "    EPG_activity = EPG_activity[start:]\n",
    "    FC_activity = FC_activity[start:]\n",
    "    DNr = DNr[start:]\n",
    "    DNl = DNl[start:]\n",
    "    t_steps = t_steps-start\n",
    "    T = T-start*bm.get_dt()\n",
    "\n",
    "    # decode heading and goal direction from neural activity and velocity\n",
    "    heading_directions = bm.zeros(t_steps)\n",
    "    de_goal_directions = bm.zeros(t_steps)\n",
    "    velocity = bm.zeros(t_steps-1)\n",
    "\n",
    "\n",
    "    heading_directions = drosophila.EPG.decode(EPG_activity,axis=1)\n",
    "    de_goal_directions = drosophila.FC2.decode(FC_activity,axis=1)\n",
    "    velocity = drosophila.dist((heading_directions[1:]-heading_directions[:-1])/bm.get_dt()) * 1000\n",
    "\n",
    "\n",
    "\n",
    "    # plot the goal direction and the heading direction\n",
    "    plt.figure(figsize=(10,5))\n",
    "    plt.plot(bm.linspace(0,T,t_steps),de_goal_directions,label='Goal direction')\n",
    "    plt.plot(bm.linspace(0,T,t_steps),heading_directions,label='Heading direction')\n",
    "    plt.legend()\n",
    "    plt.xlabel('Time')\n",
    "    plt.ylabel('Direction')\n",
    "    \n",
    "\n",
    "    white_velocity = (velocity)/(velocity.max()-velocity.min()) # normalize the velocity\n",
    "    DNd = (DNl-DNr) # difference\n",
    "    white_DN = (DNd)/(DNd.max()-DNd.min()) # normalize the DN difference\n",
    "\n",
    "    # plot the velocity and DN activity\n",
    "    plt.figure(figsize=(5,5))\n",
    "    z = bm.linspace(0,1,t_steps-1)\n",
    "    plt.scatter(white_DN[1:],white_velocity,marker='o',c='w',edgecolors='r')#\n",
    "    plt.xlabel('DNl-DNr')\n",
    "    plt.ylabel('Velocity')\n",
    "    plt.legend()\n",
    "    # plt.savefig('jump_goal.eps')\n",
    "\n",
    "    plt.figure()\n",
    "    plt.scatter(bm.linspace(0,T,t_steps-1),heading_directions[1:],c=z,marker='o',cmap='hot_r')\n",
    "    plt.xlabel('Time')\n",
    "    plt.ylabel('Direction')\n",
    "    plt.title('Heading direction')\n",
    "    plt.colorbar()\n",
    "\n",
    "\n",
    "\n",
    "    plt.figure(figsize=(10,5))\n",
    "    plt.plot(bm.linspace(0,T,t_steps),DNl,label='DNl')#/(DNl.max()-DNl.min())\n",
    "    plt.plot(bm.linspace(0,T,t_steps),DNr,label='DNr')#/(DNl.max()-DNl.min())\n",
    "    plt.xlabel('Time')\n",
    "    plt.ylabel('Activity')\n",
    "    plt.legend()\n",
    "    # plt.savefig('DN_activity.eps')\n",
    "\n",
    "    #pcolormesh of EPG activity, superimposed on the goal direction with dashed line\n",
    "    plt.figure()\n",
    "    plt.pcolormesh(bm.linspace(0,T,t_steps),drosophila.EPG.x,EPG_activity.T,cmap='hot_r')\n",
    "    plt.plot(bm.linspace(0,T,t_steps),de_goal_directions,c='r',label='Goal direction',linestyle='--')\n",
    "    plt.xlabel('Time')\n",
    "    plt.ylabel('Direction')\n",
    "    plt.title('EPG activity')\n",
    "    # plt.savefig('EPG_activity.eps')\n",
    "\n",
    "    direction_diff = drosophila.EPG.dist(de_goal_directions-heading_directions)[1:]\n",
    "    X,Y = direction_diff,velocity\n",
    "    Z = DNl[1:]\n",
    "    plt.figure()\n",
    "    # plt.pcolormesh(direction_diff,velocity,Z[None,:].T,cmap='hot_r')\n",
    "    plt.scatter(X,Y,c=Z,marker='o',cmap='hot_r')\n",
    "    plt.xlabel('Angular Difference')\n",
    "    plt.ylabel('Velocity')\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# plot the DN difference relative to angular difference between heading and goal direction\n",
    "def DN_diff():\n",
    "    neu_num = 180\n",
    "    drosophila = Drosophila(size=neu_num,noise=0.)\n",
    "\n",
    "    # simulation para\n",
    "    T = 1000\n",
    "    t_steps = int(T/bm.get_dt())\n",
    "    start = int(t_steps/4)\n",
    "\n",
    "    heading_shift = 5\n",
    "\n",
    "    def body(t,goal_input):\n",
    "        drosophila.move_goal(heading_shift)\n",
    "\n",
    "        # iterate \n",
    "        iter = 30\n",
    "        indice = bm.arange(iter)\n",
    "        # input is goal_input for every iteration\n",
    "        inputs = bm.repeat(goal_input[None,:],iter,axis=0)\n",
    "        EPG_activity,FC_activity,DNl,DNr = bm.for_loop(drosophila.step_run,(indice,inputs),progress_bar=False,jit=True)\n",
    "        idx = -1\n",
    "        return EPG_activity[idx,:],FC_activity[idx,:],DNl[idx,:],DNr[idx,:]\n",
    "    \n",
    "    #goal inputs\n",
    "    goal_input = bm.zeros([t_steps,drosophila.FC2.num])\n",
    "    # goal_input[0:int(start/2)] = 100*bm.exp(-0.5*bm.square(drosophila.dist(drosophila.FC2.x)/drosophila.FC2.a))/(bm.sqrt(2*bm.pi)*drosophila.FC2.a)\n",
    "\n",
    "    # run\n",
    "    indices = bm.arange(t_steps)\n",
    "    EPG_activity,FC_activity,DNl,DNr = bm.for_loop(body,(indices,goal_input),progress_bar=True,jit=True)\n",
    "\n",
    "    # decode heading and goal direction from neural activity and velocity\n",
    "    heading_directions = drosophila.EPG.decode(EPG_activity,axis=1)\n",
    "    de_goal_directions = drosophila.FC2.decode(FC_activity,axis=1)\n",
    "    velocity = drosophila.dist((heading_directions[1:]-heading_directions[:-1])/bm.get_dt()) * 1000\n",
    "\n",
    "    \n",
    "\n",
    "    # plot DN difference\n",
    "    DNd = (DNl-DNr) # difference\n",
    "    white_DN = (DNd)/(DNd.max()-DNd.min())\n",
    "    Ang_diff = drosophila.EPG.dist(de_goal_directions-heading_directions)\n",
    "\n",
    "    \n",
    "    # plot heading and goal\n",
    "    plt.figure()\n",
    "    plt.plot(bm.linspace(0,T,t_steps-start),heading_directions[start:],label='Heading direction')\n",
    "    plt.plot(bm.linspace(0,T,t_steps-start),de_goal_directions[start:],label='Goal direction')\n",
    "    plt.xlabel('Time')\n",
    "    plt.ylabel('Direction')\n",
    "    plt.legend()\n",
    "\n",
    "\n",
    "\n",
    "    # plot the DN activity\n",
    "    plt.figure()\n",
    "    plt.scatter(bm.linspace(0,T,t_steps-start),white_DN[start:],label='DN activity',marker='o',c='r')\n",
    "    plt.xlabel('Angular Difference')\n",
    "    plt.ylabel('Activity')\n",
    "    plt.legend()\n",
    "\n",
    "    #3d projection of DN activity, x axis is the time, y axis is direction\n",
    "    fig = plt.figure()\n",
    "    ax = fig.add_subplot(111, projection='3d')\n",
    "    X = bm.arange(t_steps)\n",
    "    Y = drosophila.EPG.x\n",
    "    X, Y = bm.meshgrid(X, Y)\n",
    "    Z = EPG_activity\n",
    "    ax.plot_surface(X, Y, Z.T, cmap='viridis')\n",
    "    ax.set_xlabel('Time')\n",
    "    ax.set_ylabel('Direction')\n",
    "    ax.set_zlabel('Activity')\n",
    "    plt.title('EPG activity')\n",
    "\n",
    "    Align_points = bm.where(bm.abs(Ang_diff)<1/3*drosophila.dx)[0]\n",
    "    print(Align_points)\n",
    "\n",
    "    period_start = Align_points[5]\n",
    "    period_end = Align_points[6]\n",
    "\n",
    "    # plot the DN activity\n",
    "    plt.figure()\n",
    "    plt.scatter(Ang_diff[period_start:period_end],white_DN[period_start:period_end],c='r',label='DN activity difference',marker='o')\n",
    "    # horizontal line and vertical line at 0\n",
    "    plt.axhline(0,linestyle='--',c='k')\n",
    "    plt.axvline(0,linestyle='--',c='k')\n",
    "    plt.xlabel('Angular Difference')\n",
    "    plt.ylabel('Activity')\n",
    "    plt.legend()   \n",
    "    # plt.savefig('DN_diff_dots.eps')\n",
    "\n",
    "\n",
    "def moving_goal():\n",
    "        neurons_num = 180\n",
    "        drosophila = Drosophila(neurons_num,noise=0)\n",
    "        # drosophila.w_FC2PFL3 = 0.\n",
    "        print(drosophila.w_FC2PFL3)\n",
    "\n",
    "        # simulation para\n",
    "        T = 1000\n",
    "        t_steps = int(T/bm.get_dt())\n",
    "\n",
    "        #goal inputs\n",
    "        goal_input = bm.zeros([t_steps,drosophila.FC2.num]) #10*bm.random.randn(t_steps,drosophila.FC2.num)\n",
    "        steps = bm.zeros(t_steps)\n",
    "        steps[:int(3/4*t_steps)] = bm.pi * bm.linspace(0,T*3/4,int(3/4*t_steps))/T* bm.sin(4 *bm.pi * bm.linspace(0,T*3/4,int(3/4*t_steps))/T)\n",
    "        steps[int(3/4*t_steps):] = steps[int(3/4*t_steps)-1]\n",
    "        goal_input += 100*bm.exp(-0.5*bm.square(drosophila.dist(drosophila.FC2.x[None,:]-steps[:,None])/drosophila.FC2.a))/(bm.sqrt(2*bm.pi)*drosophila.FC2.a)\n",
    "\n",
    "        # decode goal direction\n",
    "        goal_directions = drosophila.dist(steps)\n",
    "        \n",
    "\n",
    "        # run\n",
    "        indices = bm.arange(t_steps)\n",
    "        EPG_activity,FC_activity,DNl,DNr = bm.for_loop(drosophila.step_run,(indices,goal_input),progress_bar=True,jit=True)\n",
    "\n",
    "        # decode heading and goal direction from neural activity\n",
    "        heading_directions = drosophila.EPG.decode(EPG_activity,axis=1)\n",
    "        de_goal_directions = drosophila.FC2.decode(FC_activity,axis=1)\n",
    "\n",
    "\n",
    "        # plot the goal direction and the heading direction\n",
    "        plt.figure(figsize=(5,5))\n",
    "        plt.plot(bm.linspace(0,T,t_steps),goal_directions,label='Goal direction')\n",
    "        plt.plot(bm.linspace(0,T,t_steps),heading_directions,label='Heading direction')\n",
    "        # plt.plot(bm.linspace(0,T,t_steps),de_goal_directions,label='Decoded goal direction')\n",
    "        plt.xlabel('Time')\n",
    "        plt.ylabel('Direction')\n",
    "        plt.legend()\n",
    "        # plt.savefig('moving_goal.eps')\n",
    "\n",
    "def DN_tuning():\n",
    "\n",
    "\n",
    "\n",
    "    weights = bm.linspace(0.5,2,10)\n",
    "    weight_num = weights.shape[0]\n",
    "    indices = bm.arange(weight_num)\n",
    "\n",
    "    neuron_num = 180\n",
    "    drosophila = Drosophila(size=neuron_num,noise=0.)\n",
    "\n",
    "    # simulation para\n",
    "    jump_times = 9\n",
    "    T = 600\n",
    "    bm.set_dt(.5)\n",
    "    t_steps = int(T/bm.get_dt())\n",
    "\n",
    "    \n",
    "\n",
    "    #goal inputs\n",
    "    goal_input = bm.zeros([t_steps,drosophila.FC2.num])\n",
    "    goal_jumps = bm.linspace(0,t_steps,jump_times+1,endpoint=True,dtype=int)\n",
    "    goal_directions = bm.zeros(t_steps)\n",
    "    for i in range(jump_times):\n",
    "        goal_input[goal_jumps[i]:goal_jumps[i+1]] += 50.*bm.exp(-0.5*bm.square(drosophila.dist(drosophila.FC2.x-(-1)**i*bm.pi/4)/drosophila.FC2.a))/(bm.sqrt(2*bm.pi)*drosophila.FC2.a)\n",
    "        goal_directions[goal_jumps[i]:goal_jumps[i+1]] = drosophila.dist((-1)**i*bm.pi/4)\n",
    "\n",
    "\n",
    "    def body(t,weight):\n",
    "        drosophila.w_PFL32DN = drosophila.w_PFL32DN*weight\n",
    "        drosophila.synapses()\n",
    "\n",
    "        t_steps_t = t_steps\n",
    "\n",
    "\n",
    "\n",
    "        # run\n",
    "        indice = bm.arange(t_steps_t)\n",
    "        EPG_activity,FC_activity,DNl,DNr = bm.for_loop(drosophila.step_run,(indice,goal_input),progress_bar=False,jit=True)\n",
    "\n",
    "        start = int(t_steps_t/4)\n",
    "\n",
    "        # omit the first few time steps\n",
    "        EPG_activity = EPG_activity[start:]\n",
    "        FC_activity = FC_activity[start:]\n",
    "        DNr = DNr[start:]\n",
    "        DNl = DNl[start:]\n",
    "        t_steps_t = t_steps_t-start\n",
    "\n",
    "        # decode heading and goal direction from neural activity and velocity\n",
    "        heading_direction = bm.zeros(t_steps_t)\n",
    "        de_goal_direction = bm.zeros(t_steps_t)\n",
    "        velocity = bm.zeros(t_steps_t-1)\n",
    "\n",
    "        # heading_directions[0] = drosophila.EPG.decode(EPG_activity[0])\n",
    "        # de_goal_directions[0] = drosophila.FC2.decode(FC_activity[0])\n",
    "\n",
    "        heading_direction = drosophila.EPG.decode(EPG_activity,axis=1)\n",
    "        de_goal_direction = drosophila.FC2.decode(FC_activity,axis=1)\n",
    "        velocity = drosophila.dist((heading_direction[1:]-heading_direction[:-1])/bm.get_dt()) * 1000\n",
    "\n",
    "        return DNl,DNr,heading_direction,de_goal_direction,velocity\n",
    "    \n",
    "    \n",
    "    DNls,DNrs,heading_directions,de_goal_directions,velocities = bm.for_loop(body,(indices,weights),progress_bar=True,jit=True)\n",
    "\n",
    "    direction_diff = drosophila.dist(de_goal_directions-heading_directions)\n",
    "\n",
    "    plt.figure()\n",
    "    X,Y = direction_diff,velocities\n",
    "    Z = DNls\n",
    "    plt.scatter(X,Y,c=Z,marker='o',cmap='hot_r')\n",
    "    plt.xlabel('Angular Difference')\n",
    "    plt.ylabel('Velocity')\n",
    "    plt.title('DNl activity')\n",
    "    plt.colorbar()\n",
    "    plt.show()\n",
    "    \n",
    "\n",
    "# DN_tuning()\n",
    "jump_goal(3)\n",
    "# DN_diff()\n",
    "# moving_goal()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "brainpy",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
