{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset\n",
    "import sys, os\n",
    "#print(\"Filepath\")\n",
    "#print(str(os.path.dirname(os.path.abspath(__file__))))\n",
    "sys.path.append('../')\n",
    "from rnn.vae import VAE\n",
    "from rnn.train import train_VAE\n",
    "from rnn.datasets import Basic_dataset\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from pyrnn.model import RNN, predict\n",
    "from pyrnn.train import train_rnn\n",
    "from rnn.saving import save_model, load_model\n",
    "from pyrnn.train import save_rnn, load_rnn\n",
    "import matplotlib.pyplot as plt\n",
    "from itertools import combinations, chain\n",
    "\n",
    "import matplotlib as mpl\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "# Load osc RNN and generate data\n",
    "\n",
    "tr = 2\n",
    "rnn_osc, model_params, task_params, training_params = load_rnn(\"../data/osc_rnn\")\n",
    "rnn_osc40, model_params40, task_params40, training_params40 = load_rnn(\"../data/osc_rnn40\")\n",
    "\n",
    "U = torch.clone(rnn_osc.rnn.m.detach())\n",
    "V= torch.clone(rnn_osc.rnn.n.detach()*.1/model_params['n_rec'])\n",
    "W_or = U@V\n",
    "U, s, V = torch.linalg.svd(W_or, full_matrices=False)\n",
    "U, s,V = U[:, :tr], s[:tr], V[:tr, :]\n",
    "V = (V.T * s).T\n",
    "print(torch.linalg.norm(W_or-U@V)) #should be 0\n",
    "B = rnn_osc.rnn.b_rec.detach()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "\n",
    "class Trial_gen(Dataset):\n",
    "    def __init__(self, task_params):\n",
    "        self.dur =task_params['dur']\n",
    "        self.data_len = task_params['data_len']\n",
    "        self.n_trials = task_params['n_trials']\n",
    "        self.N = task_params['n_neurons']\n",
    "        self.w = task_params['w']\n",
    "        self.non_lin=task_params['non_lin']\n",
    "        self.R_z = task_params['R_z']\n",
    "        self.R_x = task_params['R_x']\n",
    "        ph0 = torch.randn(1)*np.pi*2\n",
    "        r0  = 1\n",
    "\n",
    "        self.latents = torch.zeros(2, self.data_len, dtype=torch.float32)\n",
    "        self.latents[0,0] = r0*np.cos(ph0)\n",
    "        self.latents[1,0] = r0*np.sin(ph0)\n",
    "        self.latents[:,0]+=torch.randn(2)*self.R_z\n",
    "        for t in range(1, self.data_len):\n",
    "            self.latents[:,t]+=0.9*self.latents[:,t-1]\n",
    "            self.latents[:,t]+=V@self.non_lin(U@self.latents[:,t-1]+B)+torch.randn(2)*self.R_z\n",
    "        self.data = self.non_lin(U@self.latents+B.unsqueeze(1))\n",
    "        self.data = U@self.latents\n",
    "        #self.data/=torch.max(self.data)\n",
    "        self.data+=torch.randn(self.N,self.data_len)*self.R_x\n",
    "        self.data = self.data.T\n",
    "        self.latents = self.latents.T\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.n_trials\n",
    "      \n",
    "    def __getitem__(self, idx):\n",
    "        \"\"\"\n",
    "        Return a trial of length self.dur\n",
    "        Args:\n",
    "            idx (int): trial index, arbitrary as trials are sampled randomly\n",
    "        Returns:\n",
    "            trial (torch.tensor; dim_x x self.dur): trial of length self.dur\n",
    "        \"\"\"\n",
    "        t_start = torch.randint(low=0,high=self.data.shape[0]-self.dur,size=(1,))[0]\n",
    "        t_end = t_start + self.dur\n",
    "        return self.data[t_start:t_end].T\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "# plot example trial plus the latent signal underlying it\n",
    "batch_size=1\n",
    "task_params ={\"dur\":200,\n",
    "              \"data_len\":10000,\n",
    "              \"n_trials\":1,\n",
    "              \"name\":\"Sine\",\n",
    "              \"n_neurons\": 20,\n",
    "              \"w\":.1,\n",
    "              \"R_x\":.1,\n",
    "              \"R_z\":.2,\n",
    "              \"non_lin\":nn.ReLU(),\n",
    "             }\n",
    "task = Trial_gen(task_params)\n",
    "data_loader = DataLoader(\n",
    "   task, batch_size=batch_size, shuffle=True\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "ratesGT = task.data.T\n",
    "latent_codeGT = task.latents.T\n",
    "z0_hat = latent_codeGT[:,0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def relu(x):\n",
    "    return np.maximum(x,0)\n",
    "\n",
    "def dyn_eq(x,y):\n",
    "    z = np.array([x,y])\n",
    "    z=-0.1*z+V.numpy()@relu(U.numpy()@z+B.numpy())\n",
    "    return z[0],z[1]\n",
    "\n",
    "xlims = 3\n",
    "ylims = 3\n",
    "\n",
    "X, Y = np.meshgrid(np.linspace(-xlims, xlims, 30), np.linspace(-ylims, ylims, 30))\n",
    "uGT, vGT = np.zeros_like(X), np.zeros_like(X)\n",
    "NI, NJ = X.shape\n",
    "\n",
    "normGT = np.zeros((NI,NJ))\n",
    "for i in range(NI):\n",
    "    for j in range(NJ):\n",
    "        x, y = X[i, j], Y[i, j]\n",
    "        dx, dy = dyn_eq(x,y)\n",
    "        uGT[i,j] = dx\n",
    "        vGT[i,j] = dy\n",
    "        normGT[i,j]=np.log(np.linalg.norm([dx,dy]))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "vae, vae_params,training_params, task_params=load_model('../models/SW20_1000')\n",
    "print(vae.prior.std_embed_x(vae.prior.R_x)[0])\n",
    "print(vae.prior.std_embed_z(vae.prior.R_z)[0])\n",
    "def orthogonalise_network(vae):\n",
    "    with torch.no_grad():\n",
    "        m_or = vae.prior.transition.m #20,2\n",
    "        n_or = vae.prior.transition.n\n",
    "        J = m_or@n_or\n",
    "        u,s,v = torch.linalg.svd(J)\n",
    "        projection_matrix = u[:,:vae.dim_z].T@m_or\n",
    "        proj_chol = projection_matrix@torch.diag(vae.prior.std_embed_z(vae.prior.R_z))\n",
    "        m_new = u[:,:vae.dim_z]\n",
    "        n_new = (v[:vae.dim_z].T * s[:vae.dim_z]).T\n",
    "        vae.prior.transition.m.copy_(m_new)\n",
    "        vae.prior.transition.n.copy_(n_new)\n",
    "        vae.prior.chol_cov_embed = lambda x: torch.tril(x)\n",
    "        vae.prior.R_z=torch.nn.Parameter(torch.linalg.cholesky(proj_chol@proj_chol.T))\n",
    "        vae.prior.params['scalar_noise_z']=\"Cov\"\n",
    "    return vae\n",
    "vae = orthogonalise_network(vae)\n",
    "print(vae.prior.std_embed_x(vae.prior.R_x)[0])\n",
    "print(vae.prior.std_embed_z(vae.prior.R_z)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    z0=torch.ones(1,2,1,1)\n",
    "    data= task.data\n",
    "    #z0 = vae.encoder.mean_t0(data[0].unsqueeze(0))\n",
    "    z0 = z0_hat \n",
    "    dim_x = 20\n",
    "    dur = 10000#T2-T1\n",
    "    Z = vae.prior.get_latent_time_series(time_steps=dur, z0=z0)\n",
    "    Zn = Z.cpu().detach().numpy()[0,:,:,0]\n",
    "    data_gen = vae.prior.get_observation(Z,noise_scale=1).permute(0,2,1,3).reshape(dur, dim_x).T\n",
    "\n",
    "\n",
    "prior = vae.prior.transition\n",
    "tau = prior.cast_A(prior.AW).detach().numpy().squeeze()\n",
    "V = (prior.n*prior.scaling).detach().numpy()\n",
    "U = prior.m_transform(prior.m).detach().numpy()\n",
    "B = prior.h.detach().numpy()\n",
    "\n",
    "def dyn_eq(x,y):\n",
    "    z = np.array([x,y])\n",
    "    zn=(tau)*z+V@relu(U@z-B)\n",
    "    h=.1\n",
    "    dz = (zn-z)/h\n",
    "    return dz[0],dz[1]\n",
    "\n",
    "\n",
    "X, Y = np.meshgrid(np.linspace(-xlims, xlims, 30), np.linspace(-ylims, ylims, 30))\n",
    "u, v = np.zeros_like(X), np.zeros_like(X)\n",
    "NI, NJ = X.shape\n",
    "\n",
    "norm = np.zeros((NI,NJ))\n",
    "for i in range(NI):\n",
    "    for j in range(NJ):\n",
    "        x, y = X[i, j], Y[i, j]\n",
    "        dx, dy = dyn_eq(x,y)\n",
    "        u[i,j] = dx\n",
    "        v[i,j] = dy\n",
    "        norm[i,j]=np.log(np.linalg.norm([dx,dy]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "#plt.axis('square')\n",
    "#plt.axis([-xlims, xlims, -ylims, ylims])\n",
    "with mpl.rc_context(fname=\"matplotlibrc\"):\n",
    "    fig,ax = plt.subplots(2,2,figsize=(2,2))\n",
    "\n",
    "    T1 = 150\n",
    "    T2 = 250\n",
    "    lT1 = 150\n",
    "    lT2 = 250\n",
    "    ax[0,0].set_box_aspect(1)\n",
    "    ax[0,0].spines[['right','top']].set_visible(False)\n",
    "    ax[0,0].imshow(normGT,extent = [-xlims,xlims,-ylims,ylims], \n",
    "           origin ='lower',cmap ='bone',vmax=np.max(normGT),aspect='auto')\n",
    "    ax[0,0].streamplot(X, Y, uGT, vGT,color='lavender',density=.5,linewidth=.5, arrowsize=.5)\n",
    "\n",
    "    ax[0,0].set_xlim(-xlims,xlims)\n",
    "    ax[0,0].set_ylim(-ylims,ylims)\n",
    "    ax[0,0].set_xticks([])\n",
    "    ax[0,0].set_yticks([-ylims,ylims])\n",
    "    ax[0,0].set_ylabel(r'$z_2$')\n",
    "    ax[0,0].plot(latent_codeGT[0,lT1:lT2].numpy(),latent_codeGT[1,lT1:lT2].numpy(),color='purple',alpha=1)\n",
    "\n",
    "\n",
    "    for i in range(6):\n",
    "        ax[0,1].plot(ratesGT[i,T1:T2]+i*1.5, color='#393e4d',alpha=1,lw=1)\n",
    "    ax[0,1].set_box_aspect(1)\n",
    "    ax[0,1].set_xlim(0,T2-T1)\n",
    "    ax[0,1].spines[['right','top']].set_visible(False)\n",
    "    ax[0,1].set_yticks([])#np.arange(0,20,2))\n",
    "    ax[0,1].set_yticklabels([])\n",
    "    #ax[0,1].set_xlabel(\"timesteps\")\n",
    "    #ax[1].set_ylabel(\"neurons\")\n",
    "    ax[0,1].set_xticks([])\n",
    "    ax[0,0].set_title(\"latents\")\n",
    "    ax[0,1].set_title(\"observed\")\n",
    "\n",
    "\n",
    "    ax[1,0].imshow(norm,extent = [-xlims,xlims,-ylims,ylims], \n",
    "            origin ='lower',cmap ='bone',vmax=np.max(norm),aspect='auto')\n",
    "    ax[1,0].streamplot(X, Y, u, v,color='lavender',density=.5,linewidth=.5, arrowsize=.5)\n",
    "    #plt.axis('square')\n",
    "    #plt.axis([-xlims, xlims, -ylims, ylims])\n",
    "    ax[1,0].set_box_aspect(1)\n",
    "    ax[1,0].spines[['right','top']].set_visible(False)\n",
    "\n",
    "    ax[1,0].set_xlim(-xlims,xlims)\n",
    "    ax[1,0].set_ylim(-ylims,ylims)\n",
    "    ax[1,0].set_xticks([-xlims,xlims])\n",
    "    ax[1,0].set_yticks([-ylims,ylims])\n",
    "    ax[1,0].set_xlabel(r\"$z_1$\")\n",
    "    ax[1,0].set_ylabel(r'$z_2$')\n",
    "    ax[1,0].plot(Zn[0,lT1:lT2],Zn[1,lT1:lT2],color='hotpink')#,lw=1)\n",
    "\n",
    "\n",
    "    ax[0,0].set_xticks([-xlims,xlims])\n",
    "    ax[0,0].set_xticklabels([])\n",
    "    ax[0,1].set_xticks([0,T2-T1])\n",
    "    ax[0,1].set_xticklabels([])\n",
    "    ax[1,1].set_xticks([0,T2-T1])\n",
    "    #ax[1,1].set_xticklabels([])\n",
    "\n",
    "\n",
    "    for i in range(6):\n",
    "        ax[1,1].plot(data_gen[i,T1:T2]+i*1.5, color='#393e4d',alpha=1,lw=1)\n",
    "    ax[1,1].set_box_aspect(1)\n",
    "    ax[1,1].set_xlim(0,T2-T1)\n",
    "    ax[1,1].spines[['right','top']].set_visible(False)\n",
    "    ax[1,1].set_yticks([])#np.arange(0,20,2))\n",
    "    ax[1,1].set_yticklabels([])\n",
    "    ax[1,1].set_xlabel(\"timesteps\")\n",
    "    #ax[1,1].set_ylabel(\"observed neurons\")\n",
    "    ax[1,0].invert_yaxis()\n",
    "    ax[1,0].set_yticklabels([-xlims,xlims])\n",
    "\n",
    "    plt.savefig(\"../figures/Fig1A.svg\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "data_gen.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(Zn.flatten())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "\n",
    "# Plotting the scatter plot\n",
    "plt.figure(figsize=(1, 1))\n",
    "plt.scatter(torch.std(ratesGT,axis=-1), torch.std(data_gen,axis=-1), color='purple', alpha=0.6)\n",
    "plt.title('STD rates')\n",
    "plt.xlabel('Ground truth')\n",
    "plt.ylabel('Generated')\n",
    "#plot x = y line\n",
    "plt.plot([0, 1], [0, 1], color='grey', linestyle='--',zorder=-10)\n",
    "plt.yticks([0,1])\n",
    "plt.xticks([0,1])\n",
    "plt.xlim(0,1)\n",
    "plt.ylim(0,1)\n",
    "plt.savefig(\"../figures/Supp_Fig1A.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting the scatter plot\n",
    "fig = plt.figure(figsize=(2, 1))\n",
    "gs1 = fig.add_gridspec(nrows=2, ncols=4)#, left=0.05, right=0.48, wspace=0.05)\n",
    "ax1 = fig.add_subplot(gs1[:2, 2:])\n",
    "ax21 = fig.add_subplot(gs1[1, 0])\n",
    "ax22 = fig.add_subplot(gs1[1, 1], sharex=ax21, sharey=ax21)\n",
    "ax12 = fig.add_subplot(gs1[0, 0], sharex=ax21, sharey=ax21)\n",
    "ax11 = fig.add_subplot(gs1[0, 1], sharex=ax21, sharey=ax21)\n",
    "ax1.scatter(torch.std(ratesGT,axis=-1), torch.std(data_gen,axis=-1), color='teal', alpha=0.7, s=20,linewidths=0)\n",
    "ax1.scatter(torch.std(ratesGT,axis=-1), torch.std(data_gen,axis=-1)+.4, color='firebrick', alpha=0.7, s=20,linewidths=0)\n",
    "\n",
    "ax1.set_title('Latents hist')\n",
    "\n",
    "ax1.set_title('SD rates')\n",
    "ax1.set_xlabel('Ground truth')\n",
    "#ax1.set_ylabel('Generated')\n",
    "#plot x = y line\n",
    "ax1.plot([0, 1], [0, 1], color='grey', linestyle='--',zorder=-10)\n",
    "ax1.set_yticks([0,1])\n",
    "ax1.set_xticks([0,1])\n",
    "ax1.set_yticklabels([])\n",
    "ax1.set_xlim(0,1)\n",
    "ax1.set_ylim(0,1)\n",
    "ax1.set_box_aspect(1)\n",
    "ax21.set_box_aspect(1)\n",
    "ax22.set_box_aspect(1)\n",
    "ax12.set_box_aspect(1)\n",
    "ax11.set_box_aspect(1)\n",
    "\n",
    "#ax21.set_xticks([])\n",
    "#ax22.set_xticks([])\n",
    "#ax12.set_xticks([])\n",
    "#ax11.set_xticks([])\n",
    "#ax21.set_xticks([-5,0,5])\n",
    "\n",
    "\n",
    "#plt.setp(ax21.get_xticklabels(), visible=False) #bottom left\n",
    "ax21.set_xticks([-5,0,5])\n",
    "ax21.set_xlim([-5,5])\n",
    "ax22.set_xticks([-5,0,5])\n",
    "ax22.set_xlim([-5,5])\n",
    "ax21.set_xticklabels([])\n",
    "ax22.set_xlabel(r'$z_2$')\n",
    "ax21.set_xlabel(r'$z_1$')\n",
    "ax12.set_ylabel('teacher')\n",
    "ax21.set_ylabel('student')\n",
    "ax12.set_title('dist. of latents')\n",
    "\n",
    "plt.setp(ax12.get_xticklabels(), visible=False) # top left\n",
    "#plt.setp(ax22.get_xticklabels(), visible=False) #bottom right\n",
    "plt.setp(ax11.get_xticklabels(), visible=False) #top right\n",
    "\n",
    "ax12. set_yticks([])\n",
    "#plt.setp(ax12.get_yticks(), visible=False) # top left\n",
    "#plt.setp(ax11.get_yticks(), visible=False) #top right\n",
    "#plt.setp(ax22.get_yticks(), visible=False) #bottom right\n",
    "plt.setp(ax21.get_yticks(), visible=False) #bottom left\n",
    "\n",
    "bins = 30\n",
    "ax22.hist(Zn[1].flatten(),density=True,color='hotpink',bins=20)\n",
    "ax21.hist(Zn[0].flatten(),density=True,color='hotpink',bins=20)\n",
    "ax11.hist(latent_codeGT[1].flatten(),density=True,color='purple',bins=20)\n",
    "ax12.hist(latent_codeGT[0].flatten(),density=True,color='purple',bins=20)\n",
    "\n",
    "\n",
    "plt.savefig(\"../figures/Supp_Fig1A.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting the scatter plot\n",
    "fig = plt.figure(figsize=(2, 1))\n",
    "gs1 = fig.add_gridspec(nrows=2, ncols=4)#, left=0.05, right=0.48, wspace=0.05)\n",
    "ax1 = fig.add_subplot(gs1[:2, 2:])\n",
    "ax21 = fig.add_subplot(gs1[1, 0])\n",
    "ax22 = fig.add_subplot(gs1[1, 1], sharex=ax21, sharey=ax21)\n",
    "ax12 = fig.add_subplot(gs1[0, 0], sharex=ax21, sharey=ax21)\n",
    "ax11 = fig.add_subplot(gs1[0, 1], sharex=ax21, sharey=ax21)\n",
    "ax1.scatter(torch.std(ratesGT,axis=-1), torch.std(data_gen,axis=-1), color='teal', alpha=0.7, s=20,linewidths=0)\n",
    "#ax1.scatter(torch.std(ratesGT,axis=-1), torch.std(data_gen,axis=-1)+.4, color='firebrick', alpha=0.7, s=20,linewidths=0)\n",
    "\n",
    "ax1.set_title('Latents hist')\n",
    "\n",
    "ax1.set_title('SD rates')\n",
    "ax1.set_xlabel('teacher')\n",
    "#ax1.set_ylabel('Generated')\n",
    "#plot x = y line\n",
    "ax1.plot([0, 1], [0, 1], color='grey', linestyle='--',zorder=-10)\n",
    "ax1.set_yticks([0,1])\n",
    "ax1.set_xticks([0,1])\n",
    "ax1.set_yticklabels([])\n",
    "ax1.set_xlim(0,1)\n",
    "ax1.set_ylim(0,1)\n",
    "ax1.set_box_aspect(1)\n",
    "ax21.set_box_aspect(1)\n",
    "ax22.set_box_aspect(1)\n",
    "ax12.set_box_aspect(1)\n",
    "ax11.set_box_aspect(1)\n",
    "\n",
    "#ax21.set_xticks([])\n",
    "#ax22.set_xticks([])\n",
    "#ax12.set_xticks([])\n",
    "#ax11.set_xticks([])\n",
    "#ax21.set_xticks([-5,0,5])\n",
    "\n",
    "\n",
    "#plt.setp(ax21.get_xticklabels(), visible=False) #bottom left\n",
    "#ax21.set_xticks([-5,0,5])\n",
    "#ax21.set_xlim([-5,5])\n",
    "#ax22.set_xticks([-5,0,5])\n",
    "#ax22.set_xlim([-5,5])\n",
    "ax21.set_xticklabels([])\n",
    "ax11.set_title(r'$z_2$')\n",
    "ax12.set_title(r'$z_1$')\n",
    "ax21.set_xlabel(\"timelag\")\n",
    "ax12.set_ylabel('teacher')\n",
    "ax21.set_ylabel('student')\n",
    "#ax12.set_title('autocorr. latents')\n",
    "#fig.suptitle(\"autocorr. latents\")\n",
    "plt.setp(ax12.get_xticklabels(), visible=False) # top left\n",
    "#plt.setp(ax22.get_xticklabels(), visible=False) #bottom right\n",
    "plt.setp(ax11.get_xticklabels(), visible=False) #top right\n",
    "\n",
    "ax12.set_yticks([])\n",
    "#plt.setp(ax12.get_yticks(), visible=False) # top left\n",
    "#plt.setp(ax11.get_yticks(), visible=False) #top right\n",
    "#plt.setp(ax22.get_yticks(), visible=False) #bottom right\n",
    "plt.setp(ax21.get_yticks(), visible=False) #bottom left\n",
    "\n",
    "lag=120\n",
    "Corrs=[]\n",
    "for i in range(0,9500,lag):\n",
    "    Corr = np.flip(np.correlate(Zn[1,i:i+lag],Zn[1,i:i+lag*2],mode='valid'))\n",
    "    Corr/=Corr[0]\n",
    "    Corrs.append(Corr)\n",
    "mean = np.mean(Corrs,axis=0)\n",
    "std = np.std(Corrs,axis=0)*1\n",
    "ax22.plot(np.mean(Corrs,axis=0),color='hotpink',zorder=1000)\n",
    "ax22.fill_between(np.arange(lag+1), mean-std ,mean+std, alpha=0.2, color=\"hotpink\")\n",
    "Corrs=[]\n",
    "for i in range(0,9500,lag):\n",
    "    Corr = np.flip(np.correlate(Zn[0,i:i+lag],Zn[0,i:i+lag*2],mode='valid'))\n",
    "    Corr/=Corr[0]\n",
    "    Corrs.append(Corr)\n",
    "mean = np.mean(Corrs,axis=0)\n",
    "std = np.std(Corrs,axis=0)*1\n",
    "ax21.plot(np.mean(Corrs,axis=0),color='hotpink',zorder=1000)\n",
    "ax21.fill_between(np.arange(lag+1), mean-std ,mean+std, alpha=0.2, color=\"hotpink\")\n",
    "\n",
    "ax11.hist(latent_codeGT[1].flatten(),density=True,color='purple',bins=20)\n",
    "for i in range(0,9500,lag):\n",
    "    Corr = np.flip(np.correlate(latent_codeGT[1,i:i+lag],latent_codeGT[1,i:i+lag*2],mode='valid'))\n",
    "    Corr/=Corr[0]\n",
    "    Corrs.append(Corr)\n",
    "mean = np.mean(Corrs,axis=0)\n",
    "std = np.std(Corrs,axis=0)*1\n",
    "ax11.plot(np.mean(Corrs,axis=0),color='purple',zorder=1000)\n",
    "ax11.fill_between(np.arange(lag+1), mean-std ,mean+std, alpha=0.2, color=\"purple\")\n",
    "\n",
    "for i in range(0,9500,lag):\n",
    "    Corr = np.flip(np.correlate(latent_codeGT[0,i:i+lag],latent_codeGT[0,i:i+lag*2],mode='valid'))\n",
    "    Corr/=Corr[0]\n",
    "    Corrs.append(Corr)\n",
    "mean = np.mean(Corrs,axis=0)\n",
    "std = np.std(Corrs,axis=0)*1\n",
    "ax12.plot(np.mean(Corrs,axis=0),color='purple',zorder=1000)\n",
    "ax12.fill_between(np.arange(lag+1), mean-std ,mean+std, alpha=0.2, color=\"purple\")\n",
    "ax12.set_xticks([0,100])\n",
    "ax12.set_xticklabels([0,100])\n",
    "ax12.set_xlim(0,120)\n",
    "#ax12.hist(latent_codeGT[0].flatten(),density=True,color='purple',bins=20)\n",
    "\n",
    "\n",
    "plt.savefig(\"../figures/Supp_Fig1A.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ratesGT.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_pw_correlation(data):\n",
    "    \"\"\"Calculate the cross-correlation matrix for a dataset.\"\"\"\n",
    "    correlation_matrix = np.corrcoef(data, rowvar=False)\n",
    "    return correlation_matrix\n",
    "test_correlation = calculate_pw_correlation(ratesGT.T)\n",
    "gen_correlation = calculate_pw_correlation(data_gen.T)\n",
    "\n",
    "i_upper = np.triu_indices(20, k=1)\n",
    "test_corr_values = test_correlation[i_upper]\n",
    "gen_corr_values = gen_correlation[i_upper]\n",
    "\n",
    "\n",
    "# Plotting the scatter plot\n",
    "plt.figure(figsize=(1, 1))\n",
    "plt.scatter(test_corr_values,gen_corr_values, color='teal', alpha=0.7,linewidth=0, s=10)\n",
    "plt.title('pairwise correlations')\n",
    "plt.xlabel('teacher')\n",
    "plt.ylabel('student')\n",
    "#plot x = y line\n",
    "max = np.max([test_corr_values,gen_corr_values])*1.1\n",
    "plt.plot([0, max], [0, max], color='grey', linestyle='--',zorder=-10)\n",
    "#plt.yticks([0,1])\n",
    "#plt.xticks([0,1])\n",
    "#plt.xlim(0)\n",
    "#plt.ylim(0)\n",
    "plt.savefig(\"../figures/Supp_PWA.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Zn[1].shape\n",
    "Corrs=[]\n",
    "for i in range(0,9500,100):\n",
    "    Corr = np.flip(np.correlate(Zn[1,i:i+150],Zn[1,i:i+300],mode='valid'))\n",
    "    Corr/=Corr[0]\n",
    "    Corrs.append(Corr)\n",
    "    plt.plot(Corr,color='purple',alpha=.05)\n",
    "plt.plot(np.mean(Corrs,axis=0),color='purple',zorder=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "vae.prior.R_z@vae.prior.R_z.T.detach()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "plt.imshow((vae.prior.R_z@vae.prior.R_z.T).detach(),cmap='coolwarm',vmin=-.08,vmax=.08)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "plt.imshow(np.array([[.04,0],[0,.04]]),cmap='coolwarm',vmin=-.07,vmax=.07)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "rnn_osc,model_params,task_params,training_params = load_rnn(\"../data/osc_rnn40\")\n",
    "# Obtain and orthogonalise weights of the teacher model\n",
    "\n",
    "tr = 2 # rank 2\n",
    "U = torch.clone(rnn_osc.rnn.m.detach())\n",
    "V= torch.clone(rnn_osc.rnn.n.detach()*.1/model_params['n_rec'])\n",
    "W_or = U@V\n",
    "U, s, V = torch.linalg.svd(W_or, full_matrices=False)\n",
    "U, s,V = U[:, :tr], s[:tr], V[:tr, :]\n",
    "V = (V.T * s).T\n",
    "\n",
    "B = rnn_osc.rnn.b_rec.detach()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "# Create a dataset using the teacher model\n",
    "\n",
    "class Trial_gen(Dataset):\n",
    "    def __init__(self, task_params):\n",
    "        self.task_params = task_params\n",
    "        self.dur =task_params['dur']\n",
    "        self.n_trials = task_params['n_trials']\n",
    "        self.N = task_params['n_neurons']\n",
    "        self.w = task_params['w']\n",
    "        self.non_lin=task_params['non_lin']\n",
    "        self.R_z = task_params['R_z']\n",
    "\n",
    "        ph0 = torch.randn(self.n_trials)*np.pi*2\n",
    "        r0  = torch.randn(self.n_trials)+1\n",
    "        self.latents = torch.zeros(2, self.n_trials, self.dur, dtype=torch.float32)\n",
    "        self.rates= torch.zeros(self.N, self.n_trials, self.dur, dtype=torch.float32)\n",
    "        self.latents[0,:,0] = r0*np.cos(ph0)\n",
    "        self.latents[1,:,0] = r0*np.sin(ph0)\n",
    "        self.latents[:,:,0]+=torch.randn(2,self.n_trials)*self.R_z\n",
    "\n",
    "        for t in range(1, self.dur):\n",
    "            self.latents[:,:,t]+=0.9*self.latents[:,:,t-1]\n",
    "            self.latents[:,:,t]+=V@self.non_lin(U@self.latents[:,:,t-1]+B.unsqueeze(1))+torch.randn(2,self.n_trials)*self.R_z\n",
    "\n",
    "        if task_params[\"out\"]==\"rates\":\n",
    "            for t in range(self.dur):\n",
    "                self.rates[:,:,t] = U@self.non_lin(self.latents[:,:,t]+B.unsqueeze(1))\n",
    "        elif task_params[\"out\"]==\"currents\":\n",
    "            for t in range(self.dur):\n",
    "                self.rates[:,:,t] = U@self.latents[:,:,t]\n",
    "        self.rates*=task_params['B']\n",
    "        self.rates+=task_params['Bias']\n",
    "\n",
    "        if task_params['obs_rectify'] =='exp':\n",
    "            self.rates =torch.exp(self.rates)\n",
    "        elif task_params['obs_rectify'] =='relu':\n",
    "            self.rates= torch.relu(self.rates)+1e-10\n",
    "        elif task_params['obs_rectify'] =='softplus':\n",
    "            self.rates = torch.nn.functional.softplus(self.rates)\n",
    "        self.data = torch.poisson(self.rates)\n",
    "        self.data_eval= self.data\n",
    "    def __len__(self):\n",
    "        return self.n_trials\n",
    "      \n",
    "    def __getitem__(self, idx):\n",
    "        \"\"\"\n",
    "        Return a trial of length self.dur\n",
    "        Args:\n",
    "            idx (int): trial index, arbitrary as trials are sampled randomly\n",
    "        Returns:\n",
    "            trial (torch.tensor; dim_x x self.dur): trial of length self.dur\n",
    "        \"\"\"\n",
    "        \n",
    "        return self.data[:,idx], torch.zeros(0,self.dur,device=self.data.device)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "\n",
    "task_params ={\"dur\":10000,\n",
    "              \"n_trials\":1,\n",
    "              \"name\":\"Sine\",\n",
    "              \"n_neurons\": rnn_osc.rnn.N,\n",
    "              \"out\":\"currents\",\n",
    "              \"w\":.1,\n",
    "              \"R_z\":0.2,\n",
    "              \"Bias\":-3,\n",
    "              \"B\":4,\n",
    "              \"non_lin\":torch.nn.ReLU(),\n",
    "               \"obs_rectify\":\"softplus\",\n",
    "             }\n",
    "task = Trial_gen(task_params)\n",
    "batch_size= 1\n",
    "data_loader = DataLoader(\n",
    "   task, batch_size=batch_size, shuffle=True\n",
    ")\n",
    "\n",
    "ratesGT = task.data[:,0]\n",
    "latent_codeGT= task.latents[:,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def dyn_eq(x,y):\n",
    "    z = np.array([x,y])\n",
    "    z=-0.1*z+V.numpy()@relu(U.numpy()@z+B.numpy())\n",
    "    return z[0],z[1]\n",
    "\n",
    "xlims = 3\n",
    "ylims = 3\n",
    "\n",
    "XGT, YGT = np.meshgrid(np.linspace(-xlims, xlims, 30), np.linspace(-ylims, ylims, 30))\n",
    "uGT, vGT = np.zeros_like(X), np.zeros_like(X)\n",
    "NI, NJ = X.shape\n",
    "\n",
    "normGT = np.zeros((NI,NJ))\n",
    "for i in range(NI):\n",
    "    for j in range(NJ):\n",
    "        x, y = X[i, j], Y[i, j]\n",
    "        dx, dy = dyn_eq(x,y)\n",
    "        uGT[i,j] = dx\n",
    "        vGT[i,j] = dy\n",
    "        normGT[i,j]=np.log(np.linalg.norm([dx,dy]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "\n",
    "name = \"Sine_40_1000\"\n",
    "vae, vae_params,training_params, task_params=load_model('../models/'+name)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "task_params.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "print(vae.prior.observation.Bias.mean().item(),torch.diag(vae.prior.observation.B).mean().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    z0=torch.ones(1,2,1,1)\n",
    "    data= task.data\n",
    "    #z0 = vae.encoder.mean_t0(data[0].unsqueeze(0))\n",
    "    z0 = z0_hat \n",
    "    dim_x = 40\n",
    "    dur = 10000#T2-T1\n",
    "    Z = vae.prior.get_latent_time_series(time_steps=dur, z0=z0)\n",
    "    Zn = Z.cpu().detach().numpy()[0,:,:,0]\n",
    "    data_gen = vae.prior.get_observation(Z,noise_scale=1).permute(0,2,1,3).reshape(dur, dim_x).T\n",
    "    data_gen = torch.poisson(vae.obs_rectify(data_gen))\n",
    "\n",
    "prior = vae.prior.transition\n",
    "tau = prior.cast_A(prior.AW).detach().numpy().squeeze()\n",
    "V = (prior.n*prior.scaling).detach().numpy()\n",
    "U = prior.m_transform(prior.m).detach().numpy()\n",
    "B = prior.h.detach().numpy()\n",
    "\n",
    "def dyn_eq(x,y):\n",
    "    z = np.array([x,y])\n",
    "    zn=(tau)*z+V@relu(U@z-B)\n",
    "    h=.1\n",
    "    dz = (zn-z)/h\n",
    "    return dz[0],dz[1]\n",
    "\n",
    "xlims_inf = [-4,1]\n",
    "ylims_inf = [-5,0]\n",
    "\n",
    "X, Y = np.meshgrid(np.linspace(xlims_inf[0], xlims_inf[1], 30), np.linspace(ylims_inf[0], ylims_inf[1], 30))\n",
    "u, v = np.zeros_like(X), np.zeros_like(X)\n",
    "NI, NJ = X.shape\n",
    "\n",
    "norm = np.zeros((NI,NJ))\n",
    "for i in range(NI):\n",
    "    for j in range(NJ):\n",
    "        x, y = X[i, j], Y[i, j]\n",
    "        dx, dy = dyn_eq(x,y)\n",
    "        u[i,j] = dx\n",
    "        v[i,j] = dy\n",
    "        norm[i,j]=np.log(np.linalg.norm([dx,dy]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def orthogonalise_network(vae):\n",
    "    with torch.no_grad():\n",
    "        m_or = vae.prior.transition.m #20,2\n",
    "        n_or = vae.prior.transition.n\n",
    "        J = m_or@n_or\n",
    "        u,s,v = torch.linalg.svd(J)\n",
    "        projection_matrix = u[:,:2].T@m_or\n",
    "        proj_chol = projection_matrix@torch.diag(vae.prior.std_embed_z(vae.prior.R_z))\n",
    "        m_new = u[:,:2]\n",
    "        n_new = (v[:2].T * s[:2]).T\n",
    "        vae.prior.transition.m.copy_(m_new)\n",
    "        vae.prior.transition.n.copy_(n_new)\n",
    "        vae.prior.chol_cov_embed = lambda x: torch.tril(x)\n",
    "        vae.prior.R_z=torch.nn.Parameter(torch.cholesky(proj_chol@proj_chol.T))\n",
    "        vae.prior.params['scalar_noise_z']=\"Cov\"\n",
    "    return vae"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "bins = np.linspace(0, 100, 100)\n",
    "mean_hist = np.zeros(99)\n",
    "for i in range(40):\n",
    "    isi = np.diff(np.where(data_gen.detach().numpy()[i])[0])\n",
    "    real_hist, x1 = np.histogram(isi, bins,density=True)\n",
    "    mean_hist += real_hist\n",
    "\n",
    "mean_hist_gen = np.zeros(99)\n",
    "for i in range(40):\n",
    "    isi_hat = np.diff(np.where(ratesGT.detach().numpy()[i])[0])\n",
    "    gen_hist, x1 = np.histogram(isi_hat, bins,density=True)\n",
    "    mean_hist_gen += gen_hist\n",
    "\n",
    "\n",
    "plt.xlim(0, 60)\n",
    "plt.bar(bins[0:-1], mean_hist, width=2, color='purple', alpha=1, label=\"Real\")\n",
    "plt.bar(bins[0:-1], mean_hist_gen, width=2, color='teal', alpha=0.5, label=\"Generated\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_pw_correlation(data):\n",
    "    \"\"\"Calculate the cross-correlation matrix for a dataset.\"\"\"\n",
    "    correlation_matrix = np.corrcoef(data, rowvar=False)\n",
    "    return correlation_matrix\n",
    "test_correlation = calculate_pw_correlation(ratesGT.T)\n",
    "gen_correlation = calculate_pw_correlation(data_gen.T)\n",
    "\n",
    "i_upper = np.triu_indices(40, k=1)\n",
    "test_corr_values = test_correlation[i_upper]\n",
    "gen_corr_values = gen_correlation[i_upper]\n",
    "\n",
    "\n",
    "# Plotting the scatter plot\n",
    "plt.figure(figsize=(1, 1))\n",
    "plt.scatter(test_corr_values,gen_corr_values, color='teal', alpha=0.2,linewidth=0, s=10)\n",
    "plt.title('pairwise correlations')\n",
    "plt.xlabel('teacher')\n",
    "plt.ylabel('student')\n",
    "#plot x = y line\n",
    "max = np.max([test_corr_values,gen_corr_values])*1.1\n",
    "plt.plot([0, max], [0, max], color='grey', linestyle='--',zorder=-10)\n",
    "#plt.yticks([0,1])\n",
    "#plt.xticks([0,1])\n",
    "#plt.xlim(0)\n",
    "#plt.ylim(0)\n",
    "plt.savefig(\"../figures/Supp_PWB.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "\n",
    "vmax = 2\n",
    "#plt.axis('square')\n",
    "#plt.axis([-xlims, xlims, -ylims, ylims])\n",
    "with mpl.rc_context(fname=\"matplotlibrc\"):\n",
    "    fig,ax = plt.subplots(2,2,figsize=(2,2))\n",
    "\n",
    "    T1 = 150\n",
    "    T2 = 250\n",
    "    lT1 = 150\n",
    "    lT2 = 250\n",
    "    ax[0,0].set_box_aspect(1)\n",
    "    ax[0,0].spines[['right','top']].set_visible(False)\n",
    "    ax[0,0].imshow(normGT,extent = [-xlims,xlims,-ylims,ylims], \n",
    "           origin ='lower',cmap ='bone',vmax=np.max(normGT),aspect='auto')\n",
    "    ax[0,0].streamplot(XGT, YGT, uGT, vGT,color='lavender',density=.5,linewidth=.5, arrowsize=.5)\n",
    "\n",
    "    ax[0,0].set_xlim(-xlims,xlims)\n",
    "    ax[0,0].set_ylim(-ylims,ylims)\n",
    "    ax[0,0].set_xticks([-xlims,xlims])\n",
    "    ax[0,0].set_xticklabels([])\n",
    "    ax[0,0].set_yticks([-ylims,ylims])\n",
    "    ax[0,0].set_ylabel(r'$z_2$')\n",
    "    ax[0,0].plot(latent_codeGT[0,lT1:lT2].numpy(),latent_codeGT[1,lT1:lT2].numpy(),color='purple',alpha=1)\n",
    "\n",
    "\n",
    "    #for i in range(10):\n",
    "    ax[0,1].imshow(ratesGT[:,T1:T2], cmap='Grays',aspect='auto',vmax = vmax)#,alpha=1,lw=1)\n",
    "    \n",
    "    #ax[0,1].set_box_aspect(1)\n",
    "    ax[0,1].set_xlim(0,T2-T1)\n",
    "\n",
    "    ax[0,1].spines[['right','top']].set_visible(False)\n",
    "    ax[0,1].set_yticks([])#np.arange(0,20,2))\n",
    "    ax[0,1].set_yticklabels([])\n",
    "    #ax[0,1].set_xlabel(\"timesteps\")\n",
    "    #ax[1].set_ylabel(\"neurons\")\n",
    "    ax[0,1].set_xticks([])\n",
    "    ax[0,0].set_title(\"latents\")\n",
    "    ax[0,1].set_title(\"observed\")\n",
    "    ax[0,1].set_box_aspect(1)\n",
    "\n",
    "\n",
    "    ax[1,0].imshow(norm,extent = [xlims_inf[0], xlims_inf[1],ylims_inf[0], ylims_inf[1]], \n",
    "            origin ='lower',cmap ='bone',vmax=np.max(norm),aspect='auto')\n",
    "    ax[1,0].streamplot(X, Y, u, v,color='lavender',density=.5,linewidth=.5, arrowsize=.5)\n",
    "    #plt.axis('square')\n",
    "    #plt.axis([-xlims, xlims, -ylims, ylims])\n",
    "    ax[1,0].set_box_aspect(1)\n",
    "    ax[1,0].spines[['right','top']].set_visible(False)\n",
    "\n",
    "    ax[1,0].set_xlim(xlims_inf[0], xlims_inf[1])\n",
    "    ax[1,0].set_ylim(ylims_inf[0], ylims_inf[1])\n",
    "    ax[1,0].set_xticks([xlims_inf[0], xlims_inf[1]])\n",
    "    ax[1,0].set_yticks([ylims_inf[0], ylims_inf[1]])\n",
    "    ax[1,0].set_xlabel(r\"$z_1$\")\n",
    "    ax[1,0].set_ylabel(r'$z_2$')\n",
    "    ax[1,0].plot(Zn[0,lT1:lT2],Zn[1,lT1:lT2],color='hotpink')#,lw=1)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    ax[1,1].imshow(data_gen[:,T1:T2], cmap='Grays',aspect='auto',vmax = vmax)#,alpha=1,lw=1)\n",
    "\n",
    "    ax[1,1].set_box_aspect(1)\n",
    "    ax[1,1].set_xlim(0,T2-T1)\n",
    "    ax[1,1].set_xticks([0,T2-T1])\n",
    "\n",
    "    ax[1,1].spines[['right','top']].set_visible(False)\n",
    "    ax[1,1].set_yticks([])#np.arange(0,20,2))\n",
    "    ax[0,1].set_xticks([0,T2-T1])\n",
    "    ax[0,1].set_xticklabels([])\n",
    "    ax[1,1].set_yticklabels([])\n",
    "    ax[1,1].set_xlabel(\"timesteps\")\n",
    "    #ax[1,1].set_ylabel(\"observed neurons\")\n",
    "    \"\"\"\n",
    "    ax[0,2].set_title(\"ISI distribution\")\n",
    "    ax[0,2].set_box_aspect(1)\n",
    "    ax[1,2].set_box_aspect(1)\n",
    "    ax[0,2].bar(bins[0:-1], mean_hist, width=2.5, color='purple', alpha=0.5, label=\"Real\")\n",
    "    ax[1,2].bar(bins[0:-1], mean_hist_gen, width=2.5, color='hotpink', alpha=0.5, label=\"Generated\")\n",
    "    ax[0,2].set_xlim(0, 50)\n",
    "    ax[1,2].set_xlim(0, 50)\n",
    "    ax[0,2].set_ylim(0, 6)\n",
    "    ax[1,2].set_ylim(0, 6)\n",
    "    ax[0,2].set_xticks([0,50])\n",
    "    ax[1,2].set_xticks([0,50])\n",
    "    ax[0,2].set_xticks([0,50])\n",
    "    ax[0,2].set_xticklabels([])\n",
    "\n",
    "    ax[0,2].set_yticks([])\n",
    "    ax[1,2].set_yticks([])\n",
    "    ax[1,2].set_xlabel(\"timesteps\")\n",
    "    \"\"\"\n",
    "\n",
    "    plt.savefig(\"../figures/Fig1B.svg\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "spikes_hat = ratesGT.T.detach().numpy()\n",
    "spikes = data_gen.T.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bins = np.linspace(0, 100, 100)\n",
    "CVs_isi = []\n",
    "Means_isi = []\n",
    "Means = []\n",
    "for i in range(40):\n",
    "    isi = np.diff(np.where(spikes.T[i])[0])\n",
    "    CVs_isi.append(np.mean(isi)/np.std(isi))\n",
    "    Means_isi.append(np.mean(isi))\n",
    "Means=np.mean(spikes,axis=0)\n",
    "\n",
    "CVs_isi_hat = []\n",
    "Means_isi_hat = []\n",
    "Means_hat = []\n",
    "for i in range(40):\n",
    "    isi_hat = np.diff(np.where(spikes_hat.T[i])[0])\n",
    "    CVs_isi_hat.append(np.mean(isi_hat)/np.std(isi_hat))\n",
    "    Means_isi_hat.append(np.mean(isi_hat))\n",
    "Means_hat=np.mean(spikes_hat,axis=0)\n",
    "\n",
    "# Plotting the scatter plot\n",
    "fig, ax = plt.subplots(1,2,figsize=(2, 1))\n",
    "\n",
    "#ax[2].set_xlabel('Ground truth')\n",
    "ax[0].set_ylabel('Generated')\n",
    "#plot x = y line\n",
    "\n",
    "ax[1].scatter(Means_isi, Means_isi_hat, color='teal', alpha=0.7,s=20,linewidth=0)\n",
    "ax[1].set_title('mean ISI')\n",
    "#ax[1].set_xlabel('Ground truth')\n",
    "#plot x = y line\n",
    "max = np.max([Means_isi,Means_isi_hat])*1.1\n",
    "ax[1].plot([0, max], [0, max], color='grey', linestyle='--',zorder=-1000)\n",
    "ax[1].set_aspect(1)\n",
    "\n",
    "ax[0].scatter(Means, Means_hat, color='teal', alpha=0.7,s=20,linewidth=0)\n",
    "ax[0].set_title('student')\n",
    "ax[0].set_xlabel('teacher')\n",
    "#plot x = y line\n",
    "max = np.max([Means,Means_hat])*1.1\n",
    "\n",
    "ax[0].plot([0, max], [0, max], color='grey', linestyle='--',zorder=-1000)\n",
    "ax[0].set_aspect(1)\n",
    "\n",
    "ax[0].set_yscale('log')\n",
    "ax[0].set_xscale('log')\n",
    "ax[0].set_xticks([0.1,.3])\n",
    "ax[0].set_yticks([0.1,.3])\n",
    "ax[1].set_xlim(0)\n",
    "ax[1].set_ylim(0)\n",
    "from matplotlib.ticker import StrMethodFormatter, NullFormatter\n",
    "ax[0].yaxis.set_major_formatter(StrMethodFormatter('{x:.1f}'))\n",
    "ax[0].yaxis.set_minor_formatter(NullFormatter())\n",
    "ax[0].xaxis.set_major_formatter(StrMethodFormatter('{x:.1f}'))\n",
    "ax[0].xaxis.set_minor_formatter(NullFormatter())\n",
    "#plt.tight_layout()\n",
    "ax[0].set_yticklabels([])\n",
    "ax[1].set_yticklabels([])\n",
    "\n",
    "plt.savefig(\"../figures/Supp_Fig1B\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "bins = np.linspace(0, 100, 100)\n",
    "CVs_isi = []\n",
    "Means_isi = []\n",
    "Means = []\n",
    "for i in range(40):\n",
    "    isi = np.diff(np.where(spikes.T[i])[0])\n",
    "    CVs_isi.append(np.mean(isi)/np.std(isi))\n",
    "    Means_isi.append(np.mean(isi))\n",
    "Means=np.mean(spikes,axis=0)\n",
    "\n",
    "CVs_isi_hat = []\n",
    "Means_isi_hat = []\n",
    "Means_hat = []\n",
    "for i in range(40):\n",
    "    isi_hat = np.diff(np.where(spikes_hat.T[i])[0])\n",
    "    CVs_isi_hat.append(np.mean(isi_hat)/np.std(isi_hat))\n",
    "    Means_isi_hat.append(np.mean(isi_hat))\n",
    "Means_hat=np.mean(spikes_hat,axis=0)\n",
    "\n",
    "# Plotting the scatter plot\n",
    "fig, ax = plt.subplots(1,3,figsize=(3, 1))\n",
    "\n",
    "ax[2].scatter(CVs_isi, CVs_isi_hat, color='teal', alpha=0.6)\n",
    "ax[2].set_title('CV ISI')\n",
    "#ax[2].set_xlabel('Ground truth')\n",
    "ax[0].set_ylabel('Generated')\n",
    "#plot x = y line\n",
    "ax[2].plot([.3, 1.3], [.3, 1.3], color='grey', linestyle='--',zorder=-1000)\n",
    "ax[2].set_aspect(1)\n",
    "ax[1].scatter(Means_isi, Means_isi_hat, color='teal', alpha=0.6)\n",
    "ax[1].set_title('mean ISI')\n",
    "#ax[1].set_xlabel('Ground truth')\n",
    "#plot x = y line\n",
    "max = np.max([Means_isi,Means_isi_hat])*1.1\n",
    "ax[1].plot([0, max], [0, max], color='grey', linestyle='--',zorder=-1000)\n",
    "ax[1].set_aspect(1)\n",
    "\n",
    "ax[0].scatter(Means, Means_hat, color='teal', alpha=0.6)\n",
    "ax[0].set_title('mean rates')\n",
    "ax[0].set_xlabel('Ground truth')\n",
    "#plot x = y line\n",
    "max = np.max([Means,Means_hat])*1.1\n",
    "\n",
    "ax[0].plot([0, max], [0, max], color='grey', linestyle='--',zorder=-1000)\n",
    "ax[0].set_aspect(1)\n",
    "ax[2].set_aspect(1)\n",
    "\n",
    "ax[0].set_yscale('log')\n",
    "ax[0].set_xscale('log')\n",
    "ax[0].set_xticks([0.1,.3])\n",
    "ax[0].set_yticks([0.1,.3])\n",
    "ax[1].set_xlim(0)\n",
    "ax[1].set_ylim(0)\n",
    "from matplotlib.ticker import StrMethodFormatter, NullFormatter\n",
    "ax[0].yaxis.set_major_formatter(StrMethodFormatter('{x:.1f}'))\n",
    "ax[0].yaxis.set_minor_formatter(NullFormatter())\n",
    "ax[0].xaxis.set_major_formatter(StrMethodFormatter('{x:.1f}'))\n",
    "ax[0].xaxis.set_minor_formatter(NullFormatter())\n",
    "#plt.tight_layout()\n",
    "ax[0].set_yticklabels([])\n",
    "ax[1].set_yticklabels([])\n",
    "ax[2].set_yticklabels([])\n",
    "\n",
    "plt.savefig(\"../figures/Supp_Fig1B\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "class Reaching(Dataset):\n",
    "    def __init__(self, task_params):\n",
    "        \"\"\"\n",
    "        Initialize a Reaching task\n",
    "        Args:\n",
    "            task_params: dictionary containing task parameters\n",
    "        \"\"\"\n",
    "\n",
    "        self.task_params = task_params\n",
    "        \n",
    "\n",
    "\n",
    "    def __len__(self):\n",
    "        \"\"\"Arbitrary number of trials, as they are randomly generated anyway\"\"\"\n",
    "        return 200#self.task_params['n_stim']\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        \"\"\"\n",
    "        Returns a trial\n",
    "\n",
    "        Args:\n",
    "            idx, trial index\n",
    "\n",
    "        Returns:\n",
    "            input, Tensor of size [seq_len, n_inp]\n",
    "            target, Tensor of size [seq_len, n_inp]\n",
    "            mask, Tensor of size [seq_len, n_inp]\n",
    "        \"\"\"\n",
    "        phase = np.pi*2*idx/self.__len__()\n",
    "        inputs = torch.zeros(self.task_params['trial_len'],2)\n",
    "        cp = np.cos(phase)\n",
    "        sp = np.sin(phase)\n",
    "        onset = torch.randint(low=self.task_params['onset'][0],high=self.task_params['onset'][1],size=(1,))[0]\n",
    "        stim_dur = torch.randint(low=self.task_params['stim_dur'][0],high=self.task_params['stim_dur'][1],size=(1,))[0]\n",
    "        delay_dur = torch.randint(low=self.task_params['delay_dur'][0],high=self.task_params['delay_dur'][1],size=(1,))[0]\n",
    "        delay_end = delay_dur+onset+stim_dur\n",
    "        inputs[onset:onset+stim_dur,0]=cp\n",
    "        inputs[onset:onset+stim_dur,1]=sp\n",
    "        #inputs[delay_end:,2]=1\n",
    "        targets = torch.zeros(self.task_params['trial_len'],2)\n",
    "        targets[delay_end:,0]=cp\n",
    "        targets[delay_end:,1]=sp\n",
    "        mask=torch.zeros_like(targets)\n",
    "        mask[onset+stim_dur:]=1\n",
    "        #mask[delay_end:]=5\n",
    "\n",
    "        return inputs,targets,mask\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "rnn_osc,model_params,task_params,training_params = load_rnn(\"../data/reach_rnn\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "task_params = {\n",
    "        'onset': [25,50], # time till target stimulus onset (uniform between)\n",
    "        'trial_len' : 200, # trial duration\n",
    "        'stim_dur' : [25,50], # target stimulus duration\n",
    "        'delay_dur': [10,11],# time till Movement onset cue \n",
    "        'n_stim': 12 #number of stimuli locations\n",
    "}\n",
    "reaching = Reaching(task_params)   \n",
    "stimulus, target, loss_mask = reaching[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def powerset(iterable):\n",
    "    s = list(iterable)\n",
    "    return chain.from_iterable(combinations(s,r) for r in range(len(s)+1))\n",
    "\n",
    "def find_fixed_points(a,V,U,hz,h,d=1):\n",
    "    \"\"\"\n",
    "    Find fixed points of the model\n",
    "    Args:\n",
    "        a: numpy array of shape (R,)\n",
    "        V: numpy array of shape (N,R)\n",
    "        U: numpy array of shape (N,R)\n",
    "        hz: numpy array of shape (R,)\n",
    "        h: numpy array of shape (N,)\n",
    "    Returns:\n",
    "        D_list: numpy array of shape (n_Ds,N) containing all subspaces\n",
    "        D_inds: list of indices of subspaces in D_list that are fixed points\n",
    "        z_list: list of fixed points\n",
    "    \"\"\"\n",
    "    \n",
    "    n_inverses=0\n",
    "    N=U.shape[0]\n",
    "    R=U.shape[1]\n",
    "\n",
    "    # First solve for all intersection of hyperplanes\n",
    "    intersect_inds = np.array(list(combinations(np.arange(N),R)))\n",
    "    print(len(intersect_inds))\n",
    "\n",
    "    par_inds = []\n",
    "    if d == 2:\n",
    "        ni = N//2\n",
    "        for i, el in enumerate(intersect_inds):\n",
    "            if el[0]==el[1]+ni or el[1]==el[0]+ni:\n",
    "                par_inds.append(i)\n",
    "        intersect_inds=np.delete(intersect_inds,par_inds,axis=0)\n",
    "        print(\"removed parallel lines\")\n",
    "        print(len(intersect_inds))\n",
    "    \n",
    "    n_Ds_initial = len(list(powerset(range(R))))*len(intersect_inds)\n",
    "    print(len(list(powerset(range(R)))))\n",
    "    D_list = np.zeros((n_Ds_initial,N),dtype='uint8')\n",
    "    it = 0\n",
    "    n_singular = 0\n",
    "    for inds in intersect_inds:\n",
    "        b_hat = h[inds]\n",
    "        U_hat = U[inds]\n",
    "        if np.linalg.matrix_rank(U_hat)>0:#==R:\n",
    "            n_inverses+=1\n",
    "            z = np.linalg.solve(U_hat,b_hat)\n",
    "            # Find all subspaces bordering to this intersection\n",
    "            x = U@z-h\n",
    "            D_init = np.array(x > 0).astype('uint8')\n",
    "            D_init[inds]=0\n",
    "            D_list[it]=D_init\n",
    "            it+=1\n",
    "            D_inds = list(powerset(inds))[1:]\n",
    "            for D_ind in D_inds:\n",
    "                D=np.copy(D_init)\n",
    "                D[np.array(D_ind)]=1\n",
    "                D_list[it]=D\n",
    "                it+=1\n",
    "        else:\n",
    "            n_singular+=1\n",
    "    print(\"n singular\")\n",
    "    print(n_singular)\n",
    "    # Throw away duplicate subspaces\n",
    "    print(D_list.shape)\n",
    "    D_list = np.unique(D_list,axis=0)\n",
    "    print(D_list.shape)\n",
    "\n",
    "    # Finally solve for fixed points\n",
    "    z_list = []\n",
    "    D_inds = []\n",
    "    for D_ind,D_init in enumerate(D_list):\n",
    "\n",
    "        A = -np.eye(R)+np.diag(a)+V.T@np.diag(D_init)@U\n",
    "        b = V.T@np.diag(D_init)@h+hz\n",
    "        z_hat = np.linalg.solve(A,b)\n",
    "        n_inverses+=1\n",
    "\n",
    "        x_hat = U@z_hat-h\n",
    "        if np.allclose(D_init,np.array(x_hat > 0).astype('uint8')):\n",
    "            print(\"Found a fixed point\")\n",
    "            print(z_hat)\n",
    "            z_list.append(z_hat)\n",
    "            D_inds.append(D_ind)\n",
    "    print(\"Done, found \" + str(len(z_list)) + \" fixed points\")\n",
    "    return D_list,D_inds,z_list, n_singular, n_inverses\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "orth=True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "U = torch.clone(rnn_osc.rnn.m.detach())\n",
    "V= torch.clone(rnn_osc.rnn.n.detach()*.1/model_params['n_rec'])\n",
    "if orth:\n",
    "    W_or = U@V\n",
    "\n",
    "    # Orthogonalise the weight matrix\n",
    "    tr = 2\n",
    "    U, s, V = torch.linalg.svd(W_or, full_matrices=False)\n",
    "    U, s,V = U[:, :tr], s[:tr], V[:tr, :]\n",
    "    V = (V.T * s).T\n",
    "    W=U@V\n",
    "    print(torch.linalg.norm(W_or-W)) #should be (very close to) 0\n",
    "# Bias and Input weights\n",
    "\n",
    "B = rnn_osc.rnn.b_rec.detach()\n",
    "I = rnn_osc.rnn.w_inp.detach()\n",
    "D_list,D_inds,z_list, n_singular,n_inverses = find_fixed_points(np.ones(2)*.9,V.T.numpy(),U.numpy(),0,-B.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "print(n_inverses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "n_angles = 8\n",
    "dur = 50\n",
    "R_z=0.05\n",
    "\n",
    "n_repeats = 400\n",
    "z_all = torch.zeros(n_repeats,n_angles,dur,2)\n",
    "X_all_GT =np.zeros((n_repeats,n_angles,dur,60))\n",
    "\n",
    "for ri in range(n_repeats):\n",
    "    angles = torch.arange(0,2*np.pi,2*np.pi/n_angles)\n",
    "    input = torch.zeros(n_angles, dur,2)\n",
    "    z=torch.zeros(n_angles, dur,2)\n",
    "    x=np.zeros((n_angles, dur,60))\n",
    "\n",
    "    r0=0\n",
    "    ph0 =torch.rand(n_angles)*2*np.pi\n",
    "    Trelu = nn.ReLU()\n",
    "    for i, angle in enumerate(angles):\n",
    "        input[i, 0:25,0]=torch.cos(angle)\n",
    "        input[i, 0:25,1]=torch.sin(angle)\n",
    "    z[:,0,0] = r0*torch.cos(ph0)\n",
    "    z[:,0,1] = r0*torch.sin(ph0)\n",
    "    #print(self.latents.shape)\n",
    "    for t in range(1, dur):\n",
    "        z[:,t]=0.9*z[:,t-1]\n",
    "        X = U@z[:,t-1].T+B.unsqueeze(1) + (input[:,t-1]@I).T\n",
    "        z[:,t]+=(V@Trelu(X)).T+torch.randn(len(angles),2)*R_z\n",
    "        x[:,t]=X.T.detach().numpy()\n",
    "\n",
    "    z_all[ri]=z\n",
    "    X_all_GT[ri]=x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def Relu_derivative(x):\n",
    "    return np.array(x>0).astype('float')\n",
    "\n",
    "def jacobian(V,U,h,a,z):\n",
    "    \"\"\"\n",
    "    Compute the jacobian of the model\n",
    "    Args:\n",
    "        V: numpy array of shape (N,R)\n",
    "        U: numpy array of shape (N,R)\n",
    "        h: numpy array of shape (N,)\n",
    "    Returns:\n",
    "        J: numpy array of shape (N,N) representing the jacobian\n",
    "    \"\"\"\n",
    "    x = U@z\n",
    "    D1 = np.diag(Relu_derivative(x+h))\n",
    "    J = a*np.eye(len(z)) + V@(D1)@U\n",
    "    return J\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "e,v = np.linalg.eig(jacobian(V.numpy(),U.numpy(),B.numpy(),np.diag(np.ones(2)*.9),z_list[0]))\n",
    "print(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "from matplotlib.markers import MarkerStyle\n",
    "dot_s = 10\n",
    "dot_z = 100\n",
    "dot_ew=.4\n",
    "dot_s_st=20\n",
    "dot_fill = 'gainsboro'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def relu(x):\n",
    "    return np.maximum(x,0)\n",
    "\n",
    "#plot_color_gradients('Cyclic', ['twilight', 'twilight_shifted', 'hsv'])\n",
    "prop_cycle =[plt.cm.hsv(i) for i in np.arange(0, 1, 1/len(angles))]\n",
    "phase = np.pi\n",
    "r=0\n",
    "cp = np.cos(phase)\n",
    "sp = np.sin(phase)\n",
    "\n",
    "inp1 = cp*r\n",
    "inp2 = sp*r\n",
    "u_in = np.array([inp1,inp2])\n",
    "\n",
    "def dyn_eq(x,y):\n",
    "    z = np.array([x,y])\n",
    "    z=-0.1*z+V.numpy()@relu(U.numpy()@z+u_in@I.numpy()+B.numpy())\n",
    "    return z[0],z[1]\n",
    "\n",
    "\n",
    "\n",
    "xlims = 3\n",
    "ylims = 3\n",
    "\n",
    "X, Y = np.meshgrid(np.linspace(-xlims, xlims, 30), np.linspace(-ylims, ylims, 30))\n",
    "u, v = np.zeros_like(X), np.zeros_like(X)\n",
    "NI, NJ = X.shape\n",
    "\n",
    "norm = np.zeros((NI,NJ))\n",
    "for i in range(NI):\n",
    "    for j in range(NJ):\n",
    "        x, y = X[i, j], Y[i, j]\n",
    "        dx, dy = dyn_eq(x,y)\n",
    "        u[i,j] = dx\n",
    "        v[i,j] = dy\n",
    "        norm[i,j]=np.log(np.linalg.norm([dx,dy]))\n",
    "#plt.axis('square')\n",
    "#plt.axis([-xlims, xlims, -ylims, ylims])\n",
    "with mpl.rc_context(fname=\"matplotlibrc\"):\n",
    "    fig,ax = plt.subplots(1,2,figsize=(2,1),dpi=200)\n",
    "\n",
    "    T1 = 0\n",
    "    T2 = 300\n",
    "    ax[0].set_box_aspect(1)\n",
    "    ax[0].spines[['right','top']].set_visible(False)\n",
    "    ax[0].imshow(norm,extent = [-xlims,xlims,-ylims,ylims], \n",
    "           origin ='lower',cmap ='bone',vmax=np.max(norm),aspect='auto')\n",
    "    ax[0].streamplot(X, Y, u, v,color='lavender',density=.5,linewidth=.5,arrowsize=.5)\n",
    "\n",
    "    ax[0].set_xlim(-xlims,xlims)\n",
    "    ax[0].set_ylim(-ylims,ylims)\n",
    "    ax[0].set_xticks([-xlims,xlims])\n",
    "    ax[0].set_yticks([-ylims,ylims])\n",
    "    ax[0].set_xlabel(r\"$z_1$\")\n",
    "    ax[0].set_ylabel(r'$z_2$')\n",
    "    #ax[0].plot(latent_code[0,T1:T2].numpy(),latent_code[1,T1:T2].numpy(),color='purple',lw=3)\n",
    "    ax[0].set_prop_cycle('color',prop_cycle)\n",
    "    for z in z_all:\n",
    "        for ang_i in range(n_angles):\n",
    "            ax[0].plot(z[ang_i,:,0],z[ang_i,:,1],alpha = .75,zorder=1000)\n",
    "\n",
    "    for ang_i in[0,1,-1]:\n",
    "        ax[0].plot(z[ang_i,:1,0],z[ang_i,:1,1],alpha = .75,label=f\"{angles[ang_i]/(np.pi):.2f}$\\pi$\",zorder=-1)\n",
    "    ax[0].legend(title='angle', loc = 'upper right',bbox_to_anchor=(2,1.3),fontsize=6)\n",
    "    #for i in range(10):\n",
    "    #    ax[1].plot(rates[i,T1:T2]+i*2, color='black',alpha=.7)\n",
    "    ax[1].set_box_aspect(1)\n",
    "    ax[1].set_xlim(0,T2-T1)\n",
    "    ax[1].spines[['right','top']].set_visible(False)\n",
    "    ax[1].set_yticks(np.arange(0,20,2))\n",
    "    ax[1].set_yticklabels([])\n",
    "    ax[1].set_xlabel(\"timesteps\")\n",
    "    ax[1].set_ylabel(\"neurons\")\n",
    "    ax[0].set_title(\"latents\")\n",
    "    #ax[1].set_title(\"observed\")\n",
    "    # turn of ax 1\n",
    "    ax[1].axis('off')\n",
    "\n",
    "    for z in z_list:\n",
    "        #ax[0].scatter(z[0],z[1],color='red',zorder = 0)\n",
    "        e,v = np.linalg.eig(jacobian(V.numpy(),U.numpy(),B.numpy(),np.diag(np.ones(2)*.9),z))\n",
    "        if abs(e[0])>1 and abs(e[1])>1:\n",
    "            #if np.iscomplex(e[0]):\n",
    "            #    ax[0].scatter(z[0],z[1],color='black',marker=\"*\",zorder=dot_z,s=dot_s_st)\n",
    "            #else:\n",
    "            ax[0].scatter(z[0],z[1], s=dot_s, c=dot_fill, edgecolor=\"black\",lw=dot_ew,\n",
    "                                marker=MarkerStyle(\"o\"),zorder=dot_z)\n",
    "        elif abs(e[0])<1 and abs(e[1])<1:\n",
    "            #if np.iscomplex(e[0]):\n",
    "            #    ax[0].scatter(z[0],z[1],color='white',marker=\"*\", edgecolor=\"black\",lw=dot_ew,zorder=dot_z,s=dot_s_st)\n",
    "            #else:\n",
    "            ax[0].scatter(z[0],z[1], s=dot_s, c='black', edgecolor=\"black\",lw=dot_ew,\n",
    "                                marker=MarkerStyle(\"o\"),zorder=dot_z)\n",
    "        else: #saddle\n",
    "            ax[0].scatter(z[0],z[1], s=dot_s, c=dot_fill, edgecolor=\"black\",lw=dot_ew,\n",
    "                            marker=MarkerStyle(\"o\", fillstyle=\"right\"),zorder=dot_z)\n",
    "            ax[0].scatter(z[0],z[1], s=dot_s, c='black', edgecolor=\"black\",lw=dot_ew,\n",
    "                            marker=MarkerStyle(\"o\", fillstyle=\"left\"),zorder=dot_z)\n",
    "\n",
    "\n",
    "    #norm= mpl.colors.Normalize(vmin=0, vmax=np.pi*2)\n",
    "    # a previous version of this used\n",
    "    #norm= matplotlib.colors.Normalize(vmin=cs.vmin, vmax=cs.vmax)\n",
    "    # which does not work any more\n",
    "    #sm = plt.cm.ScalarMappable(norm=norm, cmap = plt.cm.hsv)\n",
    "    #sm.set_array([])\n",
    "    #fig.colorbar(sm,cax=ax[1],ticks=[0,np.pi*2])#, ticks=cs.levels)\n",
    "    plt.savefig(\"../figures/Fig1C1.svg\")\n",
    "#z0_hat = latent_code[:,0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "vae,vae_params,task_params,training_params=load_model(name='reach_vae')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "vae.to_device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "\n",
    "prior = vae.prior.transition\n",
    "tau = prior.cast_A(prior.AW).detach().numpy().squeeze()\n",
    "pV = (prior.n*prior.scaling).detach().numpy()\n",
    "pU = prior.m_transform(prior.m).detach().numpy()\n",
    "pB = prior.h.detach().numpy()\n",
    "pI = prior.Wu.detach().numpy()\n",
    "\n",
    "pJ = pU@pV\n",
    "pu,ps,pv = np.linalg.svd(pJ)\n",
    "projection_matrix = pu[:,:2].T@pU\n",
    "pU= pu[:,:2]\n",
    "pV = (pv[:2].T * ps[:2]).T\n",
    "proj_chol = projection_matrix@vae.prior.chol_cov_embed(vae.prior.R_z).detach().numpy()\n",
    "pJ_r = pU@pV\n",
    "print(proj_chol@proj_chol.T)\n",
    "print(np.sqrt(proj_chol@proj_chol.T))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "n_angles = 8\n",
    "dur = 50\n",
    "\n",
    "n_repeats = 400\n",
    "z_all =np.zeros((n_repeats,n_angles,dur,2))\n",
    "X_all =np.zeros((n_repeats,n_angles,dur,60))\n",
    "\n",
    "for ri in range(n_repeats):\n",
    "    angles =np.arange(0,2*np.pi,2*np.pi/n_angles)\n",
    "    input = np.zeros((n_angles, dur,2))\n",
    "    z=np.zeros((n_angles, dur,2))\n",
    "    x=np.zeros((n_angles, dur,60))\n",
    "\n",
    "    r0=0\n",
    "    ph0 =np.random.rand(n_angles)*2*np.pi\n",
    "    for i, angle in enumerate(angles):\n",
    "        input[i, 0:25,0]=np.cos(angle)\n",
    "        input[i, 0:25,1]=np.sin(angle)\n",
    "    z[:,0,0] = r0*np.cos(ph0)\n",
    "    z[:,0,1] = r0*np.sin(ph0)\n",
    "    #print(self.latents.shape)\n",
    "    for t in range(1, dur):\n",
    "        z[:,t]=tau*z[:,t-1]\n",
    "        X = pU@z[:,t-1].T-np.expand_dims(pB,1)+ (input[:,t-1]@pI.T).T\n",
    "        z[:,t]+=(pV@relu(X)).T+(proj_chol@np.random.randn(len(angles),2).T).T\n",
    "        x[:,t]=X.T\n",
    "\n",
    "    z_all[ri]=z\n",
    "    X_all[ri]=x\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "tau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "pV.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "D_list,D_inds,z_list, n_singular, n_inverses = find_fixed_points(np.array([tau,tau]),pV.T,pU,0,pB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "vae.to_device(\"cpu\")\n",
    "\n",
    "phase = np.pi\n",
    "r=0\n",
    "cp = np.cos(phase)\n",
    "sp = np.sin(phase)\n",
    "\n",
    "inp1 = cp*r\n",
    "inp2 = sp*r\n",
    "u_in = np.array([inp1,inp2])\n",
    "\n",
    "def relu(x):\n",
    "    return np.maximum(x,0)\n",
    "\n",
    "\n",
    "\n",
    "def dyn_eq(x,y):\n",
    "    z = np.array([x,y])\n",
    "    X = pU@z+pI@u_in-pB\n",
    "    zn=(tau)*z+pV@relu(X)\n",
    "    h=.1\n",
    "    dz = (zn-z)/h\n",
    "    return dz[0],dz[1]\n",
    "\n",
    "xlims = 3\n",
    "ylims =3\n",
    "\n",
    "X, Y = np.meshgrid(np.linspace(-xlims, xlims, 30), np.linspace(-ylims, ylims, 30))\n",
    "u, v = np.zeros_like(X), np.zeros_like(X)\n",
    "NI, NJ = X.shape\n",
    "\n",
    "norm = np.zeros((NI,NJ))\n",
    "for i in range(NI):\n",
    "    for j in range(NJ):\n",
    "        x, y = X[i, j], Y[i, j]\n",
    "        dx, dy = dyn_eq(x,y)\n",
    "        u[i,j] = dx\n",
    "        v[i,j] = dy\n",
    "        norm[i,j]=np.log(np.linalg.norm([dx,dy]))\n",
    "\n",
    "with mpl.rc_context(fname='matplotlibrc'):\n",
    "\n",
    "    fig,ax = plt.subplots(1,2,figsize=(2,1),dpi=200)\n",
    "\n",
    "    ax[0].imshow(norm,extent = [-xlims,xlims,-ylims,ylims], \n",
    "            origin ='lower',cmap ='bone',vmax=np.max(norm),aspect='auto')\n",
    "    ax[0].streamplot(X, Y, u, v,color='lavender',density=.5,linewidth=.5,arrowsize=.5)\n",
    "    #plt.axis('square')\n",
    "    #plt.axis([-xlims, xlims, -ylims, ylims])\n",
    "    T1 = 0\n",
    "    T2 = 300\n",
    "    ax[0].set_box_aspect(1)\n",
    "    ax[0].spines[['right','top']].set_visible(False)\n",
    "\n",
    "    ax[0].set_xlim(-xlims,xlims)\n",
    "    ax[0].set_ylim(-ylims,ylims)\n",
    "    ax[0].set_xticks([-xlims,xlims])\n",
    "    ax[0].set_yticks([-ylims,ylims])\n",
    "    ax[0].set_xlabel(r\"$z_1$\")\n",
    "    ax[0].set_ylabel(r'$z_2$')\n",
    "\n",
    "    ax[0].set_prop_cycle('color',prop_cycle)\n",
    "    for z in z_all:\n",
    "        for ang_i in range(n_angles):\n",
    "            ax[0].plot(z[ang_i,:,0],z[ang_i,:,1],alpha = .75,zorder=1000)\n",
    "\n",
    "\n",
    "\n",
    "    ax[1].set_box_aspect(1)\n",
    "    ax[1].set_xlim(0,T2-T1)\n",
    "    ax[1].spines[['right','top']].set_visible(False)\n",
    "    ax[1].set_yticks(np.arange(0,20,2))\n",
    "    ax[1].set_yticklabels([])\n",
    "    ax[1].set_xlabel(\"timesteps\")\n",
    "    ax[1].set_ylabel(\"observed neurons\")\n",
    "    ax[0].set_title(\"latents\")\n",
    "    ax[1].axis('off')\n",
    "\n",
    "    incl_stable = False\n",
    "    incl_saddle = False\n",
    "    incl_unstable =False\n",
    "\n",
    "    for z in z_list:\n",
    "        #ax[0].scatter(z[0],z[1],color='red',zorder = 0)\n",
    "        e,v = np.linalg.eig(jacobian(pV,pU,-pB,np.diag(np.ones(2)*tau),z))\n",
    "        if abs(e[0])>1 and abs(e[1])>1:\n",
    "            if incl_unstable:\n",
    "                ax[0].scatter(z[0],z[1], s=dot_s, c=dot_fill, edgecolor=\"black\",lw=dot_ew,\n",
    "                                    marker=MarkerStyle(\"o\"),zorder=dot_z)\n",
    "            else:\n",
    "                incl_unstable=True\n",
    "                ax[0].scatter(z[0],z[1], s=dot_s, c=dot_fill, edgecolor=\"black\",lw=dot_ew,\n",
    "                    marker=MarkerStyle(\"o\"),zorder=dot_z,label='unstable')\n",
    "        elif abs(e[0])<1 and abs(e[1])<1:\n",
    "            if incl_stable:\n",
    "                ax[0].scatter(z[0],z[1], s=dot_s, c='black', edgecolor=\"black\",lw=dot_ew,\n",
    "                                    marker=MarkerStyle(\"o\"),zorder=dot_z)\n",
    "            else:\n",
    "                incl_stable=True\n",
    "                ax[0].scatter(z[0],z[1], s=dot_s, c='black', edgecolor=\"black\",lw=dot_ew,\n",
    "                        marker=MarkerStyle(\"o\"),zorder=dot_z,label='stable')\n",
    "        else: #saddle\n",
    "            if incl_saddle:\n",
    "                ax[0].scatter(z[0],z[1], s=dot_s, c=dot_fill, edgecolor=\"black\",lw=dot_ew,\n",
    "                                marker=MarkerStyle(\"o\", fillstyle=\"right\"),zorder=dot_z)\n",
    "                ax[0].scatter(z[0],z[1], s=dot_s, c='black', edgecolor=\"black\",lw=dot_ew,\n",
    "                            marker=MarkerStyle(\"o\", fillstyle=\"left\"),zorder=dot_z)\n",
    "            else:\n",
    "                incl_saddle=True\n",
    "                ax[0].scatter(z[0],z[1], s=dot_s, c=dot_fill, edgecolor=\"black\",lw=dot_ew,\n",
    "                marker=MarkerStyle(\"o\", fillstyle=\"right\"),zorder=dot_z,label='saddle')\n",
    "                ax[0].scatter(z[0],z[1], s=dot_s, c='black', edgecolor=\"black\",lw=dot_ew,\n",
    "                            marker=MarkerStyle(\"o\", fillstyle=\"left\"),zorder=dot_z,label='saddle')\n",
    "    ax[0].legend(title='fixed points', loc = 'upper right',bbox_to_anchor=(3,1.1))\n",
    "\n",
    "    plt.savefig(\"../figures/Fig1C2.svg\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "vae.to_device(\"cpu\")\n",
    "\n",
    "phase = np.pi\n",
    "r=0\n",
    "cp = np.cos(phase)\n",
    "sp = np.sin(phase)\n",
    "\n",
    "inp1 = cp*r\n",
    "inp2 = sp*r\n",
    "u_in = np.array([inp1,inp2])\n",
    "\n",
    "def relu(x):\n",
    "    return np.maximum(x,0)\n",
    "\n",
    "\n",
    "\n",
    "def dyn_eq(x,y):\n",
    "    z = np.array([x,y])\n",
    "    X = pU@z+pI@u_in-pB\n",
    "    zn=(tau)*z+pV@relu(X)\n",
    "    h=.1\n",
    "    dz = (zn-z)/h\n",
    "    return dz[0],dz[1]\n",
    "\n",
    "xlims = 3\n",
    "ylims =3\n",
    "\n",
    "X, Y = np.meshgrid(np.linspace(-xlims, xlims, 30), np.linspace(-ylims, ylims, 30))\n",
    "u, v = np.zeros_like(X), np.zeros_like(X)\n",
    "NI, NJ = X.shape\n",
    "\n",
    "norm = np.zeros((NI,NJ))\n",
    "for i in range(NI):\n",
    "    for j in range(NJ):\n",
    "        x, y = X[i, j], Y[i, j]\n",
    "        dx, dy = dyn_eq(x,y)\n",
    "        u[i,j] = dx\n",
    "        v[i,j] = dy\n",
    "        norm[i,j]=np.log(np.linalg.norm([dx,dy]))\n",
    "\n",
    "with mpl.rc_context(fname='matplotlibrc'):\n",
    "\n",
    "    fig,ax = plt.subplots(1,2,figsize=(2,1),dpi=200)\n",
    "\n",
    "    ax[0].imshow(norm.T,extent = [-xlims,xlims,-ylims,ylims], \n",
    "            origin ='lower',cmap ='bone',vmax=np.max(norm),aspect='auto')\n",
    "    ax[0].streamplot(X, Y, u, v,color='lavender',density=.5,linewidth=.5,arrowsize=.5)\n",
    "    #plt.axis('square')\n",
    "    #plt.axis([-xlims, xlims, -ylims, ylims])\n",
    "    T1 = 0\n",
    "    T2 = 300\n",
    "    ax[0].set_box_aspect(1)\n",
    "    ax[0].spines[['right','top']].set_visible(False)\n",
    "\n",
    "    ax[0].set_xlim(-xlims,xlims)\n",
    "    ax[0].set_ylim(-ylims,ylims)\n",
    "    ax[0].set_xticks([-xlims,xlims])\n",
    "    ax[0].set_yticks([-ylims,ylims])\n",
    "    ax[0].set_xlabel(r\"$z_1$\")\n",
    "    ax[0].set_ylabel(r'$z_2$')\n",
    "\n",
    "    ax[0].set_prop_cycle('color',prop_cycle)\n",
    "    for z in z_all:\n",
    "        for ang_i in range(n_angles):\n",
    "            ax[0].plot(z[ang_i,:,1],z[ang_i,:,0],alpha = .75,zorder=1000)\n",
    "\n",
    "\n",
    "\n",
    "    ax[1].set_box_aspect(1)\n",
    "    ax[1].set_xlim(0,T2-T1)\n",
    "    ax[1].spines[['right','top']].set_visible(False)\n",
    "    ax[1].set_yticks(np.arange(0,20,2))\n",
    "    ax[1].set_yticklabels([])\n",
    "    ax[1].set_xlabel(\"timesteps\")\n",
    "    ax[1].set_ylabel(\"observed neurons\")\n",
    "    ax[0].set_title(\"latents\")\n",
    "    ax[1].axis('off')\n",
    "\n",
    "    incl_stable = False\n",
    "    incl_saddle = False\n",
    "    incl_unstable =False\n",
    "\n",
    "    for z in z_list:\n",
    "        #ax[0].scatter(z[0],z[1],color='red',zorder = 0)\n",
    "        e,v = np.linalg.eig(jacobian(pV,pU,-pB,np.diag(np.ones(2)*tau),z))\n",
    "        if abs(e[0])>1 and abs(e[1])>1:\n",
    "            if incl_unstable:\n",
    "                ax[0].scatter(z[1],z[0], s=dot_s, c=dot_fill, edgecolor=\"black\",lw=dot_ew,\n",
    "                                    marker=MarkerStyle(\"o\"),zorder=dot_z)\n",
    "            else:\n",
    "                incl_unstable=True\n",
    "                ax[0].scatter(z[1],z[0], s=dot_s, c=dot_fill, edgecolor=\"black\",lw=dot_ew,\n",
    "                    marker=MarkerStyle(\"o\"),zorder=dot_z,label='unstable')\n",
    "        elif abs(e[0])<1 and abs(e[1])<1:\n",
    "            if incl_stable:\n",
    "                ax[0].scatter(z[1],z[0], s=dot_s, c='black', edgecolor=\"black\",lw=dot_ew,\n",
    "                                    marker=MarkerStyle(\"o\"),zorder=dot_z)\n",
    "            else:\n",
    "                incl_stable=True\n",
    "                ax[0].scatter(z[1],z[0], s=dot_s, c='black', edgecolor=\"black\",lw=dot_ew,\n",
    "                        marker=MarkerStyle(\"o\"),zorder=dot_z,label='stable')\n",
    "        else: #saddle\n",
    "            if incl_saddle:\n",
    "                ax[0].scatter(z[1],z[0], s=dot_s, c=dot_fill, edgecolor=\"black\",lw=dot_ew,\n",
    "                                marker=MarkerStyle(\"o\", fillstyle=\"right\"),zorder=dot_z)\n",
    "                ax[0].scatter(z[1],z[0], s=dot_s, c='black', edgecolor=\"black\",lw=dot_ew,\n",
    "                            marker=MarkerStyle(\"o\", fillstyle=\"left\"),zorder=dot_z)\n",
    "            else:\n",
    "                incl_saddle=True\n",
    "                ax[0].scatter(z[1],z[0], s=dot_s, c=dot_fill, edgecolor=\"black\",lw=dot_ew,\n",
    "                marker=MarkerStyle(\"o\", fillstyle=\"right\"),zorder=dot_z,label='saddle')\n",
    "                ax[0].scatter(z[1],z[0], s=dot_s, c='black', edgecolor=\"black\",lw=dot_ew,\n",
    "                            marker=MarkerStyle(\"o\", fillstyle=\"left\"),zorder=dot_z,label='saddle')\n",
    "    ax[0].legend(title='fixed points', loc = 'upper right',bbox_to_anchor=(3,1.1))\n",
    "    #ax[0].invert_yaxis()\n",
    "    ax[0].invert_xaxis()\n",
    "    plt.savefig(\"../figures/Fig1C2.svg\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prop_cycle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.colors import colorConverter as cc\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ni=np.random.randint(60)\n",
    "means = np.mean(X_all[:,:,25:,ni],axis=(0,2))\n",
    "means_GT = np.mean(X_all_GT[:,:,25:,ni],axis=(0,2))\n",
    "#plt.figure(figsize=(1, 1))\n",
    "#plt.gca().set_prop_cycle('color',prop_cycle)\n",
    "#for i in range(8):\n",
    "#    plt.scatter(means[i],means_GT[i],alpha=.7,lw=0)\n",
    "positions=np.arange(8)\n",
    "\n",
    "fig,ax = plt.subplots(2,1,figsize=(1,1),dpi=200, sharey=True)\n",
    "\n",
    "for i in range(8):\n",
    "    c=prop_cycle[i]\n",
    "\n",
    "    ax[1].boxplot(X_all[:,i,:,ni].flatten(),positions=[i],widths=.6,patch_artist=True,\n",
    "                    boxprops=dict(facecolor=cc.to_rgba(c, alpha=.7), color=c),\n",
    "                    capprops=dict(color=c),\n",
    "                    whiskerprops=dict(color=c),\n",
    "                    medianprops=dict(color=c),\n",
    "                    flierprops={'marker': 'o', 'markersize': 1, 'markerfacecolor':c, 'markeredgecolor':c})\n",
    "    ax[0].boxplot(X_all_GT[:,i,:,ni].flatten(),positions=[i],widths=.6,patch_artist=True,\n",
    "                    boxprops=dict(facecolor=cc.to_rgba(c, alpha=.7), color=c),\n",
    "                    capprops=dict(color=c),\n",
    "                    whiskerprops=dict(color=c),\n",
    "                    medianprops=dict(color=c),\n",
    "                    flierprops={'marker': 'o', 'markersize': 1, 'markerfacecolor':c, 'markeredgecolor':c})\n",
    "    ax[1].set_ylabel('student')\n",
    "    ax[0].set_ylabel('teacher')\n",
    "    # disable x axis spine\n",
    "    ax[0].spines['bottom'].set_visible(False)\n",
    "    ax[0].set_xticks([])\n",
    "    ax[1].spines['bottom'].set_visible(False)\n",
    "    # make ticks invisible\n",
    "    ax[1].set_xticks([])\n",
    "\n",
    "# Plotting the scatter plot\n",
    "#plt.title('mean rates')\n",
    "#plt.xlabel('Ground truth')\n",
    "#plt.ylabel('Generated')\n",
    "#plot x = y line\n",
    "#plt.plot([0, 1], [0, 1], color='grey', linestyle='--',zorder=-10)\n",
    "#plt.yticks([0,1])\n",
    "#plt.xticks([0,1])\n",
    "#plt.xlim(0,1)\n",
    "#plt.ylim(0,1)\n",
    "plt.savefig(\"../figures/Supp_Fig1C.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_all.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_all.reshape(-1,60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "means = np.mean(X_all[:,:,25:,ni],axis=(0,2))\n",
    "means_GT = np.mean(X_all_GT[:,:,25:,ni],axis=(0,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_pw_correlation(data):\n",
    "    \"\"\"Calculate the cross-correlation matrix for a dataset.\"\"\"\n",
    "    correlation_matrix = np.corrcoef(data, rowvar=False)\n",
    "    return correlation_matrix\n",
    "test_correlation = calculate_pw_correlation(X_all[:,:,25:].reshape(-1,60))\n",
    "gen_correlation = calculate_pw_correlation(X_all_GT[:,:,25:].reshape(-1,60))\n",
    "\n",
    "i_upper = np.triu_indices(60, k=1)\n",
    "test_corr_values = test_correlation[i_upper]\n",
    "gen_corr_values = gen_correlation[i_upper]\n",
    "\n",
    "\n",
    "# Plotting the scatter plot\n",
    "plt.figure(figsize=(1, 1))\n",
    "plt.scatter(test_corr_values,gen_corr_values, color='teal', alpha=0.1,linewidth=0, s=10)\n",
    "plt.title('pairwise correlations')\n",
    "plt.xlabel('teacher')\n",
    "plt.ylabel('student')\n",
    "#plot x = y line\n",
    "max = np.max([test_corr_values,gen_corr_values])*1.1\n",
    "plt.plot([0, max], [0, max], color='grey', linestyle='--',zorder=-10)\n",
    "#plt.yticks([0,1])\n",
    "#plt.xticks([0,1])\n",
    "#plt.xlim(0)\n",
    "#plt.ylim(0)\n",
    "plt.savefig(\"../figures/Supp_PWC.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_corr_values.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(60*60)/2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rnns",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
