{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "import os\n",
    "import sys\n",
    "file_dir = os.getcwd()\n",
    "sys.path.append(file_dir)\n",
    "from evaluation.klx_gmm import calc_kl_from_data\n",
    "from evaluation.pse import power_spectrum_helling\n",
    "import torch\n",
    "import numpy as np\n",
    "from rnn.datasets import Basic_dataset\n",
    "from rnn.saving import load_model\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from matplotlib.colors import ListedColormap\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from matplotlib.animation import FuncAnimation\n",
    "\n",
    "wandb.login()\n",
    "# Get the second color from the colormap\n",
    "#set seed\n",
    "seed = 42\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "i_code = '#C23B22'\n",
    "real_code = '##779ECB'\n",
    "gen_code = '#56AE57'\n",
    "\n",
    "i_color = (0.7608, 0.2314, 0.1333, 1)\n",
    "real_color = (0.4667, 0.6196, 0.7961, 1)\n",
    "gen_color = (0.3373, 0.6824, 0.3412, 1)\n",
    "average_color = tuple((x + y) / 2 for x, y in zip(real_color, gen_color))\n",
    "\n",
    "##mix of real and gen\n",
    "mix_color = (0.5, 0.5, 0.5, 1)\n",
    "\n",
    "#testing = testing[:,0:-288]\n",
    "#testing.shape\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "def hellinger_distance(p, q):\n",
    "    \"\"\"\n",
    "    Calculate the Hellinger distance between two discrete probability distributions.\n",
    "\n",
    "    Parameters:\n",
    "    p (array-like): First probability distribution.\n",
    "    q (array-like): Second probability distribution.\n",
    "\n",
    "    Returns:\n",
    "    float: Hellinger distance between the two distributions.\n",
    "    \"\"\"\n",
    "    p = np.array(p)\n",
    "    q = np.array(q)\n",
    "    \n",
    "    # Ensure that the input arrays have the same shape\n",
    "    assert p.shape == q.shape, \"Input arrays must have the same shape\"\n",
    "    \n",
    "    # Calculate Hellinger distance\n",
    "    h = np.sqrt(np.sum((np.sqrt(p) - np.sqrt(q))**2)) / np.sqrt(2)\n",
    "    \n",
    "    return h\n",
    "\n",
    "def binned_spikes_to_times(binned_spikes, bin_size=0.01):\n",
    "    idxs = np.nonzero(binned_spikes)[0]\n",
    "    spike_times = np.repeat(idxs, binned_spikes[idxs].astype(int))\n",
    "    spike_times = spike_times.astype(float)\n",
    "    for idx in idxs:\n",
    "        if binned_spikes[idx] > 1:\n",
    "            spike_times[spike_times == idx] += np.arange(binned_spikes[idx]) / binned_spikes[idx]\n",
    "    spike_times *= bin_size\n",
    "    return spike_times\n",
    "\n",
    "def compute_isi_stats(spikes, bin_size=0.01):\n",
    "    n_trials, n_timesteps, n_neurons = spikes.shape\n",
    "    isi_means = []\n",
    "    isi_stds = []\n",
    "    for neuron in range(n_neurons):\n",
    "        isis = np.concatenate([np.diff(binned_spikes_to_times(spikes[i,:,neuron], bin_size=bin_size)) for i in range(n_trials)])\n",
    "        isi_means.append(np.mean(isis))\n",
    "        isi_stds.append(np.std(isis))\n",
    "    return np.array(isi_means), np.array(isi_stds)\n",
    "\n",
    "api = wandb.Api()\n",
    "file_dir = \"\"\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "testing_TRAIN = np.load('train_hpc_11_10.npy')\n",
    "\n",
    "api= \"\"\n",
    "run = api.run(api) #hpc11, 10ms\n",
    "\n",
    "filename = None\n",
    "file = run.file(\"output.log\").download(replace=True)\n",
    "with open(file.name, \"r\") as file:\n",
    "    lines = file.readlines()\n",
    "    for line in lines:\n",
    "        if line.startswith(\"Saved: /models/ec013527\"):\n",
    "            filename = line.split(\"/\")[-1].strip()\n",
    "            break \n",
    "print(filename)\n",
    "print(run.config)\n",
    "\n",
    "\n",
    "\n",
    "vae, params, task_params, training_params = load_model(file_dir+\"/ec013527/\"+str(filename))\n",
    "#data_name = 'test_first_hpc.npy'\n",
    "#loaded_binary_matrix = np.load(str(data_name))\n",
    "loaded_binary_matrix = testing_TRAIN\n",
    "task      = Basic_dataset(task_params, loaded_binary_matrix.T.astype(np.float32))\n",
    "dim_x     = task.data.shape[1]\n",
    "print(\"Data shape: \", task.data.shape)\n",
    "\n",
    "train_TRAIN = np.load('test_hpc_11_10.npy')\n",
    "\n",
    "#set seed\n",
    "seed = 42\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "T = task.data.shape[0]\n",
    "spikes = task.data[:T]                      # (T, X)\n",
    "x_sample = spikes.T.unsqueeze(0)            # (1, X, T)\n",
    "z_hat, _,_, _ = vae.encoder(x_sample)          # (1, Z, T), _, _\n",
    "z0 = z_hat[0, :, 0] #+ torch.randn_like(z_hat[0, :, 0]) * 10\n",
    "Z = vae.prior.get_latent_time_series(time_steps=T, cut_off=0, z0=z0, noise_scale=1) # (1, Z, T)\n",
    "lam = vae.prior.get_observation(Z, noise_scale=0)      \n",
    "lam = lam[0,:,:,0]\n",
    "lam = torch.exp(lam)          # (1, X, T) (from Z to N to X) \n",
    "Z = Z[0]\n",
    "Z = Z[:,:,0]\n",
    "spikes_hat = torch.poisson(lam).T\n",
    "print(np.shape(Z))\n",
    "\n",
    "m_or = vae.prior.transition.m #20,2\n",
    "n_or = vae.prior.transition.n\n",
    "J = m_or@n_or\n",
    "u,s,v = torch.linalg.svd(J)\n",
    "projection_matrix = u[:,:vae.dim_z].T@m_or\n",
    "\n",
    "Z_TRAIN = projection_matrix@Z\n",
    "\n",
    "\n",
    "Z_TRAIN = Z_TRAIN.detach().numpy()\n",
    "\n",
    "pZ1_TRAIN = (Z_TRAIN[0] - np.mean(Z_TRAIN[0]))/np.std(Z_TRAIN[0])\n",
    "pZ2_TRAIN = (Z_TRAIN[1] - np.mean(Z_TRAIN[1]))/np.std(Z_TRAIN[1])\n",
    "pZ3_TRAIN = (Z_TRAIN[2] - np.mean(Z_TRAIN[2]))/np.std(Z_TRAIN[2])\n",
    "pZ4_TRAIN = (Z_TRAIN[3] - np.mean(Z_TRAIN[3]))/np.std(Z_TRAIN[3])\n",
    "\n",
    "spikes_TRAIN = spikes.detach().numpy()\n",
    "spikes_hat_TRAIN = spikes_hat.detach().numpy()\n",
    "\n",
    "truncated_spk = loaded_binary_matrix[:, :]\n",
    "task      = Basic_dataset(task_params, truncated_spk.T.astype(np.float32))\n",
    "x=task.data.T.unsqueeze(0)\n",
    "print(x.shape)\n",
    "marginal_smoothing = False\n",
    "t_held_in = truncated_spk.shape[1]\n",
    "k=246\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in range(5):\n",
    "        if i == 0:\n",
    "            Qzs_filt_avg_TRAIN, Qzs_sm_avg_TRAIN, Xs_filt_avg, Xs_sm_avg = vae.predict_NLB(x, u=None,k=k,t_held_in =t_held_in, t_forward=0,marginal_smoothing=marginal_smoothing)\n",
    "        else:        \n",
    "            Qzs_filt_avg_tmp_TRAIN, Qzs_sm_avg_tmp_TRAIN, _, _ = vae.predict_NLB(x, u=None, k=k, t_held_in=t_held_in, t_forward=0, marginal_smoothing=marginal_smoothing)\n",
    "            \n",
    "            Qzs_filt_avg_TRAIN += Qzs_filt_avg_tmp_TRAIN\n",
    "            Qzs_sm_avg_TRAIN += Qzs_sm_avg_tmp_TRAIN\n",
    "    \n",
    "    Qzs_filt_avg_TRAIN /= 5\n",
    "    Qzs_sm_avg_TRAIN /= 5\n",
    "    Xs_filt_avg /= 5\n",
    "    Xs_sm_avg /= 5\n",
    "  \n",
    "Z1 = torch.mean(Qzs_sm_avg_TRAIN[0,0,:,:],axis=1)\n",
    "Z2 = torch.mean(Qzs_sm_avg_TRAIN[0,1,:,:],axis=1)\n",
    "Z3 = torch.mean(Qzs_sm_avg_TRAIN[0,2,:,:],axis=1)\n",
    "Z4 = torch.mean(Qzs_sm_avg_TRAIN[0,3,:,:],axis=1)\n",
    "#Z score Z1, Z2, Z3\n",
    "Z1_TRAIN = (Z1 - torch.mean(Z1))/torch.std(Z1)\n",
    "Z2_TRAIN = (Z2 - torch.mean(Z2))/torch.std(Z2)\n",
    "Z3_TRAIN = (Z3 - torch.mean(Z3))/torch.std(Z3)\n",
    "Z4_TRAIN = (Z4 - torch.mean(Z4))/torch.std(Z4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "testing = np.load('test_hpc_11_10.npy')\n",
    "print(testing.shape)\n",
    "\n",
    "\n",
    "run = api.run('vgtfi-rnn/uncategorized/lqjpl8b3') #hpc11, 10ms\n",
    "filename = None\n",
    "file = run.file(\"output.log\").download(replace=True)\n",
    "with open(file.name, \"r\") as file:\n",
    "    lines = file.readlines()\n",
    "    for line in lines:\n",
    "        if line.startswith(\"Saved: models/ec013527\"):\n",
    "            filename = line.split(\"/\")[-1].strip()\n",
    "            break \n",
    "print(filename)\n",
    "print(run.config)\n",
    "\n",
    "\n",
    "\n",
    "vae, params, task_params, training_params = load_model(file_dir+\"/ec013527/\"+str(filename))\n",
    "#data_name = 'test_first_hpc.npy'\n",
    "#loaded_binary_matrix = np.load(str(data_name))\n",
    "loaded_binary_matrix = testing\n",
    "task      = Basic_dataset(task_params, loaded_binary_matrix.T.astype(np.float32))\n",
    "dim_x     = task.data.shape[1]\n",
    "print(\"Data shape: \", task.data.shape)\n",
    "\n",
    "train = np.load('train_hpc_11_10.npy')\n",
    "\n",
    "#set seed\n",
    "seed = 42\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "T = task.data.shape[0]\n",
    "spikes = task.data[:T]                      # (T, X)\n",
    "x_sample = spikes.T.unsqueeze(0)            # (1, X, T)\n",
    "z_hat, _,_, _ = vae.encoder(x_sample)          # (1, Z, T), _, _\n",
    "z0 = z_hat[0, :, 0] #+ torch.randn_like(z_hat[0, :, 0]) * 10\n",
    "Z = vae.prior.get_latent_time_series(time_steps=T+1000, cut_off=0, z0=z0, noise_scale=1) # (1, Z, T)\n",
    "lam = vae.prior.get_observation(Z, noise_scale=0)      \n",
    "lam = lam[0,:,1000:,0]\n",
    "lam = torch.exp(lam)          # (1, X, T) (from Z to N to X) \n",
    "Z = Z[0]\n",
    "Z = Z[:,1000:,0]\n",
    "spikes_hat = torch.poisson(lam).T\n",
    "print(np.shape(Z))\n",
    "\n",
    "m_or = vae.prior.transition.m #20,2\n",
    "n_or = vae.prior.transition.n\n",
    "J = m_or@n_or\n",
    "u,s,v = torch.linalg.svd(J)\n",
    "projection_matrix = u[:,:vae.dim_z].T@m_or\n",
    "\n",
    "Z = projection_matrix@Z\n",
    "\n",
    "\n",
    "Z = Z.detach().numpy()\n",
    "\n",
    "spikes = spikes.detach().numpy()\n",
    "spikes_hat = spikes_hat.detach().numpy()\n",
    "\n",
    "truncated_spk = loaded_binary_matrix[:, :]\n",
    "task      = Basic_dataset(task_params, truncated_spk.T.astype(np.float32))\n",
    "x=task.data.T.unsqueeze(0)\n",
    "print(x.shape)\n",
    "marginal_smoothing = False\n",
    "t_held_in = truncated_spk.shape[1]\n",
    "k=246\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in range(5):\n",
    "        if i == 0:\n",
    "            Qzs_filt_avg, Qzs_sm_avg, Xs_filt_avg, Xs_sm_avg = vae.predict_NLB(x, u=None,k=k,t_held_in =t_held_in, t_forward=0,marginal_smoothing=marginal_smoothing)\n",
    "        else:        \n",
    "            Qzs_filt_avg_tmp, Qzs_sm_avg_tmp, Xs_filt_avg_tmp, Xs_sm_avg_tmp = vae.predict_NLB(x, u=None, k=k, t_held_in=t_held_in, t_forward=0, marginal_smoothing=marginal_smoothing)\n",
    "            \n",
    "            Qzs_filt_avg += Qzs_filt_avg_tmp\n",
    "            Qzs_sm_avg += Qzs_sm_avg_tmp\n",
    "            Xs_filt_avg += Xs_filt_avg_tmp\n",
    "            Xs_sm_avg += Xs_sm_avg_tmp\n",
    "    \n",
    "    Qzs_filt_avg /= 5\n",
    "    Qzs_sm_avg /= 5\n",
    "    Xs_filt_avg /= 5\n",
    "    Xs_sm_avg /= 5\n",
    "  \n",
    "Qzs_filt_avg.shape\n",
    "\n",
    "Z1 = torch.mean(Qzs_sm_avg[0,0,:,:],axis=1)\n",
    "Z2 = torch.mean(Qzs_sm_avg[0,1,:,:],axis=1)\n",
    "Z3 = torch.mean(Qzs_sm_avg[0,2,:,:],axis=1)\n",
    "Z4 = torch.mean(Qzs_sm_avg[0,3,:,:],axis=1)\n",
    "#Z score Z1, Z2, Z3\n",
    "Z1 = (Z1 - torch.mean(Z1))/torch.std(Z1)\n",
    "Z2 = (Z2 - torch.mean(Z2))/torch.std(Z2)\n",
    "Z3 = (Z3 - torch.mean(Z3))/torch.std(Z3)\n",
    "Z4 = (Z4 - torch.mean(Z4))/torch.std(Z4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#import resample\n",
    "from scipy.signal import resample\n",
    "\n",
    "alla = np.load('train_hpc_11_10.npy')\n",
    "alla1 = np.load('test_hpc_11_10.npy')\n",
    "lfp = np.load('mazeLFP.npy').T\n",
    "#print(lfpM.shape, alla.shape, alla1.shape)\n",
    "#lfp = lfp[:,int(alla.shape[1]):int(alla.shape[1]) + int(Qzs_filt.shape[2])]\n",
    "#lfp = lfpM[:, alla.shape[1]:]  \n",
    "print(lfp.shape)\n",
    "\n",
    "for i in range(lfp.shape[0]):\n",
    "    lfp[i] = (lfp[i] - np.mean(lfp[i]))/np.std(lfp[i])\n",
    "\n",
    "#take the mean along the channels\n",
    "lfp = np.mean(lfp, axis=0)\n",
    "print(lfp.shape, Qzs_filt_avg.shape, alla.shape[1])\n",
    "\n",
    "#test lfp\n",
    "train_lfp = lfp[:alla.shape[1]]\n",
    "test_lfp = lfp[alla.shape[1]:]\n",
    "print(lfp.shape, Qzs_filt_avg.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 5))\n",
    "color_dict = {\"wasserstein\": \"#cc241d\",\n",
    "              \"mmd\": \"#eebd35\",\n",
    "              \"c2st\": \"#458588\",\n",
    "              \"fid\": \"#8ec07c\", \n",
    "              \"kl\": \"#8ec07c\"}\n",
    "\n",
    "from neo.io.neuroscopeio import NeuroScopeIO\n",
    "from neo.io.klustakwikio import KlustaKwikIO\n",
    "\n",
    "\n",
    "#cmap = ListedColormap(['#cc241d', '#eebd35', '#458588', '#8ec07c'])\n",
    "cmap = plt.get_cmap('tab10')\n",
    "\n",
    "init = 1000\n",
    "duration = 1000\n",
    "\n",
    "x_index = np.linspace(0, 400, 7)\n",
    "x_loc = np.linspace(0, 10, 7)\n",
    "#round the x_loc to 2 decimal places\n",
    "x_loc = [round(i, 2) for i in x_loc]\n",
    "\n",
    "tri = 0\n",
    "latent_ind = 0\n",
    "#plt.plot(torch.mean(Qzs_filt[:1000,0,0,:],axis=1),color='teal', label='filtered');\n",
    "#plt.plot(torch.mean(Qzs_filt[:1000,1,0,:],axis=1),color='teal', label='filtered');\n",
    "#plt.plot(torch.mean(Qzs_filt[:1000,2,0,:],axis=1),color='teal', label='filtered');\n",
    "#plt.plot(torch.mean(Qzs_filt[:1000,3,0,:],axis=1),color='teal', label='filtered');\n",
    "#plt.plot(Qzs_filt[:1000,latent_ind,0,:],color='blue',alpha=0.01,zorder=-1000);\n",
    "\n",
    "\n",
    "#Z5 = torch.mean(Qzs_sm[0,4,:,:],axis=1)\n",
    "#Z6 = torch.mean(Qzs_sm[0,5,:,:],axis=1)\n",
    "#Z score Z1, Z2, Z3\n",
    "\n",
    "#Z5 = (Z5 - torch.mean(Z5))/torch.std(Z5)\n",
    "#Z6 = (Z6 - torch.mean(Z6))/torch.std(Z6)\n",
    "\n",
    "#z SCORE np array lfp\n",
    "\n",
    "#plt.plot(Z1[500:800],color=cmap(1), label='postZ1');\n",
    "#plt.plot(Z2[500:800],color=cmap(2), label='postZ2');\n",
    "#plt.plot(Z3[500:800],color=cmap(3), label='postZ3');\n",
    "\n",
    "#ZSCORE z[0], z[1], z[2]\n",
    "pZ1 = (Z[0] - np.mean(Z[0]))/np.std(Z[0])\n",
    "pZ2 = (Z[1] - np.mean(Z[1]))/np.std(Z[1])\n",
    "pZ3 = (Z[2] - np.mean(Z[2]))/np.std(Z[2])\n",
    "pZ4 = (Z[3] - np.mean(Z[3]))/np.std(Z[3])\n",
    "#pZ5 = (Z[4] - np.mean(Z[4]))/np.std(Z[4])\n",
    "#pZ6 = (Z[5] - np.mean(Z[5]))/np.std(Z[5])\n",
    "init = 0\n",
    "dur = 6000\n",
    "\n",
    "plt.plot(pZ1[init:init+dur],color=cmap(1), linestyle = 'dashed', label='priorZ1');\n",
    "plt.plot(pZ2[init:init+dur],color=cmap(2), linestyle = 'dashed', label='priorZ2');\n",
    "plt.plot(pZ3[init:init+dur],color=cmap(3), linestyle = 'dashed', label='priorZ3');\n",
    "plt.plot(pZ4[init:init+dur],color=cmap(1), linestyle = 'dashed', label='priorZ4');\n",
    "plt.plot(Z1[init:init+dur],color=cmap(1), linestyle = 'dashed', label='postZ1');\n",
    "plt.plot(Z2[init:init+dur],color=cmap(2), linestyle = 'dashed', label='postZ2');\n",
    "plt.plot(Z3[init:init+dur],color=cmap(3), linestyle = 'dashed', label='postZ3');\n",
    "plt.plot(Z4[init:init+dur],color=cmap(1), linestyle = 'dashed', label='postZ4');\n",
    "plt.plot(test_lfp[init:init+dur],color='black', alpha=0.7,linestyle = 'dashed', label='lfp');\n",
    "#plt.plot(pZ5[500:800],color=cmap(2), linestyle = 'dashed', label='priorZ5');\n",
    "#plt.plot(pZ6[500:800],color=cmap(3), linestyle = 'dashed', label='priorZ6');\n",
    "plt.xlabel(\"timesteps\")\n",
    "plt.xlim(850,1000)\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from neo.io.neuroscopeio import NeuroScopeIO\n",
    "from neo.io.klustakwikio import KlustaKwikIO\n",
    "\n",
    "#filen= path+\"/Achilles_10252013.xml\"\n",
    "path = \"\"\n",
    "filen= path+\"/ec013527.xml\"\n",
    "reader = NeuroScopeIO(filename=filen)\n",
    "\n",
    "from scipy.signal import resample\n",
    "\n",
    "seg = reader.read_segment(lazy=False)\n",
    "t, c = np.shape(seg.analogsignals[0])\n",
    "ds = []\n",
    "for i in range(c):\n",
    "    lfp = np.array(seg.analogsignals[0][:,i])\n",
    "    resample_rate = 12.5\n",
    "    #resample\n",
    "    n_samples = int(len(lfp)/resample_rate)\n",
    "    lfp_ds = resample(lfp,n_samples)\n",
    "    ds.append(lfp_ds)\n",
    "\n",
    "alla = np.load('train_first_hpc.npy')\n",
    "\n",
    "lfp = np.array(ds)[:,:,0]\n",
    "print(lfp.shape)\n",
    "lfp = lfp[:,int(alla.shape[1]):int(alla.shape[1]) + int(Qzs_filt.shape[2])]\n",
    "print(lfp.shape)\n",
    "#z score lfp\n",
    "for i in range(lfp.shape[0]):\n",
    "    lfp[i] = (lfp[i] - np.mean(lfp[i]))/np.std(lfp[i])\n",
    "\n",
    "#take the mean alxong the channels\n",
    "lfp = np.mean(lfp, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.signal import welch\n",
    "# Define the sampling frequency\n",
    "fs = 39 # Hz\n",
    "nperseg = 1024\n",
    "frequencies0, psd0 = welch(test_lfp.reshape(-1), fs=fs, nperseg=nperseg)  # Adjust nperseg as needed\n",
    "#frequencies1, psd1 = welch(maze_filtered.reshape(-1), fs=fs, nperseg=nperseg)  # Adjust nperseg as needed\n",
    "\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.semilogy(frequencies0, psd0, alpha=0.8, zorder=0, label = 'lfp')\n",
    "#plt.semilogy(frequencies1, psd1, alpha=0.8, zorder=0, label = 'lfp_filtered')\n",
    "plt.xlabel(\"Frequency (f)\", labelpad=10, fontsize=14)\n",
    "plt.ylabel(\"PSD of LFP\", labelpad=10, fontsize=14) \n",
    "#prevent x-axis from being cut off\n",
    "plt.xlim(0, 20)\n",
    "#plt.ylim(1e-4, 1)\n",
    "plt.subplots_adjust(bottom=0.3)\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.interpolate import interp1d\n",
    "\n",
    "# Example: Your data sizes (These should match your actual data)\n",
    "# size_lfp = len(lfp)  # This should be the actual length of your LFP data\n",
    "# size_latents = len(Z1)  # This should be the actual length of your latent data\n",
    "# Define the colormap\n",
    "cmap = plt.get_cmap(\"tab10\")\n",
    "locs = np.load('test_loc.npy')\n",
    "# Plotting\n",
    "plt.figure(figsize=(12, 3))\n",
    "init = 500\n",
    "dur = 5500\n",
    "plt.plot(Z1[init:init+dur], color=cmap(1), label='postZ1')\n",
    "plt.plot(Z2[init:init+dur], color=cmap(2), label='postZ2')\n",
    "plt.plot(Z3[init:init+dur], color=cmap(3), label='postZ3')\n",
    "plt.plot(Z4[init:init+dur], color=cmap(1), label='postZ4')\n",
    "#plt.plot(Z5[800:1500], color=cmap(2), label='postZ5')\n",
    "#plt.plot(Z6[800:1500], color=cmap(3), label='postZ6')\n",
    "#zscore lfp\n",
    "#lfp = (lfp - np.mean(lfp))/np.std(lfp)\n",
    "plt.plot(locs[init:init+dur], color='black', alpha = 0.8)\n",
    "#plt.plot(lfp_filtered[init:init+dur], color='blue', alpha = 0.8)\n",
    "#print(lfp[0][500:600].shape)\n",
    "plt.legend()\n",
    "plt.xlabel('Time (s)')\n",
    "plt.ylabel('Amplitude')\n",
    "plt.xlim(500, 1000)\n",
    "plt.title('Aligned LFP and Latent Variables via Interpolation')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#DETACH Z1, Z2, Z3\n",
    "Z1 = Z1.detach().numpy()\n",
    "Z2 = Z2.detach().numpy()\n",
    "Z3 = Z3.detach().numpy()\n",
    "Z4 = Z4.detach().numpy()\n",
    "\n",
    "\n",
    "#Z5 = Z5.detach().numpy()\n",
    "#Z6 = Z6.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#bandpass lfp\n",
    "import scipy.signal\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "print(lfp.shape)\n",
    "# Define the sampling frequency and time vector\n",
    "fs = 39  # Sampling frequency\n",
    "t = np.arange(0, Qzs_filt_avg.shape[2]) / fs  # Time vector\n",
    "lowcut = 0.1  # Example: 10 Hz\n",
    "highcut = 19  # Example: 300 Hz\n",
    "low = lowcut / (fs / 2)\n",
    "high = highcut / (fs / 2)\n",
    "order = 4\n",
    "blfp, alfp = scipy.signal.butter(order, [low, high], btype='band')\n",
    "\n",
    "fs = 39  # Sampling frequency\n",
    "t = np.arange(0, Qzs_filt_avg.shape[2]) / fs  # Time vector\n",
    "lowcut = 0.1  # Example: 10 Hz\n",
    "highcut = 19  # Example: 300 Hz\n",
    "low = lowcut / (fs / 2)\n",
    "high = highcut / (fs / 2)\n",
    "order = 4\n",
    "b, a = scipy.signal.butter(order, [low, high], btype='band')\n",
    "\n",
    "#lfp_filtered = scipy.signal.filtfilt(blfp, alfp, test_lfp, axis=0)\n",
    "Z1_filtered = scipy.signal.filtfilt(b, a, Z1, axis=0)\n",
    "Z2_filtered = scipy.signal.filtfilt(b, a, Z2, axis=0)\n",
    "Z3_filtered = scipy.signal.filtfilt(b, a, Z3, axis=0)\n",
    "Z4_filtered = scipy.signal.filtfilt(b, a, Z4, axis=0)\n",
    "#Z5_filtered = scipy.signal.filtfilt(b, a, Z5, axis=0)\n",
    "#Z6_filtered = scipy.signal.filtfilt(b, a, Z6, axis=0)\n",
    "\n",
    "pZ1_filtered = scipy.signal.filtfilt(b, a, pZ1, axis=0)\n",
    "pZ2_filtered = scipy.signal.filtfilt(b, a, pZ2, axis=0)\n",
    "pZ3_filtered = scipy.signal.filtfilt(b, a, pZ3, axis=0)\n",
    "pZ4_filtered = scipy.signal.filtfilt(b, a, pZ4, axis=0)\n",
    "#pZ5_filtered = scipy.signal.filtfilt(b, a, pZ5, axis=0)\n",
    "#pZ6_filtered = scipy.signal.filtfilt(b, a, pZ6, axis=0)\n",
    "\n",
    "# Plot the original and filtered signals\n",
    "plt.figure(figsize=(12, 3))\n",
    "#plt.plot(t, lfp_interpolated, label='Original LFP', color='black')\n",
    "#plt.plot(t, lfp_filtered, label='Filtered LFP', color='black')\n",
    "#plt.plot(t, test_lfp, label='LFP', color='red')\n",
    "plt.plot(t, pZ1_filtered, label='Filtered Z1', color='red')\n",
    "plt.plot(t, pZ2_filtered, label='Filtered Z2', color='green')\n",
    "plt.plot(t, pZ3_filtered, label='Filtered Z3', color='orange')\n",
    "plt.xlabel('Time (s)')\n",
    "plt.ylabel('Amplitude')\n",
    "plt.legend()\n",
    "plt.ylim(-2, 2)\n",
    "plt.xlim(6, 10)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.signal import welch\n",
    "\n",
    "# Assuming your time series data is stored in a variable called 'data'\n",
    "# Make sure 'data' is a 1D numpy array\n",
    "\n",
    "# Define the sampling frequency\n",
    "fs = 39  # Hz\n",
    "nperseg = 400\n",
    "Z1_filtered = Z1_filtered.reshape(-1)\n",
    "Z2_filtered = Z2_filtered.reshape(-1)\n",
    "Z3_filtered = Z3_filtered.reshape(-1)\n",
    "Z4_filtered = Z4_filtered.reshape(-1)\n",
    "#Z5_filtered = Z5_filtered.reshape(-1)\n",
    "#Z6_filtered = Z6_filtered.reshape(-1)\n",
    "#lfp_filtered = lfp_filtered.reshape(-1)\n",
    "frequencies0, psd0 = welch(Z1_filtered.T, fs=fs, nperseg=nperseg)  # Adjust nperseg as needed\n",
    "frequencies1, psd1 = welch(Z2_filtered.T, fs=fs, nperseg=nperseg) \n",
    "frequencies2, psd2 = welch(Z3_filtered.T, fs=fs, nperseg=nperseg) \n",
    "frequencies0, psd0 = welch(pZ1.T, fs=fs, nperseg=nperseg)  # Adjust nperseg as needed\n",
    "frequencies1, psd1 = welch(pZ2.T, fs=fs, nperseg=nperseg) \n",
    "frequencies2, psd2 = welch(pZ3.T, fs=fs, nperseg=nperseg) \n",
    "frequencies3, psd3 = welch(pZ4.T, fs=fs, nperseg=nperseg)  # Adjust nperseg as needed\n",
    "#frequencies4, psd4 = welch(pZ5_filtered.T, fs=fs, nperseg=nperseg) \n",
    "#frequencies5, psd5 = welch(pZ6_filtered.T, fs=fs, nperseg=nperseg) \n",
    "\n",
    "frequencies6, psd6 = welch(test_lfp.T, fs=39, nperseg=nperseg) \n",
    "#find the peak frequency\n",
    "\n",
    "#peak2 = frequencies2[np.argmax(psd2)]\n",
    "\n",
    "# Plot the PSD\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.xlim(0, 50)\n",
    "plt.semilogy(frequencies0, psd0, color=cmap(1), alpha=0.8, zorder=0, label = 'rnn Z1')\n",
    "plt.semilogy(frequencies1, psd1,  color=cmap(3), alpha=0.8, zorder=0, label = 'rnn Z2')\n",
    "plt.semilogy(frequencies2, psd2,  color=cmap(0), alpha=0.8, zorder=0, label = 'rnn Z3')\n",
    "cmap = plt.get_cmap('tab20')\n",
    "plt.semilogy(frequencies3, psd3, color=cmap(1), alpha=0.8, zorder=0, label = 'rnn Z4   ')\n",
    "#plt.semilogy(frequencies4, psd4,  color=cmap(3), alpha=0.8, zorder=0, label = 'latent 2')\n",
    "#plt.semilogy(frequencies5, psd5,  color=cmap(0), alpha=0.8, zorder=0, label = 'latent 3')\n",
    "plt.semilogy(frequencies6, psd6,  color='black', alpha=0.8, zorder=0, label = 'LFP')\n",
    "\n",
    "\n",
    "#plt.semilogy(frequencies6, psd6,  color='gray', label = 'LFP')\n",
    "plt.xlabel(\"Frequency (f)\", labelpad=10, fontsize=14)\n",
    "plt.ylabel(\"PSD of Latents\", labelpad=10, fontsize=14) \n",
    "#prevent x-axis from being cut off\n",
    "plt.xlim(0, 20)\n",
    "plt.ylim(1e-5, 10)\n",
    "plt.subplots_adjust(bottom=0.3)\n",
    "plt.savefig('psd_latent.png')\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#linear regression from latents to LFP\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "# Define the model\n",
    "model = LinearRegression()\n",
    "mazar = np.load('mazeLFP.npy')\n",
    "alla = np.load('train_hpc_11_10.npy')\n",
    "alla1 = np.load('test_hpc_11_10.npy')\n",
    "test_locs = np.load('test_loc.npy')\n",
    "train_locs = np.load('train_loc.npy')\n",
    "\n",
    "\n",
    "#gaussian 1d kernel to the each channel of alla\n",
    "from scipy.ndimage import gaussian_filter1d\n",
    "#for i in range(alla.shape[0]):\n",
    "#    alla[i] = gaussian_filter1d(alla[i], sigma=3)\n",
    "\n",
    "#for i in range(alla1.shape[0]):\n",
    "#    alla1[i] = gaussian_filter1d(alla1[i], sigma=3)\n",
    "\n",
    "#take the mean of alla1 alon the channels\n",
    "alla = np.sum(alla, axis=0)\n",
    "alla = gaussian_filter1d(alla, sigma=10)\n",
    "print(np.shape(alla1))\n",
    "\n",
    "\n",
    "\n",
    "X = np.vstack((Z1_TRAIN, Z2_TRAIN, Z3_TRAIN, Z4_TRAIN)).T\n",
    "#X = alla.T\n",
    "X_test = np.vstack((Z1, Z2, Z3, Z4)).T\n",
    "y = train_locs\n",
    "\n",
    "# Fit the model\n",
    "model.fit(X, y)\n",
    "\n",
    "# Make predictions\n",
    "y_pred = model.predict(X_test)\n",
    "\n",
    "#zscore predicted lfp\n",
    "#y_pred = (y_pred - np.mean(y_pred))/np.std(y_pred)\n",
    "# Plot the predicted vs. actual values\n",
    "#szcore y_pred\n",
    "plt.figure(figsize=(6, 2))\n",
    "plt.plot(y_pred)\n",
    "plt.plot(test_locs) \n",
    "plt.legend()\n",
    "print(test_locs.shape, y_pred.shape)\n",
    "#r2 score\n",
    "r2 = model.score(X_test, test_locs)\n",
    "plt.title(f'r2: {r2:.5f}')\n",
    "print(np.corrcoef(test_locs, y_pred)[0, 1])\n",
    "#plt.xlim(0, 100)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#import signal\n",
    "from scipy import signal\n",
    "\n",
    "corr1 = signal.correlate(test_lfp, Z1)\n",
    "lags1 = signal.correlation_lags(len(test_lfp), len(Z1))\n",
    "corr1 /= np.max(corr1)\n",
    "\n",
    "corr2 = signal.correlate(test_lfp, Z2)\n",
    "lags2 = signal.correlation_lags(len(test_lfp), len(Z2))\n",
    "corr2 /= np.max(corr2)\n",
    "\n",
    "corr3 = signal.correlate(test_lfp, Z3)\n",
    "lags3 = signal.correlation_lags(len(test_lfp), len(Z3))\n",
    "corr3 /= np.max(corr3)\n",
    "\n",
    "corr4 = signal.correlate(test_lfp, Z4)\n",
    "lags4 = signal.correlation_lags(len(test_lfp), len(Z4))\n",
    "corr4 /= np.max(corr4)\n",
    "\n",
    "\n",
    "# 2 by 2 plot for the correlation\n",
    "plt.figure(figsize=(12, 6))\n",
    "plt.subplot(2, 2, 1)\n",
    "plt.plot(lags1, corr1)\n",
    "plt.title('Correlation with Z1')\n",
    "plt.xlabel('Lags')\n",
    "plt.ylabel('Correlation')\n",
    "\n",
    "plt.subplot(2, 2, 2)\n",
    "plt.plot(lags2, corr2)\n",
    "plt.title('Correlation with Z2')\n",
    "plt.xlabel('Lags')\n",
    "plt.ylabel('Correlation')\n",
    "\n",
    "plt.subplot(2, 2, 3)\n",
    "plt.plot(lags3, corr3)\n",
    "plt.title('Correlation with Z3')\n",
    "plt.xlabel('Lags')\n",
    "plt.ylabel('Correlation')\n",
    "\n",
    "plt.subplot(2, 2, 4)\n",
    "plt.plot(lags4, corr4)\n",
    "plt.title('Correlation with Z4')\n",
    "plt.xlabel('Lags')\n",
    "plt.ylabel('Correlation')\n",
    "\n",
    "plt.tight_layout()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.signal import coherence\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Sampling frequency in Hz\n",
    "fs = 100\n",
    "\n",
    "# Compute coherence\n",
    "f1, Cxy1 = coherence(Z1, locs, fs=fs, nperseg=1024)\n",
    "f2, Cxy2 = coherence(Z2, locs, fs=fs, nperseg=1024)\n",
    "f3, Cxy3 = coherence(Z3, locs, fs=fs, nperseg=1024)\n",
    "f4, Cxy4 = coherence(Z4, locs, fs=fs, nperseg=1024)\n",
    "pf1, pCxy1 = coherence(pZ1_filtered, locs, fs=fs, nperseg=1024)\n",
    "pf2, pCxy2 = coherence(pZ2_filtered, locs, fs=fs, nperseg=1024)\n",
    "pf3, pCxy3 = coherence(pZ3_filtered, locs, fs=fs, nperseg=1024)\n",
    "pf4, pCxy4 = coherence(pZ4_filtered, locs, fs=fs, nperseg=1024)\n",
    "#convolve with gaussian kernel\n",
    "from scipy.ndimage import gaussian_filter1d\n",
    "Cxy1 = gaussian_filter1d(Cxy1, sigma=10)\n",
    "pCxy1 = gaussian_filter1d(pCxy1, sigma=10)\n",
    "Cxy2 = gaussian_filter1d(Cxy2, sigma=10)\n",
    "pCxy2 = gaussian_filter1d(pCxy2, sigma=10)\n",
    "Cxy3 = gaussian_filter1d(Cxy3, sigma=10)\n",
    "pCxy3 = gaussian_filter1d(pCxy3, sigma=10)\n",
    "Cxy4 = gaussian_filter1d(Cxy4, sigma=10)\n",
    "pCxy4 = gaussian_filter1d(pCxy4, sigma=10)\n",
    "\n",
    "# Plot squared coherence\n",
    "#plt.figure(figsize=(1, 1))\n",
    "\n",
    "plt.plot(pf1, pCxy1**2, label='posterior 1')  # Squared coherence\n",
    "plt.plot(f1, Cxy1**2, label='rnn 1')  # Squared coherence\n",
    "\n",
    "plt.plot(pf3, pCxy3**2, label='posterior 3')  # Squared coherence\n",
    "plt.plot(f3, Cxy3**2, label='rnn 3')  # Squared coherence\n",
    "\n",
    "plt.plot(pf2, pCxy2**2, label='posterior 2')  # Squared coherence\n",
    "plt.plot(f2, Cxy2**2, label='rnn 2')  # Squared coherence\n",
    "\n",
    "plt.plot(pf4, pCxy4**2, label='posterior 4')  # Squared coherence\n",
    "plt.plot(f4, Cxy4**2, label='rnn 4')  # Squared coherence\n",
    "#plt.xlim(0, 20)\n",
    "#plt.plot(f3, Cxy3**2, label ='z3') # Squared coherence\n",
    "plt.legend()\n",
    "plt.xlabel('Frequency [Hz]')\n",
    "plt.ylabel('Squared Coherence')\n",
    "#x log\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.gridspec as gridspec\n",
    "import matplotlib as mpl\n",
    "color1= '#7B46C1'\n",
    "color2= '#A860AF'\n",
    "color3= '#7C277D'\n",
    "color4= '#8A44A4'\n",
    "tg = 'teal'\n",
    "tr = 'firebrick'\n",
    "with mpl.rc_context(fname=\"style_matplotlib.rc\"): \n",
    "\n",
    "    # Create a figure with specified size\n",
    "    fig, axes = plt.subplots(2, 4, figsize=(6, 3))\n",
    "    fig.subplots_adjust(hspace=0.2, wspace=0.6)\n",
    "\n",
    "    ax1=axes[0,0]\n",
    "    ax2=axes[0,2]\n",
    "    ax3=axes[0,1]\n",
    "    ax4=axes[1,0]\n",
    "    ax5=axes[1,1]\n",
    "    ax6=axes[1,2]\n",
    "    ax7=axes[1,3]\n",
    "    ax8=axes[0,3]\n",
    "\n",
    "    # Hide the last subplot in the first row\n",
    "    axes[0, 3].axis('off')\n",
    "\n",
    "    width_scaling_factor = 1.6\n",
    "    new_positions = []\n",
    "\n",
    "    # Calculate new positions for the first row\n",
    "    for i, ax in enumerate(axes[0, :3]):\n",
    "        pos = ax.get_position()\n",
    "        if i == 0:\n",
    "            new_positions.append([pos.x0, pos.y0, pos.width * width_scaling_factor, pos.height])\n",
    "        else:\n",
    "            new_positions.append([new_positions[i-1][0] + new_positions[i-1][2] + 0.066, pos.y0, pos.width * width_scaling_factor, pos.height])\n",
    "\n",
    "    # Apply new positions\n",
    "    for i, ax in enumerate(axes[0, :3]):\n",
    "        ax.set_position(new_positions[i])\n",
    "\n",
    "    #hide yticks\n",
    "    ax2.set_yticks([])\n",
    "    ax1.set_yticks([])\n",
    "    cmap = plt.get_cmap('tab20b')\n",
    "\n",
    "    init = 750\n",
    "    duration = 20*40  \n",
    "    ax1.imshow(spikes[init:init+duration,:].T, aspect='auto', cmap='Greys', interpolation='none', vmax=1)\n",
    "    #xticks 0 to 300 is 0 to 3\n",
    "    ax1.set_xticks([0, 10*40, 20*40])\n",
    "    ax1.set_xticklabels([0, 10, 20])\n",
    "    #hide y ticks\n",
    "    ax1.set_title('data')\n",
    "    ax1.set_ylabel('neuron')\n",
    "    ax1.set_yticks([])\n",
    "    ax1.set_xlabel('time (s)')\n",
    "\n",
    "\n",
    "    t = np.linspace(0, 20, duration)\n",
    "    ax2.plot(t, Z[0][init:init+duration]+50, alpha = 0.9, label=\"Z1\", color=color1)\n",
    "    ax2.plot(t, Z[1][init:init+duration]+60, alpha = 0.9, label=\"Z2\", color=color2)\n",
    "    ax2.plot(t, Z[2][init:init+duration]-35, alpha = 0.9,label=\"Z3\", color=color3)\n",
    "    ax2.plot(t, Z[3][init:init+duration]-100, alpha = 0.9,label=\"Z4\", color=color4)\n",
    "    \n",
    "    #cc = 7*test_lfp[init:init+duration]-70\n",
    "    #ax2.plot(t, cc, alpha = 0.7,label=\"LFP\", color='black')\n",
    "    #xticks\n",
    "    ax2.set_xticks([0, 10, 20])\n",
    "    #hide y ticks\n",
    "    ax2.set_yticks([])\n",
    "    ax2.set_xlabel('time (s)')\n",
    "\n",
    "    ax3.imshow(spikes_hat[init:init+duration,:].T, aspect='auto', cmap='Greys', interpolation='none', vmax=1)\n",
    "    ax3.set_xticks([0, 10*40, 20*40])\n",
    "    ax3.set_xticklabels([0, 10, 20])\n",
    "    ax3.set_title(\"inferred\")\n",
    "    #ax3.set_ylabel('neuron')\n",
    "    ax3.set_xlabel('time (s)')\n",
    "    ax3.set_yticks([])\n",
    "\n",
    "    spikesS = np.expand_dims(spikes, axis=0)  # Add a new dimension along axis 0\n",
    "    spikes_hatS = np.expand_dims(spikes_hat, axis=0)  # Add a new dimension along axis 0\n",
    "    trainS = np.expand_dims(train.T, axis=0)  # Add a new dimension along axis 0\n",
    "    print(spikesS.shape, spikes_hatS.shape, trainS.shape)\n",
    "    iM_data, iV_data = compute_isi_stats(spikesS)\n",
    "    iM_gen, iV_gen = compute_isi_stats(spikes_hatS)\n",
    "    iM_train, iV_train = compute_isi_stats(trainS)\n",
    "\n",
    "    cmap2 = plt.get_cmap('Dark2')\n",
    "    cvD = iV_data/iM_data\n",
    "    cvG = iV_gen/iM_gen\n",
    "    cvT = iV_train/iM_train\n",
    "    ax5.plot([0, 3.4], [0, 3.4], color='gray', linestyle='--', zorder =0 )\n",
    "    #ax5.scatter(iM_data, iM_gen,  s=20, alpha=0.7,color = cmap2(0), edgecolors='black', label= 'gen', linewidths=0.8)\n",
    "    #ax5.scatter(iM_data, iM_train, s=20, alpha=0.7,color = cmap2(1), edgecolors='black', label= 'train', linewidths=0.8)\n",
    "    #ax5.scatter(cvD, cvT, s=10, alpha=0.7,color = tg, label= 'train')\n",
    "    #ax5.scatter(cvD, cvG,  s=10, alpha=0.7,color = tr, label= 'gen')\n",
    "    scatter1 = ax5.scatter(cvD, cvG,  s=10, alpha=0.7,color = tg, label= 'gen')\n",
    "    scatter2 = ax5.scatter(cvD, cvT, s=10, alpha=0.7,color = tr, label= 'train')\n",
    "    zorders = np.random.randint(1, 20, size=(len(iM_data),2))\n",
    "    # Apply random zorder values to each point\n",
    "    for i in range(len(iM_data)):\n",
    "        ax5.scatter(cvD[i], cvG[i], s=10, alpha=0.7, color=tg, zorder=zorders[i][0])\n",
    "        ax5.scatter(cvD[i], cvT[i], s=10, alpha=0.7,color = tr,label= 'train',zorder=zorders[i][1])\n",
    "  \n",
    "    ax5.set_title('cv')\n",
    "    #xticks\n",
    "    ax5.set_xticks([1, 3])\n",
    "    ax5.set_yticks([1, 3])\n",
    "    ax5.set_xlabel('test')\n",
    "    ax5.set_ylabel('gen / train')\n",
    "    #plt.savefig('hpc2_train_one_isi_mean.png', dpi=300)\n",
    "    teS = testing.shape[1]/40\n",
    "    trS = train.shape[1]/40\n",
    "    fr = np.sum(loaded_binary_matrix.T, axis=0)/teS\n",
    "    fr_hat = np.sum(lam.T.cpu().detach().numpy(), axis = 0)/teS\n",
    "    train_fr = np.sum(train.T, axis=0)/trS\n",
    "    ax4.plot(np.linspace(0, 4.8, 20), np.linspace(0, 4.8, 20), color='gray', linestyle='--', zorder=0)\n",
    "    scatter3 = ax4.scatter(fr, fr_hat, s=10, alpha=0.7, color = tg)\n",
    "    scatter4 = ax4.scatter(fr, train_fr, s=10, alpha=0.7, color = tr)\n",
    "    for i in range(len(fr)):\n",
    "        ax4.scatter(fr[i], fr_hat[i], s=10, alpha=0.7, color = tg,  zorder=zorders[i][0])\n",
    "        ax4.scatter(fr[i], train_fr[i], s=10, alpha=0.7,color = tr, zorder=zorders[i][1])\n",
    "      \n",
    "    \n",
    "    ax4.set_title('mean rates (hz)')\n",
    "    #make it log log\n",
    "    #xticks\n",
    "    #ax4.set_xscale('log')\n",
    "    #ax4.set_yscale('log')\n",
    "    ax4.set_xticks([1, 4])\n",
    "    ax4.set_yticks([1, 4])\n",
    "    ax4.set_xlabel('test')\n",
    "    ax4.set_ylabel('gen / train')\n",
    "\n",
    "        # Custom legend labels and colors\n",
    "    legend_labels = ['test/gen', 'test/train']\n",
    "    legend_colors = [tg, tr]\n",
    "\n",
    "    # Add custom legend\n",
    "    legend_elements = [plt.Line2D([0], [0], color=color, lw=0, label=label) for color, label in zip(legend_colors, legend_labels)]\n",
    "    legend = ax4.legend(handles=legend_elements, handletextpad=0, handlelength=0, fancybox=True, loc='upper right')\n",
    "\n",
    "    for text, color in zip(legend.get_texts(), legend_colors):\n",
    "        text.set_color(color)\n",
    "    \n",
    "    ax6.semilogy(frequencies0, psd1,  color=color2, alpha=0.9, zorder=0, label = 'latent 1')\n",
    "    #ax6.semilogy(frequencies1, psd3, color=color4, alpha=0.9, zorder=0, label = 'latent 2')\n",
    "    #ax6.semilogy(frequencies2, psd2,  color=color3, alpha=0.9, zorder=0, label = 'latent 3')\n",
    "    #ax6.semilogy(frequencies3, psd0,  color=color1, alpha=0.9, zorder=0, label = 'latent 4')\n",
    "    #ax6.semilogy(frequencies6, psd6,  color='black', alpha=0.7, zorder=0, label = 'LFP')\n",
    "    #ax6.set_ylim([10**-4, 10])\n",
    "    #ax6.set_title('psd')\n",
    "    #ax6.tick_params(axis='y', which='both', width=1)\n",
    "    #ax6.set_xlabel('frequency (hz)')\n",
    "    ax4.legend(loc='lower right', bbox_to_anchor=(0.3, 1.2))\n",
    "    ax6.legend(loc='upper right', bbox_to_anchor=(4, 3))\n",
    "\n",
    "    legend_labels = ['$z_1$', '$z_2$', '$z_3$',  '$z_4$','LFP']\n",
    "    legend_colors = [color1, color2, color3, color4, 'black']\n",
    "#\n",
    "    # Add the custom legend to the plot\n",
    "    legend = ax6.legend(legend_labels, handletextpad=0, handlelength=0, fancybox=True, loc='upper right', bbox_to_anchor=(3.2, 2.95))\n",
    "    for text, color in zip(legend.get_texts(), legend_colors):\n",
    "        text.set_color(color)\n",
    "    ax2.set_xlim(0, 20)\n",
    "    ax2.set_title('latents')\n",
    "    #only show ticks at 1, 10, 100\n",
    "    ax6.set_yticks([0.001, 0.1])\n",
    "    ax6.set_yticklabels(['0.001', '0.1'])\n",
    "    ax6.set_xticks([0.2, 10])\n",
    "    ax6.set_xticklabels(['0.2', '10'])\n",
    "    ax6.set_xlim([0, 20])\n",
    "    # vertical line at 0.2\n",
    "\n",
    "    init = int(5.3*40)\n",
    "    duration = 1*40\n",
    "    t = np.linspace(0, 1, duration)\n",
    "    ax7.plot(t, Z[0][init:init+duration]+50, alpha = 0.9, label=\"Z1\", color=color1)\n",
    "    ax7.plot(t, Z[1][init:init+duration]+60, alpha = 0.9, label=\"Z4\", color=color2)\n",
    "    ax7.plot(t, Z[2][init:init+duration]-35, alpha = 0.9,label=\"Z3\", color=color3)\n",
    "    ax7.plot(t, Z[3][init:init+duration]-100, alpha = 0.9,label=\"Z2\", color=color4)\n",
    "    \n",
    "    cc = 10*test_lfp[init:init+duration]+150\n",
    "    ax7.plot(t, cc, alpha = 0.7,label=\"LFP\", color='black')\n",
    "    ax7.set_xlim([0, 1])\n",
    "    #set x ticks\n",
    "    ax7.set_xticks([0, 1])\n",
    "    #hide y ticks\n",
    "    ax7.set_yticks([])\n",
    "\n",
    "    #t = np.linspace(0, 5, duration)\n",
    "    #init = 0\n",
    "    #duration = 33*40\n",
    "    #cmap3 = plt.get_cmap('tab20')\n",
    "    #t = np.linspace(0, 33, duration)\n",
    "\n",
    "    #pr = '#9BB5DE'\n",
    "    #rat = '#2B3073'\n",
    "    #ax7.plot(t, y_pred[init:init+duration], color=pr, alpha = 0.7, label='predicted', linewidth=1.5)\n",
    "    #ax7.plot(t, test_locs[init:init+duration], alpha = 0.7, color=rat, label='rat', linewidth=1.5)\n",
    "    #ax7.set_xlabel('time (s)')\n",
    "    #ax7.legend(loc='upper right', bbox_to_anchor=(2.45, 0.8))\n",
    "    #ax7.set_xlim(0, 33)\n",
    "    #ax7.set_xticks([0, 15, 30])\n",
    "    #ax7.set_xticklabels([0, 15, 30])\n",
    "    #ax7.plot(f2, Cxy2**2, color=cmap(15), label='posterior', linewidth=1.5)\n",
    "    #ax7.plot(pf2, pCxy2**2, color=cmap(13), label='gen. latents', linewidth=1.5 )\n",
    "    #legend_labels = ['predicted', 'rat']\n",
    "    #legend_colors = [pr, rat]\n",
    "    # Add the custom legend to the plot\n",
    "    #legend = ax7.legend(legend_labels, handletextpad=0, handlelength=0, fancybox=True, loc='upper right', bbox_to_anchor=(1.8, 0.8))\n",
    "    #for text, color in zip(legend.get_texts(), legend_colors):\n",
    "    #    text.set_color(color)\n",
    "    #hide yticks\n",
    "    #ax7.set_yticks([])\n",
    "    #ax7.set_xlim([0, 30])\n",
    "    #ax7.plot(t, Z[0][init:init+duration]-40, alpha = 0.9, label=\"Z1\", color=cmap(1))\n",
    "    #ax7.plot(t, Z[3][init:init+duration], alpha = 0.9, label=\"Z4\", color=cmap(9))\n",
    "    #ax7.plot(t, Z[2][init:init+duration]-40, alpha = 0.9,label=\"Z3\", color=cmap(5))\n",
    "    #ax7.plot(t, Z[1][init:init+duration]-20, alpha = 0.9,label=\"Z2\", color=cmap(13))\n",
    "    #cc = 7*lfp_filtered[init:init+duration]-70\n",
    "    #ax7.plot(t, cc, alpha = 0.7,label=\"LFP\", color='black')\n",
    "    #xticks\n",
    "    #ax7.set_xticks([0, 2.5, 5])\n",
    "    #hide y ticks\n",
    "    #ax7.set_yticks([])\n",
    "    #yticks scientific notation, not 0.1 but 1e-1\n",
    "    #ax7.set_xticks([0.2, 8, 15])\n",
    "    #ax7.set_xlim([-1, 17])\n",
    "    #ax7.set_yticks([0, 0.4])\n",
    "    ax7.set_title('latents')\n",
    "    #ax7.set_xlabel('frequency (hz)')\n",
    "    plt.gcf().set_size_inches(5.2, 3)\n",
    "    #put legends outside\n",
    "    ax1.set_box_aspect(0.625)\n",
    "    ax2.set_box_aspect(0.625)\n",
    "    ax3.set_box_aspect(0.625)\n",
    "    ax7.set_box_aspect(1)\n",
    "    ax4.set_box_aspect(1)\n",
    "    ax5.set_box_aspect(1)\n",
    "    ax6.set_box_aspect(1)\n",
    "    #ax6.legend()\n",
    "    plt.savefig('hpc11_main.png', dpi=300)\n",
    "    #pdf save\n",
    "    plt.savefig('hpc11_main.pdf', dpi=300)\n",
    "    \n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(y[900:1000])\n",
    "plt.plot(y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(y[900:1000])\n",
    "plt.plot(y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "labrot",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
