{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "import os\n",
    "import sys\n",
    "file_dir = os.getcwd()\n",
    "sys.path.append(file_dir)\n",
    "from evaluation.klx_gmm import calc_kl_from_data\n",
    "from evaluation.pse import power_spectrum_helling\n",
    "import torch\n",
    "import numpy as np\n",
    "from rnn.datasets import Basic_dataset, datasetHPC11\n",
    "from rnn.saving import load_model\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from matplotlib.colors import ListedColormap\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from matplotlib.animation import FuncAnimation\n",
    "\n",
    "seed = 42\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "i_code = '#C23B22'\n",
    "real_code = '##779ECB'\n",
    "gen_code = '#56AE57'\n",
    "\n",
    "i_color = (0.7608, 0.2314, 0.1333, 1)\n",
    "real_color = (0.4667, 0.6196, 0.7961, 1)\n",
    "gen_color = (0.3373, 0.6824, 0.3412, 1)\n",
    "average_color = tuple((x + y) / 2 for x, y in zip(real_color, gen_color))\n",
    "\n",
    "##mix of real and gen\n",
    "mix_color = (0.5, 0.5, 0.5, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "testing = np.load('test_first_hpc.npy')\n",
    "testing = testing[:,0:-490]\n",
    "\n",
    "\n",
    "filename = '/YOURDIRECTORY/model_hpc2/_CNN_causal_PLRNN_Z_Date_32024_05_04_T_20_45_16'\n",
    "\n",
    "vae, params, task_params, training_params = load_model(str(filename))\n",
    "loaded_binary_matrix = testing\n",
    "task      = Basic_dataset(task_params, loaded_binary_matrix.T.astype(np.float32))\n",
    "dim_x     = task.data.shape[1]\n",
    "print(\"Data shape: \", task.data.shape)\n",
    "\n",
    "train = np.load('train_first_hpc.npy')\n",
    "\n",
    "#set seed\n",
    "seed = 42\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "T = task.data.shape[0]\n",
    "spikes = task.data[:T]                      # (T, X)\n",
    "x_sample = spikes.T.unsqueeze(0)            # (1, X, T)\n",
    "z_hat, _,_, _ = vae.encoder(x_sample)          # (1, Z, T), _, _\n",
    "z0 = z_hat[0, :, 0] #+ torch.randn_like(z_hat[0, :, 0]) * 10\n",
    "Z = vae.prior.get_latent_time_series(time_steps=T + 1000, cut_off=0, z0=z0, noise_scale=1) # (1, Z, T)\n",
    "lam = vae.prior.get_observation(Z, noise_scale=0) \n",
    "print('lam before', np.shape(lam))     \n",
    "lam = lam[0,:,1000:,0]\n",
    "print('lam after', np.shape(lam))     \n",
    "lam = torch.exp(lam)          # (1, X, T) (from Z to N to X) \n",
    "Z = Z[0]\n",
    "print(np.shape(Z))\n",
    "Z = Z[:,1000:,0]\n",
    "spikes_hat = torch.poisson(lam).T\n",
    "print(np.shape(Z))\n",
    "\n",
    "m_or = vae.prior.transition.m #20,2\n",
    "n_or = vae.prior.transition.n\n",
    "J = m_or@n_or\n",
    "u,s,v = torch.linalg.svd(J)\n",
    "projection_matrix = u[:,:vae.dim_z].T@m_or\n",
    "\n",
    "Z = projection_matrix@Z\n",
    "\n",
    "\n",
    "Z = Z.detach().numpy()\n",
    "spikes = spikes.detach().numpy()\n",
    "spikes_hat = spikes_hat.detach().numpy()\n",
    "\n",
    "truncated_spk = loaded_binary_matrix[:, :]\n",
    "task      = Basic_dataset(task_params, truncated_spk.T.astype(np.float32))\n",
    "x=task.data.T.unsqueeze(0)\n",
    "print(x.shape)\n",
    "marginal_smoothing = False\n",
    "t_held_in = truncated_spk.shape[1]\n",
    "k=128\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in range(3):\n",
    "        if i == 0:\n",
    "            Qzs_filt_avg, Qzs_sm_avg, Xs_filt_avg, Xs_sm_avg = vae.predict_NLB(x, u=None,k=k,t_held_in =t_held_in, t_forward=0,marginal_smoothing=marginal_smoothing)\n",
    "        else:        \n",
    "            Qzs_filt_avg_tmp, Qzs_sm_avg_tmp, Xs_filt_avg_tmp, Xs_sm_avg_tmp = vae.predict_NLB(x, u=None, k=k, t_held_in=t_held_in, t_forward=0, marginal_smoothing=marginal_smoothing)\n",
    "            \n",
    "            Qzs_filt_avg += Qzs_filt_avg_tmp\n",
    "            Qzs_sm_avg += Qzs_sm_avg_tmp\n",
    "            Xs_filt_avg += Xs_filt_avg_tmp\n",
    "            Xs_sm_avg += Xs_sm_avg_tmp\n",
    "    \n",
    "    Qzs_filt_avg /= 3\n",
    "    Qzs_sm_avg /= 3\n",
    "    Xs_filt_avg /= 3\n",
    "    Xs_sm_avg /= 3\n",
    "  \n",
    "Qzs_filt_avg.shape\n",
    "\n",
    "Z1 = torch.mean(Qzs_sm_avg[0,0,:,:],axis=1)\n",
    "Z2 = torch.mean(Qzs_sm_avg[0,1,:,:],axis=1)\n",
    "Z3 = torch.mean(Qzs_sm_avg[0,2,:,:],axis=1)\n",
    "#Z score Z1, Z2, Z3\n",
    "Z1 = (Z1 - torch.mean(Z1))/torch.std(Z1)\n",
    "Z2 = (Z2 - torch.mean(Z2))/torch.std(Z2)\n",
    "Z3 = (Z3 - torch.mean(Z3))/torch.std(Z3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from neo.io.neuroscopeio import NeuroScopeIO\n",
    "from neo.io.klustakwikio import KlustaKwikIO\n",
    "\n",
    "##for lfp analysis download the data from crcns hc2, ec013527\n",
    "path = \"/YOURDIRECTORY/ec013527\" # change to where ever you put the files\n",
    "filen= path+\"/ec013527.xml\"\n",
    "reader = NeuroScopeIO(filename=filen)\n",
    "\n",
    "from scipy.signal import resample\n",
    "\n",
    "seg = reader.read_segment(lazy=False)\n",
    "t, c = np.shape(seg.analogsignals[0])\n",
    "ds = []\n",
    "for i in range(c):\n",
    "    lfp = np.array(seg.analogsignals[0][:,i])\n",
    "    resample_rate = 12.5\n",
    "    #resample\n",
    "    n_samples = int(len(lfp)/resample_rate)\n",
    "    lfp_ds = resample(lfp,n_samples)\n",
    "    ds.append(lfp_ds)\n",
    "\n",
    "alla = np.load('train_first_hpc.npy')\n",
    "\n",
    "lfp = np.array(ds)[:,:,0]\n",
    "print(lfp.shape)\n",
    "test_lfp = lfp[:,int(alla.shape[1]):int(alla.shape[1]) + int(Qzs_filt_avg.shape[2])]\n",
    "print(lfp.shape)\n",
    "#z score lfp\n",
    "for i in range(test_lfp.shape[0]):\n",
    "    test_lfp[i] = (test_lfp[i] - np.mean(test_lfp[i]))/np.std(test_lfp[i])\n",
    "\n",
    "#take the mean alxong the channels\n",
    "test_lfp = np.mean(test_lfp, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = plt.get_cmap('tab10')\n",
    "\n",
    "init = 1000\n",
    "duration = 1000\n",
    "\n",
    "x_index = np.linspace(0, 400, 7)\n",
    "x_loc = np.linspace(0, 10, 7)\n",
    "#round the x_loc to 2 decimal places\n",
    "x_loc = [round(i, 2) for i in x_loc]\n",
    "\n",
    "tri = 0\n",
    "latent_ind = 0\n",
    "#ZSCORE z[0], z[1], z[2]\n",
    "pZ1 = (Z[0] - np.mean(Z[0]))/np.std(Z[0])\n",
    "pZ2 = (Z[1] - np.mean(Z[1]))/np.std(Z[1])\n",
    "pZ3 = (Z[2] - np.mean(Z[2]))/np.std(Z[2])\n",
    "init = 0\n",
    "dur = 6000\n",
    "\n",
    "#plt.plot(pZ1[init:init+dur],color=cmap(1), linestyle = 'dashed', label='priorZ1');\n",
    "#plt.plot(pZ2[init:init+dur],color=cmap(2), linestyle = 'dashed', label='priorZ2');\n",
    "#plt.plot(pZ3[init:init+dur],color=cmap(3), linestyle = 'dashed', label='priorZ3');\n",
    "plt.plot(Z1[init:init+dur],color=cmap(1), linestyle = 'dashed', label='postZ1');\n",
    "plt.plot(Z2[init:init+dur],color=cmap(2), linestyle = 'dashed', label='postZ2');\n",
    "plt.plot(Z3[init:init+dur],color=cmap(3), linestyle = 'dashed', label='postZ3');\n",
    "plt.plot(test_lfp[init:init+dur],color='black', alpha=0.7,linestyle = 'dashed', label='lfp');\n",
    "#plt.plot(pZ5[500:800],color=cmap(2), linestyle = 'dashed', label='priorZ5');\n",
    "#plt.plot(pZ6[500:800],color=cmap(3), linestyle = 'dashed', label='priorZ6');\n",
    "plt.xlabel(\"timesteps\")\n",
    "plt.xlim(850,1000)\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.signal import welch\n",
    "# Define the sampling frequency\n",
    "fs = 100 # Hz\n",
    "nperseg = 1024\n",
    "frequencies0, psd0 = welch(test_lfp.reshape(-1), fs=fs, nperseg=nperseg)  # Adjust nperseg as needed\n",
    "#frequencies1, psd1 = welch(maze_filtered.reshape(-1), fs=fs, nperseg=nperseg)  # Adjust nperseg as needed\n",
    "\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.semilogy(frequencies0, psd0, alpha=0.8, zorder=0, label = 'lfp')\n",
    "#plt.semilogy(frequencies1, psd1, alpha=0.8, zorder=0, label = 'lfp_filtered')\n",
    "plt.xlabel(\"Frequency (f)\", labelpad=10, fontsize=14)\n",
    "plt.ylabel(\"PSD of LFP\", labelpad=10, fontsize=14) \n",
    "#prevent x-axis from being cut off\n",
    "plt.xlim(0, 20)\n",
    "#plt.ylim(1e-4, 1)\n",
    "plt.subplots_adjust(bottom=0.3)\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = plt.get_cmap(\"tab10\")\n",
    "# Plotting\n",
    "plt.figure(figsize=(12, 3))\n",
    "init = 500\n",
    "dur = 5500\n",
    "plt.plot(Z1[init:init+dur], color=cmap(1), label='postZ1')\n",
    "plt.plot(Z2[init:init+dur], color=cmap(2), label='postZ2')\n",
    "plt.plot(Z3[init:init+dur], color=cmap(3), label='postZ3')\n",
    "\n",
    "plt.legend()\n",
    "plt.xlabel('Time (s)')\n",
    "plt.ylabel('Amplitude')\n",
    "plt.xlim(500, 1000)\n",
    "plt.title('Aligned LFP and Latent Variables via Interpolation')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#DETACH Z1, Z2, Z3\n",
    "Z1 = Z1.detach().numpy()\n",
    "Z2 = Z2.detach().numpy()\n",
    "Z3 = Z3.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#bandpass lfp\n",
    "import scipy.signal\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "# Define the sampling frequency and time vector\n",
    "fs = 100  # Sampling frequency\n",
    "t = np.arange(0, Qzs_filt_avg.shape[2]) / fs  # Time vector\n",
    "print(Qzs_filt_avg.shape)\n",
    "# Define your original cutoff frequencies\n",
    "lowcut = 1  # Example: 10 Hz\n",
    "highcut = 40  # Example: 300 Hz\n",
    "# Normalize the frequency by dividing by the Nyquist frequency (which is fs/2)\n",
    "low = lowcut / (fs / 2)\n",
    "high = highcut / (fs / 2)\n",
    "order = 4\n",
    "# Apply the bandpass filter\n",
    "b, a = scipy.signal.butter(order, [low, high], btype='band')\n",
    "lfp_filtered = scipy.signal.filtfilt(b, a, test_lfp, axis=0)\n",
    "Z1_filtered = scipy.signal.filtfilt(b, a, Z1, axis=0)\n",
    "Z2_filtered = scipy.signal.filtfilt(b, a, Z2, axis=0)\n",
    "Z3_filtered = scipy.signal.filtfilt(b, a, Z3, axis=0)\n",
    "\n",
    "pZ1_filtered = scipy.signal.filtfilt(b, a, pZ1, axis=0)\n",
    "pZ2_filtered = scipy.signal.filtfilt(b, a, pZ2, axis=0)\n",
    "pZ3_filtered = scipy.signal.filtfilt(b, a, pZ3, axis=0)\n",
    "\n",
    "# Plot the original and filtered signals\n",
    "plt.figure(figsize=(12, 3))\n",
    "#plt.plot(t, lfp_interpolated, label='Original LFP', color='black')\n",
    "plt.plot(t, lfp_filtered, label='Filtered LFP', color='black')\n",
    "plt.plot(t, test_lfp, label='LFP', color='red')\n",
    "plt.xlabel('Time (s)')\n",
    "plt.ylabel('Amplitude')\n",
    "plt.legend()\n",
    "plt.ylim(-2, 2)\n",
    "plt.xlim(6, 8)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.signal import welch\n",
    "\n",
    "# Assuming your time series data is stored in a variable called 'data'\n",
    "# Make sure 'data' is a 1D numpy array\n",
    "\n",
    "# Define the sampling frequency\n",
    "fs = 100  # Hz\n",
    "nperseg = 1024\n",
    "Z1_filtered = Z1_filtered.reshape(-1)\n",
    "Z2_filtered = Z2_filtered.reshape(-1)\n",
    "Z3_filtered = Z3_filtered.reshape(-1)\n",
    "\n",
    "lfp_filtered = lfp_filtered.reshape(-1)\n",
    "frequencies0, psd0 = welch(Z1_filtered.T, fs=fs, nperseg=nperseg)  # Adjust nperseg as needed\n",
    "frequencies1, psd1 = welch(Z2_filtered.T, fs=fs, nperseg=nperseg) \n",
    "frequencies2, psd2 = welch(Z3_filtered.T, fs=fs, nperseg=nperseg) \n",
    "\n",
    "frequencies3, psd3 = welch(pZ1_filtered.T, fs=fs, nperseg=nperseg)  # Adjust nperseg as needed\n",
    "frequencies4, psd4 = welch(pZ2_filtered.T, fs=fs, nperseg=nperseg) \n",
    "frequencies5, psd5 = welch(pZ3_filtered.T, fs=fs, nperseg=nperseg) \n",
    "\n",
    "frequencies6, psd6 = welch(test_lfp.T, fs=fs, nperseg=nperseg) \n",
    "#find the peak frequency\n",
    "\n",
    "#peak2 = frequencies2[np.argmax(psd2)]\n",
    "\n",
    "# Plot the PSD\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.xlim(0, 50)\n",
    "plt.semilogy(frequencies3, psd3, color=cmap(1), alpha=0.8, zorder=0, label = 'rnn Z1')\n",
    "plt.semilogy(frequencies4, psd4,  color=cmap(3), alpha=0.8, zorder=0, label = 'rnn Z2')\n",
    "plt.semilogy(frequencies5, psd5,  color=cmap(0), alpha=0.8, zorder=0, label = 'rnn Z3')\n",
    "cmap = plt.get_cmap('tab20')\n",
    "#plt.semilogy(frequencies3, psd3, color=cmap(1), alpha=0.8, zorder=0, label = 'latent 1')\n",
    "#plt.semilogy(frequencies4, psd4,  color=cmap(3), alpha=0.8, zorder=0, label = 'latent 2')\n",
    "#plt.semilogy(frequencies5, psd5,  color=cmap(0), alpha=0.8, zorder=0, label = 'latent 3')\n",
    "plt.semilogy(frequencies6, psd6,  color='black', alpha=0.8, zorder=0, label = 'LFP')\n",
    "\n",
    "#x log scale\n",
    "plt.xscale('log')\n",
    "#plt.semilogy(frequencies6, psd6,  color='gray', label = 'LFP')\n",
    "plt.xlabel(\"Frequency (f)\", labelpad=10, fontsize=14)\n",
    "plt.ylabel(\"PSD of Latents\", labelpad=10, fontsize=14) \n",
    "#prevent x-axis from being cut off\n",
    "plt.xlim(0, 20)\n",
    "plt.ylim(1e-5, 10)\n",
    "plt.subplots_adjust(bottom=0.3)\n",
    "plt.savefig('psd_latent.png')\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#correlation between latents and LFP\n",
    "import numpy as np\n",
    "from scipy.stats import pearsonr\n",
    "from scipy.stats import spearmanr\n",
    "\n",
    "# Calculate the correlation\n",
    "locs = np.load('test_loc.npy')\n",
    "corr1 = np.corrcoef(test_lfp, Z1)[0, 1]\n",
    "corr2 = np.corrcoef(test_lfp, Z2)[0, 1]\n",
    "corr3 = np.corrcoef(test_lfp, Z3)[0, 1]\n",
    "\n",
    "#corr1 = np.corrcoef(lfp_filtered, Z1_filtered)[0, 1]\n",
    "#corr2 = np.corrcoef(lfp_filtered, Z2_filtered)[0, 1]\n",
    "#corr3 = np.corrcoef(lfp_filtered, Z3_filtered)[0, 1]\n",
    "#corr4 = np.corrcoef(lfp_filtered, Z4_filtered)[0, 1]\n",
    "\n",
    "print('Corr: %.3f' % corr1)\n",
    "print('Corr: %.3f' % corr2)\n",
    "print('Corr: %.3f' % corr3)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#import signal\n",
    "from scipy import signal\n",
    "\n",
    "corr1 = signal.correlate(test_lfp, Z1)\n",
    "lags1 = signal.correlation_lags(len(test_lfp), len(Z1))\n",
    "corr1 /= np.max(corr1)\n",
    "\n",
    "corr2 = signal.correlate(test_lfp, Z2)\n",
    "lags2 = signal.correlation_lags(len(test_lfp), len(Z2))\n",
    "corr2 /= np.max(corr2)\n",
    "\n",
    "corr3 = signal.correlate(test_lfp, Z3)\n",
    "lags3 = signal.correlation_lags(len(test_lfp), len(Z3))\n",
    "corr3 /= np.max(corr3)\n",
    "\n",
    "# 2 by 2 plot for the correlation\n",
    "plt.figure(figsize=(12, 6))\n",
    "plt.subplot(2, 2, 1)\n",
    "plt.plot(lags1, corr1)\n",
    "plt.title('Correlation with Z1')\n",
    "plt.xlabel('Lags')\n",
    "plt.ylabel('Correlation')\n",
    "plt.xlim(-100, 100)\n",
    "plt.subplot(2, 2, 2)\n",
    "plt.plot(lags2, corr2)\n",
    "plt.title('Correlation with Z2')\n",
    "plt.xlabel('Lags')\n",
    "plt.ylabel('Correlation')\n",
    "plt.xlim(-100, 100)\n",
    "plt.subplot(2, 2, 3)\n",
    "plt.plot(lags3, corr3)\n",
    "plt.title('Correlation with Z3')\n",
    "plt.xlabel('Lags')\n",
    "plt.ylabel('Correlation')\n",
    "\n",
    "plt.xlim(-100, 100)\n",
    "plt.tight_layout()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.signal import coherence\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Sampling frequency in Hz\n",
    "fs = 100\n",
    "\n",
    "# Compute coherence\n",
    "f1, Cxy1 = coherence(Z1, test_lfp, fs=fs, nperseg=1024)\n",
    "f2, Cxy2 = coherence(Z2, test_lfp, fs=fs, nperseg=1024)\n",
    "f3, Cxy3 = coherence(Z3, test_lfp, fs=fs, nperseg=1024)\n",
    "\n",
    "pf1, pCxy1 = coherence(pZ1, test_lfp, fs=fs, nperseg=1024)\n",
    "pf2, pCxy2 = coherence(pZ2, test_lfp, fs=fs, nperseg=1024)\n",
    "pf3, pCxy3 = coherence(pZ3, test_lfp, fs=fs, nperseg=1024)\n",
    "#convolve with gaussian kernel\n",
    "from scipy.ndimage import gaussian_filter1d\n",
    "Cxy1 = gaussian_filter1d(Cxy1, sigma=3)\n",
    "pCxy1 = gaussian_filter1d(pCxy1, sigma=3)\n",
    "Cxy2 = gaussian_filter1d(Cxy2, sigma=3)\n",
    "pCxy2 = gaussian_filter1d(pCxy2, sigma=3)\n",
    "Cxy3 = gaussian_filter1d(Cxy3, sigma=3)\n",
    "pCxy3 = gaussian_filter1d(pCxy3, sigma=3)\n",
    "\n",
    "# Plot squared coherence\n",
    "#plt.figure(figsize=(1, 1))\n",
    "\n",
    "plt.plot(pf1, pCxy1**2, label='posterior 1')  # Squared coherence\n",
    "plt.plot(f1, Cxy1**2, label='rnn 1')  # Squared coherence\n",
    "\n",
    "plt.plot(pf3, pCxy3**2, label='posterior 3')  # Squared coherence\n",
    "plt.plot(f3, Cxy3**2, label='rnn 3')  # Squared coherence\n",
    "\n",
    "plt.plot(pf2, pCxy2**2, label='posterior 2')  # Squared coherence\n",
    "plt.plot(f2, Cxy2**2, label='rnn 2')  # Squared coherence\n",
    "\n",
    "plt.xlim(0, 20)\n",
    "#plt.plot(f3, Cxy3**2, label ='z3') # Squared coherence\n",
    "plt.legend()\n",
    "plt.xlabel('Frequency [Hz]')\n",
    "plt.ylabel('Squared Coherence')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#linear regression from latents to LFP\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "# Define the model\n",
    "model = LinearRegression()\n",
    "\n",
    "#X = alla.T\n",
    "X= np.vstack((Z1, Z2, Z3)).T\n",
    "y = test_lfp\n",
    "\n",
    "# Fit the model\n",
    "model.fit(X, y)\n",
    "\n",
    "# Make predictions\n",
    "y_pred = model.predict(X)\n",
    "\n",
    "#zscore predicted lfp\n",
    "#y_pred = (y_pred - np.mean(y_pred))/np.std(y_pred)\n",
    "# Plot the predicted vs. actual values\n",
    "#szcore y_pred\n",
    "plt.figure(figsize=(6, 2))\n",
    "plt.plot(y_pred, label='Predicted', color='black')\n",
    "plt.plot(y, label='Actual', color='cyan') \n",
    "plt.legend()\n",
    "#r2 score\n",
    "r2 = np.corrcoef(y, y_pred)[0, 1]\n",
    "plt.title(f'corr: {r2:.5f}')\n",
    "print(np.corrcoef(y, y_pred)[0, 1])\n",
    "plt.xlim(0, 500)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "neurons = spikes_hat.shape[1]\n",
    "def calculate_cross_correlation(data):\n",
    "    \"\"\"Calculate the cross-correlation matrix for a dataset.\"\"\"\n",
    "    correlation_matrix = np.corrcoef(data, rowvar=False)\n",
    "    return correlation_matrix\n",
    "\n",
    "print(spikes.shape, spikes_hat.shape, train.shape)\n",
    "# Calculate cross-correlation matrices\n",
    "test_correlation = calculate_cross_correlation(spikes_hat)\n",
    "gen_correlation = calculate_cross_correlation(spikes)\n",
    "train_correlation = calculate_cross_correlation(train.T)\n",
    "\n",
    "# Extracting upper triangle values without the diagonal\n",
    "i_upper = np.triu_indices(neurons, k=1)\n",
    "test_corr_values = test_correlation[i_upper]\n",
    "gen_corr_values = gen_correlation[i_upper]\n",
    "train_corr_values = train_correlation[i_upper]\n",
    "\n",
    "# Plotting the scatter plot\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.scatter(test_corr_values, gen_corr_values, color='green', alpha=0.6)\n",
    "plt.scatter(test_corr_values, train_corr_values, color='blue', alpha=0.6)\n",
    "plt.title('Comparison of Neuron Cross-Correlation')\n",
    "plt.xlabel('Real Data Cross-Correlation')\n",
    "plt.ylabel('Generated Data Cross-Correlation')\n",
    "#plot x = y line\n",
    "plt.plot([-0.1, 0.15], [-0.1, 0.15], color='red', linestyle='--')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def binned_spikes_to_times(binned_spikes, bin_size=0.01):\n",
    "    idxs = np.nonzero(binned_spikes)[0]\n",
    "    spike_times = np.repeat(idxs, binned_spikes[idxs].astype(int))\n",
    "    spike_times = spike_times.astype(float)\n",
    "    for idx in idxs:\n",
    "        if binned_spikes[idx] > 1:\n",
    "            spike_times[spike_times == idx] += np.arange(binned_spikes[idx]) / binned_spikes[idx]\n",
    "    spike_times *= bin_size\n",
    "    return spike_times\n",
    "\n",
    "def compute_isi_stats(spikes, bin_size=0.01):\n",
    "    n_trials, n_timesteps, n_neurons = spikes.shape\n",
    "    isi_means = []\n",
    "    isi_stds = []\n",
    "    for neuron in range(n_neurons):\n",
    "        isis = np.concatenate([np.diff(binned_spikes_to_times(spikes[i,:,neuron], bin_size=bin_size)) for i in range(n_trials)])\n",
    "        isi_means.append(np.mean(isis))\n",
    "        isi_stds.append(np.std(isis))\n",
    "    return np.array(isi_means), np.array(isi_stds)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.gridspec as gridspec\n",
    "import matplotlib as mpl\n",
    "with mpl.rc_context(fname=\"style_matplotlib.rc\"): \n",
    "    # Create a figure with specified size\n",
    "    fig, axes = plt.subplots(2, 4, figsize=(6, 3))\n",
    "    fig.subplots_adjust(hspace=0.2, wspace=0.6)\n",
    "\n",
    "    ax1=axes[0,0]\n",
    "    ax2=axes[0,2]\n",
    "    ax3=axes[0,1]\n",
    "    ax4=axes[1,0]\n",
    "    ax5=axes[1,1]\n",
    "    ax6=axes[1,2]\n",
    "    ax7=axes[1,3]\n",
    "    ax8=axes[0,3]\n",
    "\n",
    "    # Hide the last subplot in the first row\n",
    "    axes[0, 3].axis('off')\n",
    "\n",
    "    width_scaling_factor = 1.6\n",
    "    new_positions = []\n",
    "\n",
    "    # Calculate new positions for the first row\n",
    "    for i, ax in enumerate(axes[0, :3]):\n",
    "        pos = ax.get_position()\n",
    "        if i == 0:\n",
    "            new_positions.append([pos.x0, pos.y0, pos.width * width_scaling_factor, pos.height])\n",
    "        else:\n",
    "            new_positions.append([new_positions[i-1][0] + new_positions[i-1][2] + 0.066, pos.y0, pos.width * width_scaling_factor, pos.height])\n",
    "\n",
    "    # Apply new positions\n",
    "    for i, ax in enumerate(axes[0, :3]):\n",
    "        ax.set_position(new_positions[i])\n",
    "\n",
    "    #hide yticks\n",
    "    ax2.set_yticks([])\n",
    "    #ax1.set_yticks([])\n",
    "    cmap = plt.get_cmap('tab20b')\n",
    "\n",
    "    init = 3100\n",
    "    duration = 3*100\n",
    "    ax1.imshow(spikes[init:init+duration,:].T, aspect='auto', cmap='Greys', interpolation='none', vmax=1)\n",
    "    #xticks 0 to 300 is 0 to 3\n",
    "    ax1.set_xticks([0, 1*100, 2*100, 3*100])\n",
    "    ax1.set_xticklabels([0, 1, 2, 3])\n",
    "    #hide y ticks\n",
    "    ax1.set_title('spikes from rat hippocampus')\n",
    "    #ax3.set_ylabel('neuron')\n",
    "    ax3.set_yticks([])\n",
    "    ax1.set_yticks([])\n",
    "    ax1.set_xlabel('time (s)')\n",
    "    color1= '#7B46C1'\n",
    "    color2= '#A860AF'\n",
    "    color3= '#7C277D' \n",
    "\n",
    "    t = np.linspace(0, 3, duration)\n",
    "    ax2.plot(t, Z[0][init:init+duration]-55, alpha = 0.9, label=\"Z1\", color=color1)\n",
    "    ax2.plot(t, Z[2][init:init+duration], alpha = 0.9,label=\"Z3\", color=color2)\n",
    "    ax2.plot(t, Z[1][init:init+duration]-10, alpha = 0.9,label=\"Z2\", color=color3)\n",
    "    cc = 7*lfp_filtered[init:init+duration]-130\n",
    "    ax2.plot(t, cc, alpha = 0.7,label=\"LFP\", color='black')\n",
    "\n",
    "    #xticks\n",
    "    ax2.set_xticks([0, 1, 2,3])\n",
    "    ax2.set_xlim(0, 3)\n",
    "    ax2.set_title('latents')\n",
    "    ax2.set_yticks([])\n",
    "    ax2.set_xlabel('time (s)')\n",
    "\n",
    "    ax3.imshow(spikes_hat[init:init+duration,:].T, aspect='auto', cmap='Greys', interpolation='none', vmax=1)\n",
    "    ax3.set_xticks([0, 1*100, 2*100, 300])\n",
    "    ax3.set_xticklabels([0, 1, 2, 3])\n",
    "    ax3.set_title(\"generated spikes\")\n",
    "    ax1.set_ylabel('neuron')\n",
    "    ax3.set_xlabel('time (s)')\n",
    "\n",
    "    spikesS = np.expand_dims(spikes, axis=0)  # Add a new dimension along axis 0\n",
    "    spikes_hatS = np.expand_dims(spikes_hat, axis=0)  # Add a new dimension along axis 0\n",
    "    trainS = np.expand_dims(train.T, axis=0)  # Add a new dimension along axis 0\n",
    "    print(spikesS.shape, spikes_hatS.shape, trainS.shape)\n",
    "    iM_data, iV_data = compute_isi_stats(spikesS)\n",
    "    iM_gen, iV_gen = compute_isi_stats(spikes_hatS)\n",
    "    iM_train, iV_train = compute_isi_stats(trainS)\n",
    "\n",
    "    cmap2 = plt.get_cmap('Dark2')\n",
    "    cvD = iV_data/iM_data\n",
    "    cvG = iV_gen/iM_gen\n",
    "    cvT = iV_train/iM_train\n",
    "    ax5.plot([0,20], [0, 20], color='gray', linestyle='--', zorder =0 )\n",
    "    tg = 'teal'\n",
    "    tr = 'firebrick'\n",
    "    scatter1 = ax5.scatter(iM_data, iM_gen,  s=10, alpha=0.7,color = tg, edgecolors='black', label= 'gen', linewidths=0.5)\n",
    "    scatter2 = ax5.scatter(iM_data, iM_train, s=10, alpha=0.7,color = tr, edgecolors='black', label= 'train', linewidths=0.5)\n",
    "    zorders = np.random.randint(1, 20, size=(len(iM_data),2))\n",
    "    # Apply random zorder values to each point\n",
    "    for i in range(len(iM_data)):\n",
    "        ax5.scatter(iM_data[i], iM_gen[i], s=10, alpha=0.7, color=tg, zorder=zorders[i][0])\n",
    "        ax5.scatter(iM_data[i], iM_train[i], s=10, alpha=0.7,color = tr,label= 'train',zorder=zorders[i][1])\n",
    "    ax5.tick_params(axis='x', which='both', width=1)\n",
    "    ax5.tick_params(axis='y', which='both', width=1)\n",
    "    #ax5.set_xlim([0.05, 12])\n",
    "    #ax5.set_ylim([0.05, 12])\n",
    "    ax5.set_title('isi means (ms)')\n",
    "    ax5.set_xscale('log')\n",
    "    ax5.set_yscale('log')\n",
    "    ax5.set_xlabel('test')\n",
    "    ax5.set_yticks([0.1, 1, 10])\n",
    "    ax5.set_yticklabels(['0.1', '1.0', '10'])\n",
    "    ax5.set_xticks([0.1, 1, 10])\n",
    "    ax5.set_xticklabels(['0.1', '1.0', '10'])  \n",
    "    #ax5.set_ylabel('gen / train')\n",
    "    #plt.savefig('hpc2_train_one_isi_mean.png', dpi=300)\n",
    "    teS = testing.shape[1]/100\n",
    "    trS = train.shape[1]/100\n",
    "    fr = np.sum(loaded_binary_matrix.T, axis=0)/teS\n",
    "    fr_hat = np.sum(lam.T.cpu().detach().numpy(), axis = 0)/teS\n",
    "    train_fr = np.sum(train.T, axis=0)/trS\n",
    "\n",
    "    ax4.set_xscale('log')\n",
    "    ax4.set_yscale('log')\n",
    "    ax4.tick_params(axis='x', which='both', width=1)\n",
    "    ax4.tick_params(axis='y', which='both', width=1)\n",
    "    ax4.plot(np.linspace(0, 40, 20), np.linspace(0, 40, 20), color='gray', linestyle='--', zorder=0)\n",
    "    \n",
    "    zorders = np.random.randint(1, 20, size=(len(iM_data),2))\n",
    "    \n",
    "\n",
    "    scatter3 = ax4.scatter(fr, fr_hat, s=10, alpha=0.7, color = tg)\n",
    "    scatter4 = ax4.scatter(fr, train_fr, s=10, alpha=0.7, color = tr)\n",
    "    \n",
    "    for i in range(len(fr)):\n",
    "        ax4.scatter(fr[i], fr_hat[i], s=10, alpha=0.7, color = tg,  zorder=zorders[i][0])\n",
    "        ax4.scatter(fr[i], train_fr[i], s=10, alpha=0.7,color = tr, zorder=zorders[i][1])\n",
    "        \n",
    "    ax4.set_title('mean rates (hz)')\n",
    "    ax4.set_yticks([1, 10])\n",
    "    ax4.set_yticklabels(['1.0', '10'])\n",
    "    ax4.set_xticks([1, 10])\n",
    "    ax4.set_xticklabels(['1.0', '10'])\n",
    "    ax4.set_xlabel('test')\n",
    "    ax4.set_ylabel('gen / train')\n",
    "\n",
    "        # Custom legend labels and colors\n",
    "    legend_labels = ['test/gen', 'test/train']\n",
    "    legend_colors = [tg, tr]\n",
    "\n",
    "    # Add custom legend\n",
    "    legend_elements = [plt.Line2D([0], [0], color=color, lw=0, label=label) for color, label in zip(legend_colors, legend_labels)]\n",
    "    legend = ax4.legend(handles=legend_elements, handletextpad=0, handlelength=0, fancybox=True, loc='upper right', bbox_to_anchor=(1.12, 0.43))\n",
    "\n",
    "    for text, color in zip(legend.get_texts(), legend_colors):\n",
    "        text.set_color(color)\n",
    "        \n",
    "    #set colors as #7B46C1, #A860AF, #7C277D\n",
    "    line1, = ax6.semilogy(frequencies1, psd1,  color=color1, alpha=0.9, zorder=0, label = 'z 2')\n",
    "    line2, = ax6.semilogy(frequencies2, psd2,  color=color2, alpha=0.9, zorder=0, label = 'z 3')\n",
    "    line3, = ax6.semilogy(frequencies0, psd0,  color=color3, alpha=0.9, zorder=0, label = 'z 1')\n",
    "    line4, = ax6.semilogy(frequencies6, psd6,  color='black', alpha=0.6, zorder=0, label = 'LFP')\n",
    "    ax6.set_xlim([1, 17])\n",
    "    ax6.set_ylim([10**-3, 1])\n",
    "    ax6.set_title('psd')\n",
    "    ax6.set_xlabel('frequency (hz)')\n",
    "    ax6.tick_params(axis='y', which='both', width=1)\n",
    "    ax4.legend(loc='lower right', bbox_to_anchor=(1.3, 0.05))\n",
    "    # Custom legend handles\n",
    "    legend_labels = ['$z_2$', '$z_1$', '$z_3$', 'LFP']\n",
    "    legend_colors = [color3, color2, color1, 'black']\n",
    "#\n",
    "    # Add the custom legend to the plot\n",
    "    legend = ax6.legend(legend_labels, handletextpad=0, handlelength=0, fancybox=True, loc='upper right', bbox_to_anchor=(3, 2.95))\n",
    "    for text, color in zip(legend.get_texts(), legend_colors):\n",
    "        text.set_color(color)\n",
    "#\n",
    "    #only show ticks at 1, 10, 100\n",
    "    ax6.set_yticks([0.01, 0.1])\n",
    "    ax6.set_yticklabels(['0.01', '0.1'])\n",
    "    ax6.set_xticks([1, 8, 15])\n",
    "    # vertical line at 0.2\n",
    "    ax7.plot(f2, Cxy2**2, color=cmap(15), label='posterior', linewidth=1.5)\n",
    "    ax7.plot(pf2, pCxy2**2, color=cmap(13), label='gen. latents', linewidth=1.5 )\n",
    "    legend_labels = ['posterior', 'gen. latents']\n",
    "    legend_colors = [cmap(15), cmap(13)]\n",
    "    # Add the custom legend to the plot\n",
    "    legend = ax7.legend(legend_labels, handletextpad=0, handlelength=0, fancybox=True, loc='upper right', bbox_to_anchor=(1.6, 0.8))\n",
    "    for text, color in zip(legend.get_texts(), legend_colors):\n",
    "        text.set_color(color)\n",
    "\n",
    "    ax7.set_xlabel('time (s)')\n",
    "    ax7.set_xlim(1, 20)\n",
    "    #ax7.set_xticksloc = [0, 15*40, 30*40]\n",
    "    #ax7.set_xticklabels([0, 15, 30])\n",
    "    ax7.set_xticks([1, 8, 15])\n",
    "    ax7.set_yticks([0.1, 0.2])\n",
    "    ax7.set_title('coherence w.r.t. LFP')\n",
    "    ax7.set_xlabel('frequency (hz)')\n",
    "\n",
    "\n",
    "    ax1.set_box_aspect(0.625)\n",
    "    ax2.set_box_aspect(0.625)\n",
    "    ax3.set_box_aspect(0.625)\n",
    "    ax7.set_box_aspect(1)\n",
    "    ax4.set_box_aspect(1)\n",
    "    ax5.set_box_aspect(1)\n",
    "    ax6.set_box_aspect(1)\n",
    "    #ax6.set_box_aspect(1)\n",
    "    #set inch width\n",
    "    plt.gcf().set_size_inches(5.2, 3)\n",
    "\n",
    "\n",
    "    #put legends outside\n",
    "    #ax6.legend()\n",
    "    plt.savefig('hpc2_main.png', dpi=300)\n",
    "    #plt svg\n",
    "    plt.savefig('hpc2_main-lastpsd.pdf', dpi=300)\n",
    "    \n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def zero_count_ratio(data):\n",
    "    \"\"\"Calculate the zero count ratio for each neuron.\"\"\"\n",
    "    zero_counts = np.sum(data == 0, axis=0)\n",
    "    total_counts = data.shape[0]\n",
    "    return zero_counts / total_counts\n",
    "\n",
    "print(spikes_hat.shape, spikes.shape)\n",
    "# Calculate zero count ratios\n",
    "test_zero_ratios = zero_count_ratio(spikes_hat)\n",
    "gen_zero_ratios = zero_count_ratio(spikes)\n",
    "train_zero_ratios = zero_count_ratio(train.T)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with mpl.rc_context(fname=\"style_matplotlib.rc\"): \n",
    "    # Create a figure with specified size\n",
    "    fig, axes = plt.subplots(2, 4, figsize=(6, 3))\n",
    "    fig.subplots_adjust(hspace=1, wspace=0.6)\n",
    "\n",
    "    ax1=axes[0,0]\n",
    "    ax2=axes[0,2]\n",
    "    ax3=axes[0,1]\n",
    "    ax4=axes[0,3]\n",
    "    ax5=axes[1,1]\n",
    "    ax6=axes[1,2]\n",
    "    ax7=axes[1,3]\n",
    "    ax8=axes[1,0]\n",
    "\n",
    "    #cvD = iV_data/iM_data\n",
    "    #cvG = iV_gen/iM_gen\n",
    "    #cvT = iV_train/iM_train\n",
    "\n",
    "    ax1.plot([0,20], [0, 20], color='gray', linestyle='--', zorder =0 )\n",
    "    scatter1 = ax1.scatter(iV_data, iV_gen,  s=10, alpha=1,color = cmap(13), edgecolors='black', label= 'gen', linewidths=0.5)\n",
    "    scatter2 = ax1.scatter(iV_data, iV_train, s=10, alpha=1,color = cmap(1), edgecolors='black', label= 'train', linewidths=0.5)\n",
    "    zorders = np.random.randint(1, 20, size=(len(iM_data),2))\n",
    "    for i in range(len(iM_data)):\n",
    "        ax1.scatter(iV_data[i], iV_gen[i], s=10, alpha=0.8, color=cmap(13), edgecolors='black', linewidths=0.3, zorder=zorders[i][0])\n",
    "        ax1.scatter(iV_data[i], iV_train[i], s=10, alpha=0.8,color = cmap(1), edgecolors='black', label= 'train', linewidths=0.3,zorder=zorders[i][1])\n",
    "    ax1.tick_params(axis='x', which='both', width=1)\n",
    "    ax1.tick_params(axis='y', which='both', width=1)\n",
    "    ax1.set_title('isi stds (ms)')\n",
    "    ax1.set_xscale('log')\n",
    "    ax1.set_yscale('log')\n",
    "    ax1.set_xlabel('test')\n",
    "    ax1.set_yticks([0.1, 1, 10])\n",
    "    ax1.set_yticklabels(['0.1', '1.0', '10'])\n",
    "    ax1.set_xticks([0.1, 1, 10])\n",
    "    ax1.set_xticklabels(['0.1', '1.0', '10']) \n",
    "\n",
    "    ax3.plot([0,5], [0, 5], color='gray', linestyle='--', zorder =0 )\n",
    "    scatter1 = ax3.scatter(cvD, cvG,  s=10, alpha=1,color = cmap(13), edgecolors='black', label= 'gen', linewidths=0.5)\n",
    "    scatter2 = ax3.scatter(cvD, cvT, s=10, alpha=1,color = cmap(1), edgecolors='black', label= 'train', linewidths=0.5)\n",
    "    zorders = np.random.randint(1, 20, size=(len(iM_data),2))\n",
    "    for i in range(len(iM_data)):\n",
    "        ax3.scatter(cvD[i], cvG[i], s=10, alpha=0.8, color=cmap(13), edgecolors='black', linewidths=0.3, zorder=zorders[i][0])\n",
    "        ax3.scatter(cvD[i], cvT[i], s=10, alpha=0.8,color = cmap(1), edgecolors='black', label= 'train', linewidths=0.3,zorder=zorders[i][1])\n",
    "    ax3.tick_params(axis='x', which='both', width=1)\n",
    "    ax3.tick_params(axis='y', which='both', width=1)\n",
    "    ax3.set_title('cvs')\n",
    "    #ax2.set_xscale('log')\n",
    "    #ax2.set_yscale('log')\n",
    "    ax3.set_xlabel('test')\n",
    "    ax3.set_yticks([1, 5])\n",
    "    ax3.set_yticklabels(['1', '5'])\n",
    "    ax3.set_xticks([1, 5])\n",
    "    ax3.set_xticklabels(['1', '5']) \n",
    "\n",
    "    #test_corr_values = test_correlation[i_upper]\n",
    "    #gen_corr_values = gen_correlation[i_upper]\n",
    "    #train_corr_values = train_correlation[i_upper]\n",
    "\n",
    "    ax2.plot([-0.07,0.10], [-0.07,0.10], color='gray', linestyle='--', zorder =0 )\n",
    "    scatter1 = ax2.scatter(test_corr_values, gen_corr_values,  s=10, alpha=1,color = cmap(13), edgecolors='black', label= 'gen', linewidths=0.5)\n",
    "    scatter2 = ax2.scatter(test_corr_values, train_corr_values, s=10, alpha=1,color = cmap(1), edgecolors='black', label= 'train', linewidths=0.5)\n",
    "    zorders = np.random.randint(1, 20, size=(len(iM_data),2))\n",
    "    for i in range(len(iM_data)):\n",
    "        ax2.scatter(test_corr_values[i], gen_corr_values[i], s=10, alpha=0.8, color=cmap(13), edgecolors='black', linewidths=0.3, zorder=zorders[i][0])\n",
    "        ax2.scatter(test_corr_values[i], train_corr_values[i], s=10, alpha=0.8,color = cmap(1), edgecolors='black', label= 'train', linewidths=0.3,zorder=zorders[i][1])\n",
    "    ax2.tick_params(axis='x', which='both', width=1)\n",
    "    ax2.tick_params(axis='y', which='both', width=1)\n",
    "    ax2.set_title('cross-correlations')\n",
    "    ax2.set_xlabel('test')\n",
    "    ax2.set_yticks([-0.05, 0.10])\n",
    "    ax2.set_yticklabels(['-0.05', '0.10'])\n",
    "    ax2.set_xticks([-0.05, 0.10])\n",
    "    ax2.set_xticklabels(['-0.05', '0.10'])\n",
    "    ax2.set_xlim([-0.07, 0.13])\n",
    "    ax2.set_ylim([-0.07, 0.13])\n",
    "\n",
    "    ax4.plot([0.9,1], [0.9,1], color='gray', linestyle='--', zorder =0 )\n",
    "    scatter1 = ax4.scatter(test_zero_ratios, gen_zero_ratios,  s=10, alpha=1,color = cmap(13), edgecolors='black', label= 'gen', linewidths=0.5)\n",
    "    scatter2 = ax4.scatter(test_zero_ratios, train_zero_ratios, s=10, alpha=1,color = cmap(1), edgecolors='black', label= 'train', linewidths=0.5)\n",
    "    zorders = np.random.randint(1, 20, size=(len(iM_data),2))\n",
    "    for i in range(len(iM_data)):\n",
    "        ax4.scatter(test_zero_ratios[i], gen_zero_ratios[i], s=10, alpha=0.8, color=cmap(13), edgecolors='black', linewidths=0.3, zorder=zorders[i][0])\n",
    "        ax4.scatter(test_zero_ratios[i], train_zero_ratios[i], s=10, alpha=0.8,color = cmap(1), edgecolors='black', label= 'train', linewidths=0.3,zorder=zorders[i][1])\n",
    "    ax4.tick_params(axis='x', which='both', width=1)\n",
    "    ax4.tick_params(axis='y', which='both', width=1)\n",
    "    ax4.set_title('zero-count ratios')\n",
    "    ax4.set_xlabel('test')\n",
    "    ax4.set_yticks([0.9, 1])\n",
    "    ax4.set_yticklabels(['0.9', '1'])\n",
    "    ax4.set_xticks([0.9, 1])\n",
    "    ax4.set_xticklabels(['0.9', '1'])\n",
    "    ax4.set_xlim([0.9, 1])\n",
    "    ax4.set_ylim([0.9, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train.shape, testing.shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "labrot",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
