{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea98285b",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset\n",
    "import sys, os\n",
    "#print(\"Filepath\")\n",
    "#print(str(os.path.dirname(os.path.abspath(__file__))))\n",
    "sys.path.append('../')\n",
    "from rnn.vae import VAE\n",
    "from rnn.train import train_VAE\n",
    "from rnn.datasets import SU_dataset, transform,Basic_dataset\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from pyrnn.model import RNN, predict\n",
    "from pyrnn.train import train_rnn\n",
    "from pyrnn.train import save_rnn, load_rnn\n",
    "import matplotlib.pyplot as plt\n",
    "from rnn.saving import save_model, load_model\n",
    "import matplotlib as mpl\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a4031f4",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "train=True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "218e54cf",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "out_dir = \"\"\n",
    "data_path =  \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c991c226",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "\n",
    "model_params = {\n",
    "    \"nonlinearity\": \"relu\",           # activation function\n",
    "    \"rank\": 2,                        # rank, set to 0 for full rank\n",
    "    \"n_inp\":2,                 # amount of input units\n",
    "    \"p_inp\": 1,                       # probability of connection input\n",
    "    \"n_rec\": 128,                     # amount of recurrent units\n",
    "    \"p_rec\": 1,                       # probability of connection recurrent\n",
    "    \"n_out\": 2,                       # amount of output units\n",
    "    \"scale_w_inp\": 1,                 # scale input weights\n",
    "    \"scale_w_out\": 1,                 # scale output weights\n",
    "    \"w_rec_dist\": \"gauss\",            # recurrent weight dist, gauss or gamma\n",
    "    \"spectr_rad\": 1,                # gain param, recurrent weights\n",
    "    \"spectr_norm\": True,              # use spectral normalisation on rec weights\n",
    "    \"train_w_inp\": True,              # train input weights\n",
    "    \"train_w_inp_scale\": False,       # train input scaling factor\n",
    "    \"train_w_rec\": True,              # train recurrent weights\n",
    "    \"train_b_rec\": True,             # train recurrent bias\n",
    "    \"train_taus\": False,              # train time constants\n",
    "    \"train_m\": False,                  # train m\n",
    "    \"train_n\": True,\n",
    "    \"scale_n\":1,\n",
    "    \"scale_m\":1,\n",
    "    \"loadings\":None,\n",
    "    \"cov\":None,\n",
    "    \"train_w_out\": True,              # train output weights\n",
    "    \"train_w_out_scale\": False,       # train output scaling factor\n",
    "    \"train_x0\":False,                 # train initial state\n",
    "    \"tau_lims\": [100],                # tau limits (min, max) or (value) in ms\n",
    "    'project_taus':'sigmoid',         # choice of projection map to keep within limits (\"exp\", \"sigmoid\" or \"clip\")\n",
    "    'tau_mean': 100,                  # if tau distribution, specify mean\n",
    "    'tau_std':1,                       # if tau distribution, specify std\n",
    "    \"dt\": 10,                         # timestep in ms\n",
    "    \"noise_std\": 0.05,                # noise std\n",
    "    \"scale_x0\": 0.1,                  # std of initial state, if gaussian\n",
    "    \"conn_mask\":None,                 # connection mask (tensor of size n_rec*n_rec)\n",
    "    \"dale_mask\":None,                 # dale mask (tensor of size n_rec*n_rec with only -1 and 1's on diagonal)\n",
    "}\n",
    "training_params = {\n",
    "    \"n_epochs\":10000,                  # number of passes through possible trials\n",
    "    \"lr\": 10e-4,                       # learning rate\n",
    "    \"batch_size\": 4,                  # batch size\n",
    "    \"clip_gradient\": 1,               # to avoid explosion of gradients\n",
    "    \"cuda\": False,                     # train on GPU if True\n",
    "    \"loss_fn\": \"mse\",                 # loss function (mse, cos or none)\n",
    "    \"optimizer\": \"adam\",              # optimizer (adam)\n",
    "    \"osc_reg_cost\": 0,                # oscillatory regularisation weight\n",
    "    \"osc_reg_freq\": 2,                # oscillatory regularisation frequency\n",
    "    \"offset_reg_cost\": 0,             # offset regularisation cost\n",
    "    \"l2_rates_reg\":.1,                  # l2 regularisation on rates\n",
    "\n",
    "}\n",
    "\n",
    "\n",
    "class Reaching(Dataset):\n",
    "    def __init__(self, task_params):\n",
    "        \"\"\"\n",
    "        Initialize a Reaching task\n",
    "        Args:\n",
    "            task_params: dictionary containing task parameters\n",
    "        \"\"\"\n",
    "\n",
    "        self.task_params = task_params\n",
    "        \n",
    "\n",
    "\n",
    "    def __len__(self):\n",
    "        \"\"\"Arbitrary number of trials, as they are randomly generated anyway\"\"\"\n",
    "        return self.task_params['n_stim']\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        \"\"\"\n",
    "        Returns a trial\n",
    "\n",
    "        Args:\n",
    "            idx, trial index\n",
    "\n",
    "        Returns:\n",
    "            input, Tensor of size [seq_len, n_inp]\n",
    "            target, Tensor of size [seq_len, n_inp]\n",
    "            mask, Tensor of size [seq_len, n_inp]\n",
    "        \"\"\"\n",
    "        phase = np.pi*2*idx/self.task_params['n_stim']\n",
    "        inputs = torch.zeros(self.task_params['trial_len'],2)\n",
    "        cp = np.cos(phase)\n",
    "        sp = np.sin(phase)\n",
    "        onset = torch.randint(low=self.task_params['onset'][0],high=self.task_params['onset'][1],size=(1,))[0]\n",
    "        stim_dur = torch.randint(low=self.task_params['stim_dur'][0],high=self.task_params['stim_dur'][1],size=(1,))[0]\n",
    "        delay_dur = torch.randint(low=self.task_params['delay_dur'][0],high=self.task_params['delay_dur'][1],size=(1,))[0]\n",
    "        delay_end = delay_dur+onset+stim_dur\n",
    "        inputs[onset:onset+stim_dur,0]=cp\n",
    "        inputs[onset:onset+stim_dur,1]=sp\n",
    "        #inputs[delay_end:,2]=1\n",
    "        targets = torch.zeros(self.task_params['trial_len'],2)\n",
    "        targets[delay_end:,0]=cp\n",
    "        targets[delay_end:,1]=sp\n",
    "        mask=torch.zeros_like(targets)\n",
    "        mask[onset+stim_dur:]=1\n",
    "        #mask[onset+stim_dur:]=1\n",
    "        #mask=torch.ones_like(targets)\n",
    "\n",
    "        #mask[delay_end:]=5\n",
    "\n",
    "        return inputs,targets,mask\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "889a2c33",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "rnn_osc = RNN(model_params)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2636eb89",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "task_params = {\n",
    "        'onset': [25,26], # time till target stimulus onset (uniform between)\n",
    "        'trial_len' : 150, # trial duration\n",
    "        'stim_dur' : [25,26], # target stimulus duration\n",
    "        'delay_dur': [0,1],# time till Movement onset cue \n",
    "        'n_stim': 16#number of stimuli locations\n",
    "}\n",
    "reaching = Reaching(task_params)   \n",
    "stimulus, target, loss_mask = reaching[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9897ad38",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "# Plot an example trial\n",
    "\n",
    "fig,axs = plt.subplots(3,2,figsize=(6, 3))\n",
    "fig.suptitle(\"Example trial\")\n",
    "axs[0,0].set_title(\"Input channels\")\n",
    "axs[0,1].set_title(\"Target\")\n",
    "axs[0,0].set_ylabel(r\"$\\sin(\\theta)$\")\n",
    "axs[1,0].set_ylabel(r\"$\\cos(\\theta)$\")\n",
    "axs[2,0].set_ylabel(\"cue\")\n",
    "\n",
    "for i in range(2):\n",
    "    axs[i,0].plot(stimulus[:,i],color='teal')\n",
    "    axs[i,0].set_xlim(0,task_params['trial_len'])\n",
    "    axs[i,0].set_ylim(-1.2,1.2)\n",
    "axs[2,0].plot(loss_mask)\n",
    "for i in range(2):\n",
    "    axs[i,1].plot(target[:,i], color='purple')\n",
    "    axs[i,0].set_xlim(0,task_params['trial_len'])\n",
    "    axs[i,1].set_ylim(-1.2,1.2)\n",
    "    axs[i,0].set_xticklabels([])\n",
    "axs[0,1].set_xticklabels([])\n",
    "axs[2,1].set_xlabel(\"time steps\")\n",
    "axs[2,0].set_xlabel(\"time steps\")\n",
    "fig.delaxes(axs[2,1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "594e63cf",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "if train:\n",
    "    losses, reg_losses = train_rnn(rnn_osc, training_params, reaching, sync_wandb=False)\n",
    "else:\n",
    "    rnn_osc,model_params,task_params,training_params = load_rnn(\"../data/reach_rnn\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "348de904",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "save_rnn(\"../data/reach_rnn_128\", rnn_osc,model_params,task_params,training_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9eaf2e0b",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "rates, pred = predict(rnn_osc, torch.zeros(1000,2))\n",
    "\n",
    "fig, axs = plt.subplots(figsize=(8, 2))\n",
    "axs.plot(pred[0,:,:])\n",
    "axs.set_xlabel(\"timesteps\")\n",
    "\n",
    "U = rnn_osc.rnn.m.detach()\n",
    "V= rnn_osc.rnn.n.detach()*.1/model_params['n_rec']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c239054e",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "plt.plot(rnn_osc.rnn.nonlinearity(torch.from_numpy(rates[0])));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71b077dd",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "U = torch.clone(rnn_osc.rnn.m.detach())\n",
    "V= torch.clone(rnn_osc.rnn.n.detach()*.1/model_params['n_rec'])\n",
    "W_or = U@V\n",
    "\n",
    "# Orthogonalise the weight matrix\n",
    "tr = 2\n",
    "U, s, V = torch.linalg.svd(W_or, full_matrices=False)\n",
    "U, s,V = U[:, :tr], s[:tr], V[:tr, :]\n",
    "V = (V.T * s).T\n",
    "W=U@V\n",
    "print(torch.linalg.norm(W_or-W)) #should be (very close to) 0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9d13f9d",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "# Bias and Input weights\n",
    "\n",
    "B = rnn_osc.rnn.b_rec.detach()\n",
    "I = rnn_osc.rnn.w_inp.detach()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3e7ba04",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a93525e",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "n_angles = 8\n",
    "dur = 300\n",
    "R_z=0.05\n",
    "\n",
    "n_repeats = 5\n",
    "z_all = torch.zeros(n_repeats,n_angles,dur,2)\n",
    "for ri in range(n_repeats):\n",
    "    angles = torch.arange(0,2*np.pi,2*np.pi/n_angles)\n",
    "    input = torch.zeros(n_angles, dur,2)\n",
    "    z=torch.zeros(n_angles, dur,2)\n",
    "    r0=.1\n",
    "    ph0 =torch.rand(n_angles)*2*np.pi\n",
    "    Trelu = nn.ReLU()\n",
    "    for i, angle in enumerate(angles):\n",
    "        input[i, 25:57,0]=torch.cos(angle)\n",
    "        input[i, 25:57,1]=torch.sin(angle)\n",
    "    z[:,0,0] = r0*torch.cos(ph0)\n",
    "    z[:,0,1] = r0*torch.sin(ph0)\n",
    "    #print(self.latents.shape)\n",
    "    for t in range(1, dur):\n",
    "        z[:,t]=0.9*z[:,t-1]\n",
    "        X = U@z[:,t-1].T+B.unsqueeze(1) + (input[:,t-1]@I).T\n",
    "        z[:,t]+=(V@Trelu(X)).T+torch.randn(len(angles),2)*R_z\n",
    "    z_all[ri]=z\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb72b5e2",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def relu(x):\n",
    "    return np.maximum(x,0)\n",
    "\n",
    "#plot_color_gradients('Cyclic', ['twilight', 'twilight_shifted', 'hsv'])\n",
    "prop_cycle =[plt.cm.hsv(i) for i in np.linspace(0, 1, len(angles))]\n",
    "phase = np.pi\n",
    "r=0\n",
    "cp = np.cos(phase)\n",
    "sp = np.sin(phase)\n",
    "\n",
    "inp1 = cp*r\n",
    "inp2 = sp*r\n",
    "u_in = np.array([inp1,inp2])\n",
    "\n",
    "def dyn_eq(x,y):\n",
    "    z = np.array([x,y])\n",
    "    z=-0.1*z+V.numpy()@relu(U.numpy()@z+u_in@I.numpy()+B.numpy())\n",
    "    return z[0],z[1]\n",
    "\n",
    "\n",
    "\n",
    "xlims = 20\n",
    "ylims = 20\n",
    "\n",
    "X, Y = np.meshgrid(np.linspace(-xlims, xlims, 30), np.linspace(-ylims, ylims, 30))\n",
    "u, v = np.zeros_like(X), np.zeros_like(X)\n",
    "NI, NJ = X.shape\n",
    "\n",
    "norm = np.zeros((NI,NJ))\n",
    "for i in range(NI):\n",
    "    for j in range(NJ):\n",
    "        x, y = X[i, j], Y[i, j]\n",
    "        dx, dy = dyn_eq(x,y)\n",
    "        u[i,j] = dx\n",
    "        v[i,j] = dy\n",
    "        norm[i,j]=np.log(np.linalg.norm([dx,dy]))\n",
    "#plt.axis('square')\n",
    "#plt.axis([-xlims, xlims, -ylims, ylims])\n",
    "with mpl.rc_context(fname=\"matplotlibrc\"):\n",
    "    fig,ax = plt.subplots(1,2,figsize=(8,4),dpi=200)\n",
    "\n",
    "    T1 = 0\n",
    "    T2 = 300\n",
    "    ax[0].set_box_aspect(1)\n",
    "    ax[0].spines[['right','top']].set_visible(False)\n",
    "    ax[0].imshow(norm,extent = [-xlims,xlims,-ylims,ylims], \n",
    "           origin ='lower',cmap ='bone',vmax=np.max(norm),aspect='auto')\n",
    "    ax[0].streamplot(X, Y, u, v,color='lavender')\n",
    "\n",
    "    ax[0].set_xlim(-xlims,xlims)\n",
    "    ax[0].set_ylim(-ylims,ylims)\n",
    "    ax[0].set_xticks([-xlims,0,xlims])\n",
    "    ax[0].set_yticks([-ylims,0,ylims])\n",
    "    ax[0].set_xlabel(r\"$z_1$\")\n",
    "    ax[0].set_ylabel(r'$z_2$')\n",
    "    #ax[0].plot(latent_code[0,T1:T2].numpy(),latent_code[1,T1:T2].numpy(),color='purple',lw=3)\n",
    "    ax[0].set_prop_cycle('color',prop_cycle)\n",
    "    for z in z_all:\n",
    "        for ang_i in range(n_angles):\n",
    "            ax[0].plot(z[ang_i,:,0],z[ang_i,:,1],alpha = .5)\n",
    "\n",
    "    for ang_i in range(n_angles):\n",
    "        ax[0].plot(z[ang_i,:1,0],z[ang_i,:1,1],alpha = 1,label=f\"{angles[ang_i]/(np.pi):.2f}$\\pi$\")\n",
    "    ax[0].legend(title='angle', loc = 'upper right',bbox_to_anchor=(1.3,1.1))\n",
    "    #for i in range(10):\n",
    "    #    ax[1].plot(rates[i,T1:T2]+i*2, color='black',alpha=.7)\n",
    "    ax[1].set_box_aspect(1)\n",
    "    ax[1].set_xlim(0,T2-T1)\n",
    "    ax[1].spines[['right','top']].set_visible(False)\n",
    "    ax[1].set_yticks(np.arange(0,20,2))\n",
    "    ax[1].set_yticklabels([])\n",
    "    ax[1].set_xlabel(\"timesteps\")\n",
    "    ax[1].set_ylabel(\"neurons\")\n",
    "    ax[0].set_title(\"latents\")\n",
    "    ax[1].set_title(\"observed\")\n",
    "    # turn of ax 1\n",
    "    ax[1].axis('off')\n",
    "    plt.savefig(\"GT.png\")\n",
    "#z0_hat = latent_code[:,0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2bd1d07",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "# Make data generator for training a student\n",
    "\n",
    "class Trial_gen(Dataset):\n",
    "    def __init__(self, task_params):\n",
    "\n",
    "        self.dur =task_params['dur']\n",
    "        self.n_trials = task_params['n_trials']\n",
    "        self.N = task_params['n_neurons']\n",
    "        self.w = task_params['w']\n",
    "        self.non_lin=task_params['non_lin']\n",
    "        self.R_z = task_params['R_z']\n",
    "        self.R_x = task_params['R_x']\n",
    "\n",
    "        Reaching_loader = DataLoader(reaching, batch_size=8, shuffle=True)\n",
    "        s, t, m = next(iter(Reaching_loader)) #x = trial,time,stims\n",
    "        s=s.repeat(12,1,1)[:96]\n",
    "\n",
    "        #print(s.shape)\n",
    "        #x,_ = rnn_osc(s) \n",
    "        #x = trial,time,neurons\n",
    "        #Dz, bs, T\n",
    "\n",
    "        ph0 = torch.randn(self.n_trials)*np.pi*2\n",
    "        r0  = torch.randn(self.n_trials)*0.1\n",
    "        self.latents = torch.zeros(2, self.n_trials, self.dur, dtype=torch.float32)\n",
    "        self.data = torch.zeros(self.N, self.n_trials, self.dur, dtype=torch.float32)\n",
    "        self.latents[0,:,0] = r0*np.cos(ph0)\n",
    "        self.latents[1,:,0] = r0*np.sin(ph0)\n",
    "        self.latents[:,:,0]+=torch.randn(2,self.n_trials)*self.R_z\n",
    "\n",
    "        #print(self.latents.shape)\n",
    "        for t in range(1, self.dur):\n",
    "            self.latents[:,:,t]=0.9*self.latents[:,:,t-1]\n",
    "            #rint(V.shape, self.latents[:,:,t-1].shape, B.shape, U.shape, self.latents[:,:,t].shape)\n",
    "            #torch.Size([2, 20]) torch.Size([2, 200]) torch.Size([20]) torch.Size([20, 2]) torch.Size([2, 200])\n",
    "            X = U@self.latents[:,:,t-1]+B.unsqueeze(1) + (s[:,t-1]@I).T\n",
    "            self.latents[:,:,t]+=V@self.non_lin(X)+torch.randn(2,self.n_trials)*self.R_z\n",
    "\n",
    "        if task_params[\"out\"]==\"rates\":\n",
    "            for t in range(self.dur):\n",
    "                self.data[:,:,t] = self.non_lin(U@self.latents[:,:,t]+B.unsqueeze(1))\n",
    "        elif task_params[\"out\"]==\"currents\":\n",
    "            for t in range(self.dur):\n",
    "                self.data[:,:,t] = U@self.latents[:,:,t]\n",
    "        self.data+=torch.randn(self.N,self.n_trials,self.dur)*self.R_x\n",
    "        self.data_eval = self.data\n",
    "        self.stim = s\n",
    "        self.task_params = task_params\n",
    "    def __len__(self):\n",
    "        return self.n_trials\n",
    "      \n",
    "    def __getitem__(self, idx):\n",
    "        \"\"\"\n",
    "        Return a trial of length self.dur\n",
    "        Args:\n",
    "            idx (int): trial index, arbitrary as trials are sampled randomly\n",
    "        Returns:\n",
    "            trial (torch.tensor; dim_x x self.dur): trial of length self.dur\n",
    "        \"\"\"\n",
    "        stim=self.stim[idx].T.to(device=self.data.device)\n",
    "        #stim = torch.zeros(0,self.dur,device=self.data.device)\n",
    "        return self.data[:,idx], stim\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0808ed9",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "# plot example trial plus the latent signal underlying it\n",
    "batch_size=4\n",
    "task_params ={\"dur\":150,\n",
    "              \"n_trials\":96,\n",
    "              \"name\":\"Sine\",\n",
    "              \"n_neurons\": 120,\n",
    "              \"out\":\"currents\",\n",
    "              \"w\":.1,\n",
    "              \"R_x\":.1,\n",
    "              \"R_z\":.05,\n",
    "              \"non_lin\":nn.ReLU(),\n",
    "              \"reach_task\": reaching\n",
    "             }\n",
    "task = Trial_gen(task_params)\n",
    "data_loader = DataLoader(\n",
    "   task, batch_size=batch_size, shuffle=True\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aac9fbb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "for tr_i in range(5):\n",
    "    latent_code = task.latents[:,tr_i]\n",
    "    stim = task.stim[tr_i]\n",
    "    plt.plot(latent_code[0],latent_code[1])\n",
    "    plt.scatter(stim[50,0],stim[50,1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fa322bd",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d387fbbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_i = 2\n",
    "rates = task.data[:,tr_i]\n",
    "latent_code = task.latents[:,tr_i]\n",
    "fig,ax = plt.subplots(1,figsize=(7,5))\n",
    "ax.plot(latent_code[0].numpy(),latent_code[1].numpy())\n",
    "\n",
    "\n",
    "fig,ax = plt.subplots(3,figsize=(7,5))\n",
    "ax[2].imshow(rates,aspect='auto',interpolation='none')\n",
    "ax[1].plot(rates.T)\n",
    "ax[0].plot(latent_code[0])\n",
    "ax[0].plot(latent_code[1])\n",
    "ax[0].plot(latent_code[-1])\n",
    "ax[2].set_xlabel(\"time steps\")\n",
    "ax[0].set_xticks([])\n",
    "ax[1].set_xticks([])\n",
    "ax[2].set_ylabel(\"neurons\")\n",
    "ax[1].set_ylabel(\"neuron activity\")\n",
    "ax[0].set_ylabel('latents')\n",
    "#ax[1].set_xlim([0,79])\n",
    "#ax[2].set_xlim([0,79])\n",
    "ax[0].spines[['right','top']].set_visible(False)\n",
    "ax[1].spines[['right','top']].set_visible(False)\n",
    "ax[2].spines[['right','top']].set_visible(False)\n",
    "\n",
    "\n",
    "fig,ax = plt.subplots(3,figsize=(7,5))\n",
    "ax[2].imshow(rates[:,:100],aspect='auto',interpolation='none')\n",
    "ax[1].plot(rates[:,:100].T)\n",
    "ax[0].plot(latent_code[0,:100])\n",
    "ax[0].plot(latent_code[1,:100])\n",
    "ax[0].plot(latent_code[-1,:100])\n",
    "ax[2].set_xlabel(\"time steps\")\n",
    "ax[0].set_xticks([])\n",
    "ax[1].set_xticks([])\n",
    "ax[2].set_ylabel(\"neurons\")\n",
    "ax[1].set_ylabel(\"neuron activity\")\n",
    "ax[0].set_ylabel('latents')\n",
    "#ax[1].set_xlim([0,79])\n",
    "#ax[2].set_xlim([0,79])\n",
    "ax[0].spines[['right','top']].set_visible(False)\n",
    "ax[1].spines[['right','top']].set_visible(False)\n",
    "ax[2].spines[['right','top']].set_visible(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e40ab22",
   "metadata": {},
   "source": [
    "## Create a VAE RNN setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d465fb1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dim_z = 2\n",
    "dim_N=120\n",
    "dim_x = task_params['n_neurons']\n",
    "dim_u = 2\n",
    "bs = 10\n",
    "cuda = True\n",
    "n_epochs = 500\n",
    "wandb=False\n",
    "# initialise encoder\n",
    "enc_params = {\n",
    "    \"init_kernel_sizes\":[11,5,5,3],\n",
    "    \"strides\":[1]*4,\n",
    "    \"padding\":'valid',#'same' or 'valid'\n",
    "    \"padding_mode\":'circular', #'zeros' or 'circular' or 'reflect'\n",
    "    \"nonlinearity\":'gelu',\n",
    "    \"dilations\":[1]*4,\n",
    "    \"n_channels\":[24,16,16,16],\n",
    "    \"n_hidden\":24,\n",
    "    \"init_scale\":.1,\n",
    "    \"constant_var\":True,\n",
    "    \"obs_grad\":False,\n",
    "\n",
    "    }\n",
    "\n",
    "# initialise prior\n",
    "prior_params={\n",
    "    \"train_noise_obs\":True,\n",
    "    \"train_noise_prior\":True,\n",
    "    \"train_noise_prior_t0\":True,\n",
    "    \"init_noise_z\":.05,\n",
    "    \"init_noise_z_t0\":1,\n",
    "    \"init_noise_x\":task_params['R_x'],\n",
    "    \"scalar_noise_z\":\"Cov\",\n",
    "    \"scalar_noise_x\":False,\n",
    "    \"scalar_noise_z_t0\":\"Cov\",\n",
    "    'identity_readout':True,\n",
    "    'activation':\"relu\",\n",
    "    \"exp_par\":True,\n",
    "    \"shared_tau\":.7,\n",
    "    \"readout_rates\": task_params[\"out\"],\n",
    "    \"train_obs_bias\":False,\n",
    "    \"train_obs_weights\":False,\n",
    "    \"train_latent_bias\":False,\n",
    "    \"train_neuron_bias\":True,\n",
    "    \"orth\":False,\n",
    "    \"m_norm\":False,\n",
    "    \"weight_dist\":\"uniform\",\n",
    "    \"weight_scaler\":1,#/dim_N,\n",
    "    'initial_state':'trainable'\n",
    "    }\n",
    "\n",
    "\n",
    "training_params = {\n",
    "    \"lr\":1e-3,\n",
    "    \"lr_end\":1e-5,\n",
    "    \"opt_eps\":1e-8,\n",
    "    \"CosineRestarts\":0,\n",
    "    \"beta\":0.5,\n",
    "    \"n_epochs\":n_epochs,\n",
    "    \"regularisation\":\"None\",\n",
    "    \"regularisation_params\": [1e-4,0],#ratio, lambda\n",
    "    \"annealing\":False,\n",
    "    \"annealing_epochs\":500,\n",
    "    \"grad_norm\":0,\n",
    "    \"bootstrap\":False,\n",
    "    \"eval_epochs\":50,\n",
    "    \"batch_size\":bs,\n",
    "    \"cuda\":cuda,\n",
    "    'smoothing':20,\n",
    "    'freq_cut_off':10000,\n",
    "    \"sim_obs_noise\":0,\n",
    "    \"sim_latent_noise\":1,\n",
    "    \"k\":64,\n",
    "    \"loss_f\":\"opt_VGTF\",\n",
    "    \"MC_q\":True,\n",
    "    \"dreg_q\":\"none\",\n",
    "    \"MC_p\":True,\n",
    "    \"dreg_p\":\"none\",\n",
    "    \"resample\":\"systematic\",#, multinomial or none\"\n",
    "    \"L2_reg\":0,\n",
    "    \"observation_likelihood\": \"Gauss\", # observation likelihood\n",
    "    \"bootstrap\":False,\n",
    "    \"alpha\":.25,\n",
    "    \"alpha_decay\":.999,\n",
    "    \"alpha_method\":\"constant\",\n",
    "    \"alpha_update_interval\":5,\n",
    "    \"run_eval\":True,\n",
    "    \"smooth_at_eval\":False\n",
    "}\n",
    "\n",
    "\n",
    "\n",
    "VAE_params = {\n",
    "    \"dim_x\":dim_x, \n",
    "    \"dim_z\":dim_z,\n",
    "    \"dim_u\":dim_u,\n",
    "    \"dim_N\":dim_N,\n",
    "    \"enc_architecture\":\"CNN\",\n",
    "    \"enc_params\":enc_params,\n",
    "    \"prior_architecture\":\"PLRNN\",\n",
    "    \"prior_params\":prior_params,\n",
    "    \"causal\":True,\n",
    "    }\n",
    "vae=VAE(VAE_params)\n",
    "\n",
    "#with torch.no_grad():\n",
    "    #vae.prior.transition.m.copy_(U)\n",
    "    #vae.prior.transition.m.requires_grad=False\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86916eb6",
   "metadata": {},
   "source": [
    "## Train the VAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0bfc852",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_VAE(\n",
    "    vae,\n",
    "    training_params,\n",
    "    task,\n",
    "    sync_wandb=wandb,\n",
    "    out_dir=out_dir,\n",
    "    fname=None\n",
    "    )\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13ee6ed2",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_model(vae,training_params,task_params,name=out_dir+\"/reach_vae_120\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a760675",
   "metadata": {},
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e01a91be",
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.prior.R_z"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "114f8fd8",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84727e39",
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.to_device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca659e7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "prior = vae.prior.transition\n",
    "tau = prior.cast_A(prior.AW).detach().numpy().squeeze()\n",
    "pV = (prior.n*prior.scaling).detach().numpy()\n",
    "pU = prior.m_transform(prior.m).detach().numpy()\n",
    "pB = prior.h.detach().numpy()\n",
    "pI = prior.Wu.detach().numpy()\n",
    "\n",
    "pJ = pU@pV\n",
    "pu,ps,pv = np.linalg.svd(pJ)\n",
    "projection_matrix = pu[:,:2].T@pU\n",
    "pU= pu[:,:2]\n",
    "pV = (pv[:2].T * ps[:2]).T\n",
    "proj_chol = projection_matrix@vae.prior.chol_cov_embed(vae.prior.R_z).detach().numpy()\n",
    "pJ_r = pU@pV\n",
    "print(proj_chol@proj_chol.T)\n",
    "print(np.sqrt(proj_chol@proj_chol.T))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "035036cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_angles = 8\n",
    "dur = 100\n",
    "R_z=0.05\n",
    "\n",
    "n_repeats = 5\n",
    "z_all =np.zeros((n_repeats,n_angles,dur,2))\n",
    "for ri in range(n_repeats):\n",
    "    angles =np.arange(0,2*np.pi,2*np.pi/n_angles)\n",
    "    input = np.zeros((n_angles, dur,2))\n",
    "    z=np.zeros((n_angles, dur,2))\n",
    "    r0=.1\n",
    "    ph0 =np.random.rand(n_angles)*2*np.pi\n",
    "    for i, angle in enumerate(angles):\n",
    "        input[i, 25:57,0]=np.cos(angle)\n",
    "        input[i, 25:57,1]=np.sin(angle)\n",
    "    z[:,0,0] = r0*np.cos(ph0)\n",
    "    z[:,0,1] = r0*np.sin(ph0)\n",
    "    #print(self.latents.shape)\n",
    "    for t in range(1, dur):\n",
    "        z[:,t]=tau*z[:,t-1]\n",
    "        X = pU@z[:,t-1].T-np.expand_dims(pB,1)+ (input[:,t-1]@pI.T).T\n",
    "        z[:,t]+=(pV@relu(X)).T+(proj_chol@np.random.randn(len(angles),2).T).T\n",
    "    z_all[ri]=z\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e92b1a4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec22ead6",
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.to_device(\"cpu\")\n",
    "\n",
    "phase = np.pi\n",
    "r=0\n",
    "cp = np.cos(phase)\n",
    "sp = np.sin(phase)\n",
    "\n",
    "inp1 = cp*r\n",
    "inp2 = sp*r\n",
    "u_in = np.array([inp1,inp2])\n",
    "\n",
    "def relu(x):\n",
    "    return np.maximum(x,0)\n",
    "\n",
    "\n",
    "\n",
    "def dyn_eq(x,y):\n",
    "    z = np.array([x,y])\n",
    "    X = pU@z+pI@u_in-pB\n",
    "    zn=(tau)*z+pV@relu(X)\n",
    "    h=.1\n",
    "    dz = (zn-z)/h\n",
    "    return dz[0],dz[1]\n",
    "\n",
    "xlims = 3\n",
    "ylims =3\n",
    "\n",
    "X, Y = np.meshgrid(np.linspace(-xlims, xlims, 30), np.linspace(-ylims, ylims, 30))\n",
    "u, v = np.zeros_like(X), np.zeros_like(X)\n",
    "NI, NJ = X.shape\n",
    "\n",
    "norm = np.zeros((NI,NJ))\n",
    "for i in range(NI):\n",
    "    for j in range(NJ):\n",
    "        x, y = X[i, j], Y[i, j]\n",
    "        dx, dy = dyn_eq(x,y)\n",
    "        u[i,j] = dx\n",
    "        v[i,j] = dy\n",
    "        norm[i,j]=np.log(np.linalg.norm([dx,dy]))\n",
    "\n",
    "with mpl.rc_context(fname='matplotlibrc'):\n",
    "\n",
    "    fig,ax = plt.subplots(1,2,figsize=(8,4),dpi=200)\n",
    "\n",
    "    ax[0].imshow(norm,extent = [-xlims,xlims,-ylims,ylims], \n",
    "            origin ='lower',cmap ='bone',vmax=np.max(norm),aspect='auto')\n",
    "    ax[0].streamplot(X, Y, u, v,color='lavender')\n",
    "    #plt.axis('square')\n",
    "    #plt.axis([-xlims, xlims, -ylims, ylims])\n",
    "    T1 = 0\n",
    "    T2 = 300\n",
    "    ax[0].set_box_aspect(1)\n",
    "    ax[0].spines[['right','top']].set_visible(False)\n",
    "\n",
    "    ax[0].set_xlim(-xlims,xlims)\n",
    "    ax[0].set_ylim(-ylims,ylims)\n",
    "    ax[0].set_xticks([-xlims,0,xlims])\n",
    "    ax[0].set_yticks([-ylims,0,ylims])\n",
    "    ax[0].set_xlabel(r\"$z_1$\")\n",
    "    ax[0].set_ylabel(r'$z_2$')\n",
    "\n",
    "    ax[0].set_prop_cycle('color',prop_cycle)\n",
    "    for z in z_all:\n",
    "        for ang_i in range(n_angles):\n",
    "            ax[0].plot(z[ang_i,:,0],z[ang_i,:,1],alpha = .5)\n",
    "\n",
    "    for ang_i in range(n_angles):\n",
    "        ax[0].plot(z[ang_i,:1,0],z[ang_i,:1,1],alpha = 1,label=f\"{angles[ang_i]/(np.pi):.2f}$\\pi$\")\n",
    "    #ax[0].legend(title='angle', loc = 'upper right',bbox_to_anchor=(1.3,1.1))\n",
    "\n",
    "\n",
    "    ax[1].set_box_aspect(1)\n",
    "    ax[1].set_xlim(0,T2-T1)\n",
    "    ax[1].spines[['right','top']].set_visible(False)\n",
    "    ax[1].set_yticks(np.arange(0,20,2))\n",
    "    ax[1].set_yticklabels([])\n",
    "    ax[1].set_xlabel(\"timesteps\")\n",
    "    ax[1].set_ylabel(\"observed neurons\")\n",
    "    ax[0].set_title(\"latents\")\n",
    "    ax[1].set_title(\"observed\")\n",
    "    plt.savefig(\"inferred.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afe0e77e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4d63cd9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2443fb0f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
