{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "%load_ext autoreload\n",
                "%autoreload 2\n",
                "\n",
                "\n",
                "import argparse\n",
                "import os\n",
                "import sys\n",
                "\n",
                "ANONAUTHOR = False\n",
                "\n",
                "if ANONAUTHOR:\n",
                "    os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
                "    # append parent directory to path (../notebooks -> ..)\n",
                "    sys.path.append(os.path.dirname(os.getcwd()))\n",
                "    os.chdir(os.path.dirname(os.getcwd()))\n",
                "\n",
                "else:\n",
                "    os.chdir('../')\n",
                "\n",
                "\n",
                "import accelerate\n",
                "import lovely_tensors as lt\n",
                "import matplotlib.pyplot as plt\n",
                "import matplotlib\n",
                "import numpy as np\n",
                "import scipy.io as io\n",
                "import seaborn as sns\n",
                "import torch\n",
                "import torch.nn as nn\n",
                "import yaml\n",
                "from diffusers.optimization import get_scheduler\n",
                "from omegaconf import OmegaConf\n",
                "from tqdm.auto import tqdm\n",
                "from einops import rearrange\n",
                "\n",
                "from ntldm.data.lds import get_lds_dataset\n",
                "from ntldm.networks import S4AE, AutoEncoder, CountWrapper\n",
                "from ntldm.utils.plotting_utils import *\n",
                "from ntldm.networks import Denoiser\n",
                "from ntldm.utils.plotting_utils import angle_to_color\n",
                "import math\n",
                "from ntldm.data.monkey import get_monkey_dataloaders\n",
                "\n",
                "from diffusers.training_utils import EMAModel\n",
                "from diffusers.schedulers import DDPMScheduler\n",
                "# always run from ../ntldm\n",
                "\n",
                "\n",
                "lt.monkey_patch()\n",
                "matplotlib.rc_file('matplotlibrc')\n",
                "\n",
                "\n",
                "# specify the path \n",
                "save_path = '/home/anonauthor/anonloc1/results/projects/latent-diffusion/MONKEY/PAPER/unconditional/'\n",
                "os.makedirs(save_path, exist_ok=True)\n",
                "save_data_path = save_path + 'DATA/'\n",
                "os.makedirs(save_data_path, exist_ok=True)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "## load config and model path\n",
                "\n",
                "#cfg_ae = OmegaConf.load(\"conf/sweeps_count/autoencoder-count_s4-monkey_new.yaml\")\n",
                "cfg_ae = OmegaConf.load(\"conf/sweeps_count/autoencoder-count_s4-monkey_new_new_regularisation.yaml\")\n",
                "\n",
                "# cfg_yaml = \"\"\"\n",
                "# denoiser_model:\n",
                "#   C_in: 16\n",
                "#   C: 256\n",
                "#   kernel: s4\n",
                "#   num_blocks: 6\n",
                "#   bidirectional: True\n",
                "#   num_train_timesteps: 1000\n",
                "# training:\n",
                "#   lr: 0.001\n",
                "#   weight_decay: 0.0\n",
                "#   num_epochs: 2000\n",
                "#   num_warmup_epochs: 50\n",
                "#   batch_size: 512\n",
                "#   random_seed: 42\n",
                "#   precision: \"no\"\n",
                "# exp_name: diffusion_s4-monkey_epoch_140\n",
                "# \"\"\"\n",
                "\n",
                "# cfg = OmegaConf.create(yaml.safe_load(cfg_yaml))\n",
                "# cfg.dataset = cfg_ae.dataset\n",
                "\n",
                "# load the diffusion config\n",
                "with open(f\"conf/sweeps_count/diffusion_s4-monkey_epoch_140_save_after_train_new_regularisation.yaml\") as f:\n",
                "        cfg = OmegaConf.create(yaml.safe_load(f))"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "\n",
                "### 0. Load autoencoder (with checkpoint) and autoencoder dataset (run for all points 2-5)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "    \n",
                "ae_model = AutoEncoder(\n",
                "    C_in=cfg_ae.model.C_in,\n",
                "    C=cfg_ae.model.C,\n",
                "    C_latent=cfg_ae.model.C_latent,\n",
                "    L=cfg_ae.dataset.signal_length,\n",
                "    kernel=cfg_ae.model.kernel,\n",
                "    num_blocks=cfg_ae.model.num_blocks,\n",
                "    num_blocks_decoder=cfg_ae.model.get(\"num_blocks_decoder\", cfg_ae.model.num_blocks),\n",
                "    num_lin_per_mlp=cfg_ae.model.get(\"num_lin_per_mlp\", 2),  # default 2\n",
                "    bidirectional=cfg_ae.model.get(\"bidirectional\", True),\n",
                ")\n",
                "\n",
                "ae_model = CountWrapper(ae_model, use_sin_enc=cfg_ae.model.get(\"use_sin_enc\", False))\n",
                "\n",
                "\n",
                "\n",
                "# set seed\n",
                "torch.manual_seed(cfg.training.random_seed)\n",
                "np.random.seed(cfg.training.random_seed)\n",
                "\n",
                "train_dataloader, val_dataloader, test_dataloader = get_monkey_dataloaders(\n",
                "        cfg_ae.dataset.task, cfg_ae.dataset.datapath, bin_width=5, batch_size=cfg_ae.training.batch_size\n",
                "    )\n",
                "\n",
                "accelerator = accelerator = accelerate.Accelerator(\n",
                "    mixed_precision='no',\n",
                "    log_with=\"wandb\",\n",
                ")\n",
                "\n",
                "# prepare the ae model and dataset\n",
                "ae_model = accelerator.prepare(ae_model)\n",
                "    \n",
                "print(cfg_ae.exp_name)\n",
                "accelerator.load_state(f\"exp/{cfg_ae.exp_name}/epoch_140\") # best checkpoint CHANGED\n",
                "\n",
                "(\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ") = accelerator.prepare(\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ")\n",
                "\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Load the latent diffusion monkey dataset and plot stats"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.data.monkey import LatentMonkeyDataset\n",
                "latent_dataset_train = LatentMonkeyDataset(train_dataloader, ae_model, clip=False)\n",
                "latent_dataset_val = LatentMonkeyDataset(\n",
                "    val_dataloader,\n",
                "    ae_model,\n",
                "    latent_means=latent_dataset_train.latent_means,\n",
                "    latent_stds=latent_dataset_train.latent_stds,\n",
                "    clip=False,\n",
                ")\n",
                "latent_dataset_test = LatentMonkeyDataset(\n",
                "    test_dataloader,\n",
                "    ae_model,\n",
                "    latent_means=latent_dataset_train.latent_means,\n",
                "    latent_stds=latent_dataset_train.latent_stds,\n",
                "    clip=False,\n",
                ")\n",
                "\n",
                "train_latent_dataloader = torch.utils.data.DataLoader(\n",
                "    latent_dataset_train,\n",
                "    batch_size=cfg.training.batch_size,\n",
                "    shuffle=True,\n",
                "    num_workers=4,\n",
                "    pin_memory=True,\n",
                ")\n",
                "\n",
                "val_latent_dataloader = torch.utils.data.DataLoader(\n",
                "    latent_dataset_val,\n",
                "    batch_size=cfg.training.batch_size,\n",
                "    shuffle=False,\n",
                "    num_workers=4,\n",
                "    pin_memory=True,\n",
                ")\n",
                "\n",
                "test_latent_dataloader = torch.utils.data.DataLoader(\n",
                "    latent_dataset_test,\n",
                "    batch_size=cfg.training.batch_size,\n",
                "    shuffle=False,\n",
                "    num_workers=4,\n",
                "    pin_memory=True,\n",
                ")\n",
                "\n",
                "num_batches = len(train_latent_dataloader)\n",
                "\n",
                "# check if signal length is power of 2\n",
                "if cfg.dataset.signal_length & (cfg.dataset.signal_length - 1) != 0:\n",
                "    cfg.training.precision = \"no\"  # torch.fft doesnt support half if L!=2^x\n",
                "\n",
                "# prepare the denoiser model and dataset\n",
                "(\n",
                "    train_latent_dataloader,\n",
                "    val_latent_dataloader,\n",
                "    test_latent_dataloader,\n",
                ") = accelerator.prepare(\n",
                "    train_latent_dataloader,\n",
                "    val_latent_dataloader,\n",
                "    test_latent_dataloader,\n",
                ")\n",
                "\n",
                "# ------------------ visualize dataset ------------------\n",
                "\n",
                "\n",
                "plt.figure(figsize=cm2inch(5, 3))\n",
                "hist = plt.hist(latent_dataset_train.latents[:100].flatten(), bins=200, density=True)\n",
                "hist = plt.hist(latent_dataset_val.latents[:100].flatten(), bins=200, density=True)\n",
                "hist = plt.hist(latent_dataset_test.latents[:100].flatten(), bins=200, density=True)\n",
                "\n",
                "plt.title(\"Latent dataset histogram\")\n",
                "plt.show()\n",
                "\n",
                "plt.figure(figsize=cm2inch(5, 3))\n",
                "plt.hist(latent_dataset_train.behavior_angles.flatten(), bins=200, label=\"train\")\n",
                "plt.hist(latent_dataset_test.behavior_angles.flatten(), bins=200, label=\"test\")\n",
                "plt.legend()\n",
                "# xticks in terms of pi\n",
                "plt.xticks(\n",
                "    np.linspace(-np.pi, np.pi, 5),\n",
                "    [r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"],\n",
                ")\n",
                "plt.ylabel(\"number of reaches\")\n",
                "plt.xlabel(\"behavior angle\")\n",
                "plt.savefig(save_path + \"behavior_angles.png\")\n",
                "plt.savefig(save_path + \"behavior_angles.pdf\")\n",
                "\n",
                "# create cmap based on behavior angles and plot behavior_cumsum on 2d plot \n",
                "cmap = plt.get_cmap(\"hsv\")\n",
                "colors = cmap((latent_dataset_train.behavior_angles.squeeze() + np.pi) / (2 * np.pi))\n",
                "\n",
                "plt.figure(figsize=cm2inch(2, 2))\n",
                "\n",
                "for i in range(0, len(latent_dataset_train.behavior), 10):\n",
                "    plt.plot(\n",
                "        latent_dataset_train.behavior_cumsum[i, 0],\n",
                "        latent_dataset_train.behavior_cumsum[i, 1],\n",
                "        color=colors[i],\n",
                "        alpha=0.3,\n",
                "        lw=0.4\n",
                "    )\n",
                "\n",
                "# switch off axis\n",
                "plt.axis(\"off\")\n",
                "plt.savefig(save_path + \"Fig_3_all_reaches_colored_by_angles.png\")\n",
                "plt.savefig(save_path + \"Fig_3_all_reaches_colored_by_angles.pdf\")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Load the unconditional diffusion model "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.sampling_utils import sample, sample_spikes\n",
                "\n",
                "denoiser = Denoiser(\n",
                "    C_in=cfg.denoiser_model.C_in,\n",
                "    C=cfg.denoiser_model.C,\n",
                "    L=cfg.dataset.signal_length,\n",
                "    kernel=cfg.denoiser_model.kernel,\n",
                "    num_blocks=cfg.denoiser_model.num_blocks,\n",
                "    bidirectional=cfg.denoiser_model.get(\"bidirectional\", True),\n",
                ")\n",
                "\n",
                "# initial values may be way off so better to scale down the output layer\n",
                "denoiser.conv_out.weight.data = denoiser.conv_out.weight.data * 0.1\n",
                "denoiser.conv_out.bias.data = denoiser.conv_out.bias.data * 0.1\n",
                "\n",
                "scheduler = DDPMScheduler(\n",
                "    num_train_timesteps=cfg.denoiser_model.num_train_timesteps,\n",
                "    clip_sample=False,\n",
                "    beta_schedule=\"linear\", # ddpm doesnt support cosine\n",
                ")\n",
                "\n",
                "\n",
                "optimizer = torch.optim.AdamW(\n",
                "    denoiser.parameters(), lr=cfg.training.lr\n",
                ")  # default wd=0.01 for now\n",
                "\n",
                "\n",
                "\n",
                "num_batches = len(train_latent_dataloader)\n",
                "lr_scheduler = get_scheduler(\n",
                "    name=\"cosine\",\n",
                "    optimizer=optimizer,\n",
                "    num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "    num_training_steps=num_batches * cfg.training.num_epochs * 1.3,  # total number of steps\n",
                ")\n",
                "\n",
                "denoiser.load_state_dict(torch.load(f\"exp/{cfg.exp_name}/model_new_reg.pt\", map_location=\"cpu\"))\n",
                "\n",
                "# prepare the denoiser model and dataset\n",
                "(\n",
                "    denoiser,\n",
                "    optimizer,\n",
                "    lr_scheduler,\n",
                ") = accelerator.prepare(\n",
                "    denoiser,\n",
                "    optimizer,\n",
                "    lr_scheduler,\n",
                ")\n",
                "\n",
                "ema_model = EMAModel(denoiser)\n",
                "\n",
                "    \n",
                "print(f\"loaded model from exp/{cfg.exp_name}/model_new_reg.pt\")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Create tensors for comparison"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.sampling_utils import sample\n",
                "from ntldm.utils.utils import set_seed\n",
                "import pickle\n",
                "fps_monkey = 1000/5\n",
                "\n",
                "load_from_disc = True\n",
                "\n",
                "if load_from_disc:\n",
                "    with open(save_data_path + 'unconditional_diffusion_data_dict.pkl', 'rb') as f:\n",
                "        unconditional_diffusion_data_dict = pickle.load(f)\n",
                "\n",
                "    ae_rates = unconditional_diffusion_data_dict['ae_rates']\n",
                "    ae_latents = unconditional_diffusion_data_dict['ae_latents']\n",
                "    diffusion_rates = unconditional_diffusion_data_dict['diffusion_rates']\n",
                "    diffusion_latents = unconditional_diffusion_data_dict['diffusion_latents']\n",
                "    gt_spikes = unconditional_diffusion_data_dict['gt_spikes']\n",
                "\n",
                "else: # takes 4-5 min\n",
                "    # ensure reproducibility\n",
                "    set_seed(42)\n",
                "\n",
                "    true_data = True\n",
                "    n_siglen = 1\n",
                "    sample_cutoff = int(10e5)\n",
                "    ae_model.eval()\n",
                "\n",
                "    ae_rates = []\n",
                "    ae_latents = []\n",
                "    diffusion_rates = []\n",
                "    diffusion_latents = []\n",
                "    gt_spikes = []\n",
                "\n",
                "    if not true_data:\n",
                "        gt_rates = []\n",
                "        gt_latents = []\n",
                "\n",
                "\n",
                "    # autoencoder eval \n",
                "    for batch in train_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates, latent = ae_model(signal)\n",
                "            #output_rates = ae_model(signal)[0].cpu()\n",
                "        ae_rates.append(output_rates.cpu())\n",
                "        ae_latents.append(latent.cpu())\n",
                "        gt_spikes.append(signal.cpu())\n",
                "        if not true_data:\n",
                "            gt_rates.append(batch[\"rates\"].cpu())\n",
                "            gt_latents.append(batch[\"latents\"].cpu())\n",
                "\n",
                "        #break\n",
                "\n",
                "    # concatenate along batch dimension\n",
                "    ae_rates = torch.cat(ae_rates, dim=0)\n",
                "    ae_latents = torch.cat(ae_latents, dim=0)\n",
                "    gt_spikes = torch.cat(gt_spikes, dim=0)\n",
                "    if not true_data:\n",
                "        gt_rates = torch.cat(gt_rates, dim=0)\n",
                "        gt_latents = torch.cat(gt_latents, dim=0)\n",
                "        \n",
                "\n",
                "\n",
                "    # diffusion eval\n",
                "    sampled_latents = sample(\n",
                "        ema_denoiser=ema_model, scheduler=scheduler, cfg=cfg,\n",
                "        batch_size=ae_rates.shape[0], device=\"cuda\", signal_length=cfg.dataset.signal_length * n_siglen\n",
                "    )\n",
                "\n",
                "    # project back to non standardized space\n",
                "    sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "        sampled_latents.device\n",
                "    ) + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "\n",
                "    with torch.no_grad():\n",
                "        sampled_rates = ae_model.decode(sampled_latents).cpu()\n",
                "        \n",
                "    diffusion_rates.append(sampled_rates)\n",
                "    diffusion_latents.append(sampled_latents.cpu())\n",
                "\n",
                "    # concatenate along batch dimension\n",
                "    diffusion_rates = torch.cat(diffusion_rates, dim=0)\n",
                "    diffusion_latents = torch.cat(diffusion_latents, dim=0)\n",
                "\n",
                "    vecs = [ae_rates, ae_latents, diffusion_rates, diffusion_latents, gt_spikes]\n",
                "\n",
                "    if not true_data:\n",
                "        vecs.append(gt_rates)\n",
                "        vecs.append(gt_latents)\n",
                "\n",
                "    vecs = [vec.cpu().numpy() for vec in vecs]\n",
                "    vecs = [vec[:sample_cutoff] for vec in vecs]\n",
                "    vecs = [rearrange(vec, 'b n t -> b t n') for vec in vecs]\n",
                "\n",
                "\n",
                "    if not true_data:\n",
                "        ae_rates, ae_latents, diffusion_rates, diffusion_latents, gt_spikes, gt_rates, gt_latents = vecs\n",
                "    else:\n",
                "        ae_rates, ae_latents, diffusion_rates, diffusion_latents, gt_spikes = vecs\n",
                "\n",
                "    unconditional_diffusion_data_dict = {\n",
                "        'ae_rates': ae_rates,\n",
                "        'ae_latents': ae_latents,\n",
                "        'diffusion_rates': diffusion_rates,\n",
                "        'diffusion_latents': diffusion_latents,\n",
                "        'gt_spikes': gt_spikes\n",
                "    }\n",
                "    import pickle\n",
                "    with open(save_data_path + 'unconditional_diffusion_data_dict.pkl', 'wb') as f:\n",
                "        pickle.dump(unconditional_diffusion_data_dict, f)\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(7.5, 2.5), sharey=True)\n",
                "\n",
                "# ----------- Binarize the data ------------\n",
                "\n",
                "data_binary_spikes = gt_spikes[0].T > 0\n",
                "sampled_binary_spikes = np.random.poisson(diffusion_rates[0].T) > 0\n",
                "\n",
                "# Display data\n",
                "im = ax[0].imshow(data_binary_spikes, aspect=\"auto\", cmap='Greys', vmax=1)\n",
                "ax[0].set_title(\"data\")\n",
                "#fig.colorbar(im, ax=ax[0], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "\n",
                "# Set the x-ticks to display time in seconds for data\n",
                "time_points = data_binary_spikes.shape[1]\n",
                "fps = fps_monkey\n",
                "time_interval = 1 / fps\n",
                "second_ticks = fps / 2  # ticks every 0.25 seconds\n",
                "xticks = np.arange(0, time_points, second_ticks)\n",
                "ax[0].set_xticks(xticks)\n",
                "ax[0].set_xticklabels([f\"{x * time_interval:.1f}\" for x in xticks])\n",
                "\n",
                "# Display sampled spikes\n",
                "im = ax[1].imshow(sampled_binary_spikes, aspect=\"auto\", cmap='Greys', vmax=1)\n",
                "ax[1].set_title(\"sampled\")\n",
                "#fig.colorbar(im, ax=ax[1], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "\n",
                "# Set the x-ticks to display time in seconds for sampled data\n",
                "ax[1].set_xticks(xticks)\n",
                "ax[1].set_xticklabels([f\"{x * time_interval:.1f}\" for x in xticks])\n",
                "ax[0].set_xlabel(\"time (s)\")\n",
                "ax[1].set_xlabel(\"time (s)\")\n",
                "ax[0].set_ylabel(\"neuron\")\n",
                "\n",
                "plt.savefig(save_path + \"Fig_3_sampled_vs_data_binary.png\", dpi=300)\n",
                "plt.savefig(save_path + \"Fig_3_sampled_vs_data_binary.pdf\")\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Read in baselines "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# read in baseline data for comparison\n",
                "baseline_path ='/home/anonauthor/anonloc1/results/projects/latent-diffusion/DATA_SHARE/'\n",
                "import pickle\n",
                "# lfads samples sampled from the prior and passed through the decoder\n",
                "\n",
                "with open(baseline_path + 'lfads_samples.pkl', 'rb') as f:\n",
                "    lfads_samples = pickle.load(f)\n",
                "    \n",
                "\n",
                "lfads_spikes = lfads_samples['model_spikes']\n",
                "lfads_rates = lfads_samples['model_rates']\n",
                "\n",
                "\n",
                "with open(baseline_path + 'history_sampled_dict.pkl', 'rb') as f:\n",
                "    history_sampled_dict = pickle.load(f)\n",
                "\n",
                "hist_spikes = history_sampled_dict['hist'][1]\n",
                "hist_spikes_list = history_sampled_dict['hist']\n",
                "\n",
                "\n",
                "\n",
                "load_from_disc = True\n",
                "\n",
                "if load_from_disc:\n",
                "    with open(save_data_path + 'unconditional_diffusion_data_dict.pkl', 'rb') as f:\n",
                "        unconditional_diffusion_data_dict = pickle.load(f)\n",
                "\n",
                "    ae_rates = unconditional_diffusion_data_dict['ae_rates']\n",
                "    ae_latents = unconditional_diffusion_data_dict['ae_latents']\n",
                "    diffusion_rates = unconditional_diffusion_data_dict['diffusion_rates']\n",
                "    diffusion_latents = unconditional_diffusion_data_dict['diffusion_latents']\n",
                "    gt_spikes = unconditional_diffusion_data_dict['gt_spikes']"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.utils import set_seed\n",
                "\n",
                "set_seed(42)\n",
                "spikes_dict_single = {\n",
                "    'lfads': lfads_spikes,\n",
                "    'ldns': np.random.poisson(diffusion_rates),\n",
                "    'hist': hist_spikes,\n",
                "    'train': gt_spikes,\n",
                "}\n",
                "\n",
                "list_indices = [np.random.choice(len(gt_spikes), len(gt_spikes), replace=True) for _ in range(5)]\n",
                "\n",
                "def get_folds(spikes, indices):\n",
                "    \"\"\"Return spikes sampled according to the provided indices.\"\"\"\n",
                "    return spikes[indices]\n",
                "\n",
                "# Create the spikes dictionary for each model across 5 folds\n",
                "spikes_dict = {\n",
                "    'lfads': [get_folds(spikes_dict_single['lfads'], idx) for idx in list_indices],\n",
                "}\n",
                "\n",
                "for key in spikes_dict_single.keys():\n",
                "    if key not in spikes_dict:\n",
                "        spikes_dict[key] = [get_folds(spikes_dict_single[key], idx) for idx in list_indices]"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# plot the first trial spikes for all models \n",
                "fig, ax = plt.subplots(1, 4, figsize=cm2inch(15, 3), sharey=True)\n",
                "for i, (key, spikes) in enumerate(spikes_dict_single.items()):\n",
                "    ax[i].imshow(spikes[list_indices[0]][0].T > 0, aspect='auto', cmap='Greys', vmax=1)\n",
                "    ax[i].set_title(key if key != 'train' else 'data')\n",
                "    ax[i].set_xticks([])\n",
                "    ax[i].set_yticks([])\n",
                "ax[0].set_ylabel('neuron')\n",
                "ax[0].set_xlabel('time')\n",
                "plt.savefig(save_path + 'Supp_Fig_model_sampling_comparison.png')\n",
                "plt.savefig(save_path + 'Supp_Fig_model_sampling_comparison.pdf')"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# make main comparison loop here"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import numpy as np\n",
                "\n",
                "\n",
                "methods = [ \"lfads\",\"ldns\", \"hist\", \"train\"]\n",
                "comparison_stats = [ \"kl_div_psc\", \"rmse_corr\",\"rmse_mean_isi\", \"rmse_std_isi\",\n",
                "]\n",
                "\n",
                "# Initialize the comparison statistics dictionary\n",
                "comp_dict = {method: {stat: [] for stat in comparison_stats} for method in methods}\n",
                "\n",
                "# Define a function to calculate your statistics, here using dummy functions for illustration\n",
                "def calculate_statistics_dummy(spikes, gt):\n",
                "    # Here you would include calls to actual functions like rmse, kl_div, corr, etc.\n",
                "    return {\n",
                "        \"rmse_mean_rate\": np.mean(np.abs(spikes - gt)),\n",
                "        \"rmse_std_rate\": np.std(np.abs(spikes - gt)),\n",
                "        \"rmse_mean_isi\": np.mean(spikes - gt),\n",
                "        \"rmse_std_isi\": np.std(spikes - gt),\n",
                "        \"kl_div_psc\": np.mean(spikes - gt),  # Placeholder\n",
                "        \"rmse_corr\": np.corrcoef(spikes.reshape(-1), gt.reshape(-1))[0, 1],\n",
                "        \"rmse_auto_corr\": np.mean(np.abs(spikes - gt)),  # Placeholder\n",
                "        \"rmse_cross_corr\": np.mean(spikes - gt)  # Placeholder\n",
                "    }\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import rmse_nan, average_rates, std_rates, kl_div_nan, correlation_matrix\n",
                "from ntldm.utils.eval_utils import compute_spike_stats_per_neuron, counts_to_spike_trains\n",
                "from scipy.stats import gaussian_kde\n",
                "from ntldm.utils.eval_utils import group_neurons_temp_corr, get_temp_corr_summary\n",
                "\n",
                "\n",
                "\n",
                "def rmse_mean_rate(spikes, gt):\n",
                "    gt_m = average_rates(gt, mode='neur', fps=fps_monkey,).flatten()\n",
                "    spikes_m = average_rates(spikes, mode='neur', fps=fps_monkey).flatten()\n",
                "    return rmse_nan(gt_m, spikes_m)\n",
                "\n",
                "def rmse_std_rate(spikes, gt):\n",
                "    gt_s = std_rates(gt, mode='neur', fps=fps_monkey,).flatten()\n",
                "    spikes_s = std_rates(spikes, mode='neur', fps=fps_monkey).flatten()\n",
                "    return rmse_nan(gt_s, spikes_s)\n",
                "\n",
                "def get_spike_train_and_stats(spikes, fps=200):\n",
                "    spiketrain = counts_to_spike_trains(spikes, fps=fps)\n",
                "    spike_stats_m = compute_spike_stats_per_neuron(\n",
                "        spiketrain,\n",
                "        n_samples=spikes.shape[0],\n",
                "        n_neurons=spikes.shape[2],\n",
                "        mean_output=False,\n",
                "    )\n",
                "    return spike_stats_m\n",
                "\n",
                "\n",
                "def calculate_statistics(spikes, gt, spike_stats_gt=None, maxval=17, calc_isi=True, temp_corr=False, groups_dict=None):\n",
                "    # calculate spike stats for model spikes and gt spikes\n",
                "    \n",
                "    # pass gt spike stats to save compute time \n",
                "    if spike_stats_gt is None:\n",
                "        spike_stats_gt = get_spike_train_and_stats(gt, fps=fps_monkey)\n",
                "    \n",
                "    if calc_isi:\n",
                "        spike_stats_model = get_spike_train_and_stats(spikes, fps=fps_monkey)\n",
                "    else:\n",
                "        spike_stats_model=spike_stats_gt\n",
                "        print('CAUTION NOT CALCULATING SPIKE STATS')\n",
                "    \n",
                "    # population spike count histogram\n",
                "    # first approx with gaussian kde the calculate kl div\n",
                "    kde_model = gaussian_kde(spikes.sum(2).flatten())\n",
                "    kde_gt = gaussian_kde(gt.sum(2).flatten())\n",
                "    x_eval = np.linspace(0, maxval, maxval+1)\n",
                "    density_model = kde_model(x_eval)\n",
                "    density_gt = kde_gt(x_eval)\n",
                "    density_model /= density_model.sum()\n",
                "    density_gt /= density_gt.sum()\n",
                "\n",
                "    # get the correlation structure\n",
                "    C_model = correlation_matrix(spikes, mode=\"concatenate\")\n",
                "    np.fill_diagonal(C_model, 0)\n",
                "    C_model = np.tril(C_model, k=-1)\n",
                "    \n",
                "    C_gt = correlation_matrix(gt, mode=\"concatenate\")\n",
                "    np.fill_diagonal(C_gt, 0)\n",
                "    C_gt = np.tril(C_gt, k=-1)\n",
                "    \n",
                "    \n",
                "    if temp_corr:       \n",
                "        cross_corr_groups_model, auto_corr_groups_model = get_temp_corr_summary(spikes.transpose(1,0,2), groups_dict['groups'], nlags=30, mode='biased', batch_first=False)\n",
                "        number_of_groups = 2 # limit sind the lower two groups have super low correlation\n",
                "        \n",
                "\n",
                "    return {\n",
                "        #\"rmse_mean_rate\": rmse_mean_rate(spikes, gt), \n",
                "        #\"rmse_std_rate\": rmse_std_rate(spikes, gt), \n",
                "        \"kl_div_psc\": kl_div_nan(density_model, density_gt),  \n",
                "        \"rmse_corr\": rmse_nan(C_model, C_gt),\n",
                "        \"rmse_mean_isi\": rmse_nan(spike_stats_model['mean_isi'], spike_stats_gt['mean_isi']),\n",
                "        \"rmse_std_isi\": rmse_nan(spike_stats_model['std_isi'], spike_stats_gt['std_isi']),\n",
                "        # \"rmse_auto_corr\": np.mean(np.abs(spikes - gt)),  # Placeholder\n",
                "        # \"rmse_cross_corr\": np.mean(spikes - gt)  # Placeholder\n",
                "    }"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "comp_dict = {method: {stat: [] for stat in comparison_stats} for method in methods}\n",
                "# this calculation takes time so worth calculating beforehand \n",
                "fps_monkey = 1000/5\n",
                "spike_stats_gt =  get_spike_train_and_stats(gt_spikes, fps=fps_monkey)\n",
                "temp_corr = False\n",
                "\n",
                "if temp_corr: # should temporal correlations be included\n",
                "    # load again\n",
                "    if os.path.exists(save_data_path + 'cross_corr_gt.pkl'):\n",
                "        with open(save_data_path + 'cross_corr_gt.pkl', 'rb') as f:\n",
                "            groups_dict = pickle.load(f)\n",
                "    else:\n",
                "        groups = group_neurons_temp_corr(gt_spikes.transpose(1,0,2), num_groups=4)\n",
                "        cross_corr_groups, auto_corr_groups = get_temp_corr_summary(gt_spikes.transpose(1,0,2), groups, nlags=30,mode='biased',\n",
                "                                                            batch_first=False)\n",
                "        groups_dict = {\n",
                "            'cross_corr_groups': cross_corr_groups,\n",
                "            'auto_corr_groups': auto_corr_groups,\n",
                "            'groups': groups, \n",
                "        }\n",
                "        with open(save_data_path + 'cross_corr_gt.pkl', 'wb') as f:\n",
                "            pickle.dump(groups_dict, f)\n",
                "else:\n",
                "    groups_dict = None\n",
                "            \n",
                "        \n",
                "# get the overall maximum population spike count to calculate the kl div\n",
                "max_vals = []\n",
                "for key in spikes_dict_single.keys():\n",
                "    max_vals.append(spikes_dict_single[key].sum(2).max())\n",
                "maxval = int(np.max(max_vals))\n",
                "\n",
                "\n",
                "\n",
                "# Loop over methods and folds to compute and store the statistics\n",
                "for method in methods:\n",
                "    for fold in range(5):\n",
                "        model_spikes = spikes_dict[method][fold]\n",
                "        gt_data = gt_spikes  # Assuming gt_spikes is the ground truth for all\n",
                "        if method=='train':\n",
                "            # make a random split of the data\n",
                "            indices = np.random.choice(len(gt_data), int(len(gt_data)/2), replace=False)\n",
                "            # other indices\n",
                "            other_indices = np.setdiff1d(np.arange(len(gt_data)), indices)\n",
                "            gt_data = gt_data[indices]\n",
                "            model_spikes = model_spikes[other_indices]\n",
                "        stats = calculate_statistics(model_spikes, gt_data, spike_stats_gt=spike_stats_gt, maxval=maxval, groups_dict=groups_dict, temp_corr=temp_corr)\n",
                "\n",
                "        for stat in comparison_stats:\n",
                "            comp_dict[method][stat].append(stats[stat])\n",
                "            \n",
                "    print(f\"Finished computing stats for {method}\")    \n",
                "\n",
                "\n",
                "\n",
                "def compute_stats(comp_dict):\n",
                "    stats_results = {method: {} for method in comp_dict}\n",
                "    for method, metrics in comp_dict.items():\n",
                "        for metric, values in metrics.items():\n",
                "            mean_val = np.mean(values)\n",
                "            std_val = np.std(values)\n",
                "            stats_results[method][metric] = (mean_val, std_val)\n",
                "    return stats_results\n",
                "\n",
                "\n",
                "\n",
                "color_dict = {\n",
                "    'lfads': '#50547C',\n",
                "    'ldns': '#A44A3F',\n",
                "    'hist': '#495F41',\n",
                "    'train': 'grey'\n",
                "}\n",
                "def plot_metrics(comp_dict, ncols=3, figsize=cm2inch(15, 10)):\n",
                "    stats_results = compute_stats(comp_dict)\n",
                "    metrics = next(iter(comp_dict.values())).keys()  # Get the list of metrics from any method\n",
                "    num_metrics = len(metrics)\n",
                "    fig, axes = plt.subplots(nrows=int(num_metrics/ncols),ncols=ncols, figsize=figsize, constrained_layout=True)\n",
                "    axes = axes.flatten()\n",
                "\n",
                "    for ax, metric in zip(axes, metrics):\n",
                "        methods = list(stats_results.keys())\n",
                "        means = [stats_results[method][metric][0] for method in methods]\n",
                "        stds = [stats_results[method][metric][1] for method in methods]\n",
                "        method_colors = [color_dict[method] for method in methods]  # Get colors based on color_dict\n",
                "\n",
                "        ax.bar(methods, means, yerr=stds, capsize=5, alpha=0.7, color=method_colors)\n",
                "        ax.set_title(metric.replace('_', ' '))\n",
                "        ax.set_ylabel('value')\n",
                "        #ax.set_xlabel('methods')\n",
                "\n",
                "        # if more axes than metrics hide the last one\n",
                "    if len(axes) > num_metrics:\n",
                "        axes[-1].axis('off')\n",
                "    return fig\n",
                "\n",
                "    \n",
                "# save picke files in save_data_path\n",
                "with open(save_data_path + 'comp_dict_5_folds_relevant_metrics.pkl', 'wb') as f:\n",
                "        pickle.dump(comp_dict, f)\n",
                "# Example usage\n",
                "# Assuming `comp_dict` is already populated with your data\n",
                "fig = plot_metrics(comp_dict)\n",
                "\n",
                "fig.savefig(save_path + 'Supp_Fig_model_metrics_comparison_5_210524.png')\n",
                "fig.savefig(save_path + 'Supp_Fig_model_metrics_comparison_5_210524.pdf')"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Make comparison plot"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import pickle\n",
                "color_dict = {\n",
                "    'lfads': '#9f86c0',#632075',#'#50547C',\n",
                "    'ldns': '#A44A3F',\n",
                "    'hist': '#495F41',\n",
                "    'train': 'grey'\n",
                "}\n",
                "selected_methods = [\"lfads\", \"ldns\",\"hist\", \"train\"]\n",
                "\n",
                "selected_comparison_stats = [\n",
                "     \"kl_div_psc\", \"rmse_corr\", \"rmse_mean_isi\", \"rmse_std_isi\",\n",
                "    # \"rmse_auto_corr\", \"rmse_cross_corr\"\n",
                "]\n",
                "\n",
                "def compute_stats(comp_dict):\n",
                "    stats_results = {method: {} for method in comp_dict}\n",
                "    for method, metrics in comp_dict.items():\n",
                "        for metric, values in metrics.items():\n",
                "            mean_val = np.mean(values)\n",
                "            std_val = np.std(values)\n",
                "            stats_results[method][metric] = (mean_val, std_val)\n",
                "    return stats_results\n",
                "\n",
                "\n",
                "def plot_metrics(comp_dict, selected_methods=None, selected_comparison_stats=None, ncols=3, figsize=cm2inch(15, 10)):\n",
                "    stats_results = compute_stats(comp_dict)\n",
                "    if selected_comparison_stats is None:\n",
                "        metrics = next(iter(comp_dict.values())).keys()  # Get the list of metrics from any method\n",
                "    else:\n",
                "        metrics = selected_comparison_stats\n",
                "    num_metrics = len(metrics)\n",
                "    fig, axes = plt.subplots(nrows=int(num_metrics/ncols),ncols=ncols, figsize=figsize, constrained_layout=True)\n",
                "    axes = axes.flatten()\n",
                "\n",
                "    for ax, metric in zip(axes, metrics):\n",
                "        methods = list(stats_results.keys())\n",
                "        means = [stats_results[method][metric][0] for method in methods]\n",
                "        stds = [stats_results[method][metric][1] for method in methods]\n",
                "        method_colors = [color_dict[method] for method in methods]  # Get colors based on color_dict\n",
                "\n",
                "        ax.bar(methods, means, yerr=stds, capsize=5, alpha=0.7, color=method_colors)\n",
                "        ax.set_title(metric.replace('_', ' '))\n",
                "        ax.set_ylabel('value')\n",
                "        #ax.set_xlabel('methods')\n",
                "\n",
                "        # if more axes than metrics hide the last one\n",
                "    if len(axes) > num_metrics:\n",
                "        axes[-1].axis('off')\n",
                "    return fig\n",
                "\n",
                "\n",
                "# with open(save_data_path + 'comp_dict_5.pkl', 'rb') as f:\n",
                "#     comp_dict = pickle.load(f)\n",
                "        \n",
                "\n",
                "with open(save_data_path + 'comp_dict_5_folds_relevant_metrics.pkl', 'rb') as f:\n",
                "    comp_dict = pickle.load(f)\n",
                "    \n",
                "    \n",
                "fig = plot_metrics(comp_dict, selected_comparison_stats=selected_comparison_stats,figsize=cm2inch(18, 5), ncols=4)\n",
                "\n",
                "fig.savefig(save_path + 'Supp_Fig_model_metrics_comparison_barplots.png')\n",
                "fig.savefig(save_path + 'Supp_Fig_model_metrics_comparison_barplots.pdf')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "# sort keys differently \n",
                "\n",
                "comp_dict_ordered = {\n",
                "    'lfads': comp_dict['lfads'],\n",
                "    'ldns': comp_dict['ldns'],\n",
                "    'hist': comp_dict['hist'],\n",
                "    'train': comp_dict['train'],\n",
                "}\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import pickle\n",
                "color_dict = {\n",
                "    'lfads': '#9f86c0',#632075',#'#50547C',\n",
                "    'ldns': '#A44A3F',\n",
                "    'hist': '#495F41',\n",
                "    'train': 'grey'\n",
                "}\n",
                "selected_methods = [\"lfads\", \"ldns\",\"hist\", \"train\"]\n",
                "\n",
                "selected_comparison_stats = [\n",
                "     \"kl_div_psc\", \"rmse_corr\", \"rmse_mean_isi\", \"rmse_std_isi\",\n",
                "    # \"rmse_auto_corr\", \"rmse_cross_corr\"\n",
                "]\n",
                "\n",
                "def compute_stats(comp_dict):\n",
                "    stats_results = {method: {} for method in comp_dict}\n",
                "    for method, metrics in comp_dict.items():\n",
                "        for metric, values in metrics.items():\n",
                "            mean_val = np.mean(values)\n",
                "            std_val = np.std(values)\n",
                "            stats_results[method][metric] = (mean_val, std_val)\n",
                "    return stats_results\n",
                "\n",
                "\n",
                "def plot_metrics(comp_dict, selected_methods=None, selected_comparison_stats=None, ncols=3, figsize=cm2inch(15, 10)):\n",
                "    stats_results = compute_stats(comp_dict)\n",
                "    if selected_comparison_stats is None:\n",
                "        metrics = next(iter(comp_dict.values())).keys()  # Get the list of metrics from any method\n",
                "    else:\n",
                "        metrics = selected_comparison_stats\n",
                "    num_metrics = len(metrics)\n",
                "    fig, axes = plt.subplots(nrows=int(num_metrics/ncols),ncols=ncols, figsize=figsize, constrained_layout=True)\n",
                "    axes = axes.flatten()\n",
                "\n",
                "    for ax, metric in zip(axes, metrics):\n",
                "        methods = list(stats_results.keys())\n",
                "        means = [stats_results[method][metric][0] for method in methods]\n",
                "        stds = [stats_results[method][metric][1] for method in methods]\n",
                "        method_colors = [color_dict[method] for method in methods]  # Get colors based on color_dict\n",
                "\n",
                "        ax.bar(methods, means, yerr=stds, capsize=5, alpha=0.7, color=method_colors)\n",
                "        ax.set_title(metric.replace('_', ' '))\n",
                "        ax.set_ylabel('value')\n",
                "        #ax.set_xlabel('methods')\n",
                "\n",
                "        # if more axes than metrics hide the last one\n",
                "    if len(axes) > num_metrics:\n",
                "        axes[-1].axis('off')\n",
                "    return fig\n",
                "\n",
                "\n",
                "# with open(save_data_path + 'comp_dict_5.pkl', 'rb') as f:\n",
                "#     comp_dict = pickle.load(f)\n",
                "        \n",
                "\n",
                "with open(save_data_path + 'comp_dict_5_folds_relevant_metrics.pkl', 'rb') as f:\n",
                "    comp_dict = pickle.load(f)\n",
                "    \n",
                "    \n",
                "fig = plot_metrics(comp_dict_ordered,figsize=cm2inch(18, 5), ncols=4)\n",
                "\n",
                "fig.savefig(save_path + 'Supp_Fig_model_metrics_comparison_barplots_ordered.png')\n",
                "fig.savefig(save_path + 'Supp_Fig_model_metrics_comparison_barplots_ordered.pdf')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "comp_dict_ordered\n",
                "stats_results = compute_stats(comp_dict_ordered)\n",
                "stats_results"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "comp_dict"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "stats_results = compute_stats(comp_dict)\n",
                "stats_results"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from scipy.stats import gaussian_kde\n",
                "color_dict = {\n",
                "    'lfads': 'purple',\n",
                "    'ldns': 'darkred',\n",
                "    'hist': 'darkgreen',\n",
                "    'train': 'grey'\n",
                "}\n",
                "fig, ax = plt.subplots(2, 2, figsize=cm2inch(9, 7), sharex=True, sharey=True)\n",
                "ax = ax.flatten()\n",
                "for i, (key, val) in enumerate(spikes_dict_single.items()):\n",
                "    kde_model = gaussian_kde(val.sum(2).flatten())\n",
                "    kde_gt = gaussian_kde(gt_spikes.sum(2).flatten())\n",
                "\n",
                "    # Evaluating densities over a common range derived from data\n",
                "    x_eval = np.linspace(0, maxval, maxval+1)\n",
                "    \n",
                "    density_model = kde_model(x_eval)\n",
                "    density_gt = kde_gt(x_eval)\n",
                "\n",
                "    density_model /= density_model.sum()\n",
                "    density_gt /= density_gt.sum()\n",
                "\n",
                "    kl_divergence = kl_div_nan(density_model, density_gt)\n",
                "    kl_divergence\n",
                "    bins_psc = np.arange(-0.5, maxval-0.5, 1)\n",
                "    # plot the population spike count histogram for the lfads model and the data\n",
                "    ax[i].hist(val.sum(2).flatten(), bins=bins_psc, density=True, alpha=0.5, label=key, color=color_dict[key])\n",
                "    ax[i].hist(gt_spikes.sum(2).flatten(), bins=bins_psc, density=True, alpha=0.5, label='data', color='grey')\n",
                "    \n",
                "\n",
                "    # now plot the density estimate\n",
                "    ax[i].plot(x_eval, density_gt, '.-', label='data kde', color='black')\n",
                "    ax[i].plot(x_eval, density_model, '.-', label=key+' kde', color=color_dict[key])\n",
                "    \n",
                "    ax[i].legend(fontsize=6)\n",
                "ax[2].set_xlabel('spike count')\n",
                "ax[2].set_ylabel('density')\n",
                "\n",
                "# save the figure\n",
                "plt.savefig(save_path + 'Supp_Fig_population_spike_count.png')\n",
                "plt.savefig(save_path + 'Supp_Fig_population_spike_count.pdf')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# plot population spike histogram for diffusion as well as correlation and isi\n",
                "color_dict = {\n",
                "    'lfads': 'purple',\n",
                "    'ldns': 'darkred',\n",
                "    'hist': 'darkgreen',\n",
                "    'train': 'grey'\n",
                "}\n",
                "\n",
                "for i, (key, val) in enumerate(spikes_dict_single.items()):\n",
                "    # if key != 'ldns':\n",
                "    #     continue\n",
                "    fig, ax = plt.subplots(2, 2, figsize=cm2inch(9, 7))\n",
                "    ax =ax.flatten()\n",
                "    kde_model = gaussian_kde(val.sum(2).flatten())\n",
                "    kde_gt = gaussian_kde(gt_spikes.sum(2).flatten())\n",
                "\n",
                "    # Evaluating densities over a common range derived from data\n",
                "    x_eval = np.linspace(0, maxval, maxval+1)\n",
                "    density_model = kde_model(x_eval)\n",
                "    density_gt = kde_gt(x_eval)\n",
                "    density_model /= density_model.sum()\n",
                "    density_gt /= density_gt.sum()\n",
                "    \n",
                "    bins_psc = np.arange(-0.5, maxval-0.5, 1)\n",
                "    # plot the population spike count histogram for the lfads model and the data\n",
                "    ax[0].hist(gt_spikes.sum(2).flatten(), bins=bins_psc, density=True, alpha=0.5, label='data', color='grey')\n",
                "    ax[0].hist(val.sum(2).flatten(), bins=bins_psc, density=True, alpha=0.5, label=key, color=color_dict[key])\n",
                "\n",
                "    # now plot the density estimate\n",
                "    ax[0].plot(x_eval, density_gt, '.-', label='data kde', color='black')\n",
                "    ax[0].plot(x_eval, density_model, '.-', label=key+' kde', color=color_dict[key])\n",
                "    \n",
                "    ax[0].legend(fontsize=6)\n",
                "    ax[0].set_xlabel('spike count')\n",
                "    ax[0].set_ylabel('density')\n",
                "    \n",
                "    \n",
                "    # now plot the correlation matrix\n",
                "    \n",
                "    # get the correlation structure\n",
                "    C_model = correlation_matrix(val, mode=\"concatenate\")\n",
                "    np.fill_diagonal(C_model, 0)\n",
                "    C_model = np.tril(C_model, k=-1)\n",
                "    \n",
                "    C_gt = correlation_matrix(gt_spikes, mode=\"concatenate\")\n",
                "    np.fill_diagonal(C_gt, 0)\n",
                "    C_gt = np.tril(C_gt, k=-1)\n",
                "    \n",
                "    ax[1].plot(C_gt.flatten(), C_model.flatten(), '.', alpha=0.5, color=color_dict[key], ms=2)\n",
                "    data_limits = [min(C_gt.min(), C_model.min()), max(C_gt.max(), C_model.max())]\n",
                "    #data_limits = [-0.02, 0.05]\n",
                "    ax[1].plot([data_limits[0], data_limits[1]], [data_limits[0], data_limits[1]], '--', color='black')\n",
                "    # make a 1:1 line that spans the whole plot\n",
                "    #ax[1].plot([0, 1], [0, 1], '--', color='black', transform= ax[1].transAxes)\n",
                "    ax[1].set_xlabel('gt')\n",
                "    ax[1].set_ylabel(key)\n",
                "    ax[1].set_aspect('equal')\n",
                "    data_limis_ax = [data_limits[0]-0.15*np.abs(data_limits[0]), data_limits[1]+0.15*np.abs(data_limits[1])]\n",
                "    ax[1].set_xlim(data_limis_ax)\n",
                "    ax[1].set_ylim(data_limis_ax)\n",
                "\n",
                "\n",
                "    # compute the spikes stats\n",
                "    spike_stats_gt = get_spike_train_and_stats(gt_spikes, fps=fps_monkey)\n",
                "\n",
                "    spike_stats_model = get_spike_train_and_stats(val, fps=fps_monkey)\n",
                "\n",
                "\n",
                "    \n",
                "    ax[2].plot(spike_stats_gt['mean_isi'].flatten(),spike_stats_model['mean_isi'].flatten(), '.', alpha=0.5, color=color_dict[key], ms=2)\n",
                "    data_limits = [min(spike_stats_gt['mean_isi'].flatten().min(), spike_stats_model['mean_isi'].flatten().min()), max(spike_stats_gt['mean_isi'].flatten().max(), spike_stats_model['mean_isi'].flatten().max())]\n",
                "    #data_limits = [-0.02, 0.05]\n",
                "    ax[2].plot([data_limits[0], data_limits[1]], [data_limits[0], data_limits[1]], '--', color='black')\n",
                "    # make a 1:1 line that spans the whole plot\n",
                "    #ax[1].plot([0, 1], [0, 1], '--', color='black', transform= ax[1].transAxes)\n",
                "    ax[2].set_xlabel('gt mean isi')\n",
                "    ax[2].set_ylabel(key + ' mean isi')\n",
                "    ax[2].set_aspect('equal')\n",
                "    data_limis_ax = [data_limits[0]-0.15*np.abs(data_limits[0]), data_limits[1]+0.15*np.abs(data_limits[1])]\n",
                "    ax[2].set_xlim(data_limis_ax)\n",
                "    ax[2].set_ylim(data_limis_ax)\n",
                "\n",
                "    ax[3].plot(spike_stats_gt['std_isi'].flatten(),spike_stats_model['std_isi'].flatten(), '.', alpha=0.5, color=color_dict[key], ms=2)\n",
                "    data_limits = [min(spike_stats_gt['std_isi'].flatten().min(), spike_stats_model['std_isi'].flatten().min()), max(spike_stats_gt['std_isi'].flatten().max(), spike_stats_model['std_isi'].flatten().max())]\n",
                "    #data_limits = [-0.02, 0.05]\n",
                "    ax[3].plot([data_limits[0], data_limits[1]], [data_limits[0], data_limits[1]], '--', color='black')\n",
                "    # make a 1:1 line that spans the whole plot\n",
                "    #ax[1].plot([0, 1], [0, 1], '--', color='black', transform= ax[1].transAxes)\n",
                "    ax[3].set_xlabel('gt std isi')\n",
                "    ax[3].set_ylabel(key + ' std isi')\n",
                "    ax[3].set_aspect('equal')\n",
                "    data_limis_ax = [data_limits[0]-0.15*np.abs(data_limits[0]), data_limits[1]+0.15*np.abs(data_limits[1])]\n",
                "    ax[3].set_xlim(data_limis_ax)\n",
                "    ax[3].set_ylim(data_limis_ax)\n",
                "\n",
                "    plt.tight_layout()\n",
                "    # save the figure\n",
                "    plt.savefig(save_path + f'Fig_population_and_single_spike_stats_{key}.png')\n",
                "    plt.savefig(save_path + f'Fig_population_and_single_spike_stats_{key}.pdf')\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# set up comparison stats"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import rmse_nan, kl_div_nan, corr\n",
                "methods = [\"ldns\", \"hist\", \"lfads\", \"train\"]\n",
                "\n",
                "# comparison statistics\n",
                "comparison_stats = [\n",
                "    \"rmse_mean_rate\",\n",
                "    \"rmse_std_rate\",\n",
                "    \"rmse_mean_isi\",\n",
                "    \"rmse_std_isi\",\n",
                "    \"kl_div_psc\",\n",
                "    \"rmse_corr\",\n",
                "    \"rmse_auto_corr\",\n",
                "    \"rmse_cross_corr\",\n",
                "]\n",
                "\n",
                "comp_dict = {method: {metric: [] for metric in comparison_stats} for method in methods}\n",
                "\n",
                "# which vectors to store for final comparions\n",
                "comparison_vectors = [\n",
                "    \"mean_rate\",\n",
                "    \"std_rate\",\n",
                "    \"mean_isi\",\n",
                "    \"std_isi\",\n",
                "    \"psc\",\n",
                "    \"corr\",\n",
                "    \"auto_corr\",\n",
                "    \"cross_corr\",\n",
                "]\n",
                "\n",
                "summary_dict = {\n",
                "    method: {\n",
                "        metric: {data: None for data in [\"gt\", \"model\"]}\n",
                "        for metric in comparison_vectors\n",
                "    }\n",
                "    for method in methods\n",
                "}\n",
                "\n",
                "spikes_dict = {\n",
                "    'lfads': lfads_spikes,\n",
                "    'ldns': np.random.poisson(diffusion_rates),\n",
                "    'hist': hist_spikes,\n",
                "    'train': gt_spikes,\n",
                "}\n",
                "\n",
                "list_indices = [np.random.choice(len(gt_spikes), len(gt_spikes), replace=True) for _ in range(5)]\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "color_dict = {\n",
                "    \"ae\": \"midnightblue\",##2d3047\",\n",
                "    \"ael\": \"royalblue\",#\"cornflowerblue\",#6ca6c1\",\n",
                "    \"diff\": \"darkred\",#A44A3F\",\n",
                "    \"diffl\": \"orangered\", #84271F\",\n",
                "    # \"sim\": \"#afcb90\",\n",
                "    # \"siml\": \"#495f41\",\n",
                "    \"gt\": \"darkgrey\",\n",
                "    \"gtl\": \"#808080\",\n",
                "}\n",
                "lab_dict = {\n",
                "    \"ae\": \"ae\",\n",
                "    \"ael\": \"latents ae\",\n",
                "    \"diff\": \"diffusion\",\n",
                "    \"diffl\": \"latents diffusion\",\n",
                "    \"gt\": \"gt\",\n",
                "    \"gtl\": \"latents gt\",\n",
                "}\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_rate_comparisons\n",
                "from ntldm.utils.eval_utils import average_rates, std_rates\n",
                "fps_monkey = 1000/5\n",
                "\n",
                "# set numpy seed\n",
                "np.random.seed(42)\n",
                "ae_spikes = np.random.poisson(ae_rates)\n",
                "diffusion_spikes = np.random.poisson(diffusion_rates)\n",
                "figsize = cm2inch((8, 8)) \n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],fps=fps_monkey,\n",
                "                      mode='neur', figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'],\n",
                "                      save = True, xlabel='firing rate (Hz)',\n",
                "                      save_path=save_path + \"mean_spike_comparison_gt_ae_diffusion\")\n",
                "\n",
                "\n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],\n",
                "                      fn = std_rates, mode='neur',figsize = figsize,fps=fps_monkey,\n",
                "                      colors=['midnightblue', 'darkred'],\n",
                "                      save = True, xlabel='std fr (Hz)',\n",
                "                    save_path=save_path + \"std_std_comparison_gt_ae_diffusion\")\n",
                "\n",
                "\n",
                "summary_dict['ldns']['mean_rate']['gt'] = average_rates(gt_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "summary_dict['ldns']['mean_rate']['model'] = average_rates(diffusion_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "comp_dict['ldns']['rmse_mean_rate'] = rmse(summary_dict['ldns']['mean_rate']['gt'],\n",
                "                                           summary_dict['ldns']['mean_rate']['model'])\n",
                "\n",
                "\n",
                "summary_dict['ldns']['std_rate']['gt'] = std_rates(gt_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "summary_dict['ldns']['std_rate']['model'] = std_rates(diffusion_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "comp_dict['ldns']['rmse_std_rate'] = rmse(summary_dict['ldns']['std_rate']['gt'],\n",
                "                                          summary_dict['ldns']['std_rate']['model'])\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_rate_comparisons\n",
                "from ntldm.utils.eval_utils import average_rates, std_rates\n",
                "fps_monkey = 1000/5\n",
                "\n",
                "# set numpy seed\n",
                "np.random.seed(42)\n",
                "diffusion_spikes = np.random.poisson(diffusion_rates)\n",
                "ae_spikes = np.random.poisson(ae_rates)\n",
                "\n",
                "figsize = cm2inch((8, 8)) \n",
                "plot_rate_comparisons(gt_spikes, [lfads_spikes, diffusion_spikes],fps=fps_monkey,\n",
                "                      mode='neur', figsize = figsize,\n",
                "                      colors=['purple', 'darkred'],labels=['lfads', 'diffusion'],\n",
                "                      save = True, xlabel='firing rate (Hz)',\n",
                "                      save_path=save_path + \"mean_spike_comparison_gt_lfads_diffusion\")\n",
                "\n",
                "\n",
                "plot_rate_comparisons(gt_spikes, [lfads_spikes, diffusion_spikes],fps=fps_monkey,\n",
                "                      fn = std_rates, mode='neur',figsize = figsize,\n",
                "                      colors=['purple', 'darkred'],labels=['lfads', 'diffusion'],\n",
                "                      save = True, xlabel='std fr (Hz)',\n",
                "                    save_path=save_path + \"std_std_comparison_gt_ae_diffusion\")\n",
                "\n",
                "\n",
                "plot_rate_comparisons(gt_spikes, [hist_spikes, diffusion_spikes],fps=fps_monkey,\n",
                "                      mode='neur', figsize = figsize,\n",
                "                      colors=['green', 'darkred'],labels=['lfads', 'diffusion'],\n",
                "                      save = True, xlabel='firing rate (Hz)',\n",
                "                      save_path=save_path + \"mean_spike_comparison_gt_history_diffusion\")\n",
                "\n",
                "plot_rate_comparisons(gt_spikes, [hist_spikes, diffusion_spikes],fps=fps_monkey,\n",
                "                      fn = std_rates, mode='neur',figsize = figsize,\n",
                "                      colors=['green', 'darkred'],labels=['lfads', 'diffusion'],\n",
                "                      save = True, xlabel='std fr (Hz)',\n",
                "                    save_path=save_path + \"std_comparison_gt_history_diffusion\")\n",
                "\n",
                "# summary_dict['ldns']['mean_rate']['gt'] = average_rates(gt_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "# summary_dict['ldns']['mean_rate']['model'] = average_rates(diffusion_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "# comp_dict['ldns']['rmse_mean_rate'] = rmse(summary_dict['ldns']['mean_rate']['gt'],summary_dict['ldns']['mean_rate']['model'])\n",
                "\n",
                "\n",
                "# summary_dict['ldns']['std_rate']['gt'] = std_rates(gt_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "# summary_dict['ldns']['std_rate']['model'] = std_rates(diffusion_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "# comp_dict['ldns']['rmse_std_rate'] = rmse(summary_dict['ldns']['std_rate']['gt'],summary_dict['ldns']['std_rate']['model'])\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "# ndls samples\n",
                "\n",
                "summary_dict['ldns']['mean_rate']['gt'] = average_rates(gt_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "summary_dict['ldns']['mean_rate']['model'] = average_rates(diffusion_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "\n",
                "summary_dict['ldns']['std_rate']['gt'] = std_rates(gt_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "summary_dict['ldns']['std_rate']['model'] = std_rates(diffusion_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "\n",
                "# lfads samples from the prior \n",
                "\n",
                "summary_dict['lfads']['mean_rate']['gt'] = average_rates(gt_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "summary_dict['lfads']['mean_rate']['model'] = average_rates(lfads_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "\n",
                "summary_dict['lfads']['std_rate']['gt'] = std_rates(gt_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "summary_dict['lfads']['std_rate']['model'] = std_rates(lfads_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "\n",
                "# train compare one half of the training set with the other\n",
                "summary_dict['train']['mean_rate']['gt'] = average_rates(gt_spikes[:int(len(gt_spikes)/2)], mode='neur', fps=fps_monkey,).flatten()\n",
                "summary_dict['train']['mean_rate']['model'] = average_rates(gt_spikes[int(len(gt_spikes)/2):], mode='neur', fps=fps_monkey,).flatten()\n",
                "\n",
                "summary_dict['train']['std_rate']['gt'] = std_rates(gt_spikes[:int(len(gt_spikes)/2)], mode='neur', fps=fps_monkey,).flatten()\n",
                "summary_dict['train']['std_rate']['model'] = std_rates(gt_spikes[int(len(gt_spikes)/2):], mode='neur', fps=fps_monkey,).flatten()\n",
                "\n",
                "# lfads samples from the prior \n",
                "\n",
                "summary_dict['hist']['mean_rate']['gt'] = average_rates(gt_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "summary_dict['hist']['mean_rate']['model'] = average_rates(hist_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "\n",
                "summary_dict['hist']['std_rate']['gt'] = std_rates(gt_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "summary_dict['hist']['std_rate']['model'] = std_rates(hist_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "\n",
                "#hist_spikes = history_sampled_dict['hist'][0] # 1 of 5 samples\n",
                "\n",
                "\n",
                "for method in ['train', 'lfads', 'ldns', 'hist']:#methods:\n",
                "    comp_dict[method]['rmse_std_rate'] = rmse(summary_dict[method]['std_rate']['gt'],\n",
                "                                           summary_dict[method]['std_rate']['model'])\n",
                "\n",
                "    comp_dict[method]['rmse_mean_rate'] = rmse(summary_dict[method]['mean_rate']['gt'],\n",
                "                                            summary_dict[method]['mean_rate']['model'])\n",
                "    \n",
                "\n",
                "\n",
                "for var in comparison_stats:\n",
                "    print(f\"----   {var} ----\" )\n",
                "    for method in ['train', 'lfads', 'ldns', 'hist']:#methods:\n",
                "    #     print(f\"{method} {var} {comp_dict[method][var]}\")\n",
                "        print(f\"{method} {comp_dict[method][var]} \")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import compute_spike_stats, counts_to_spike_trains\n",
                "from ntldm.utils.eval_utils import compute_spike_stats_per_neuron\n",
                "from ntldm.utils.plotting_utils import cm2inch, plot_spiketrain_stats\n",
                "\n",
                "\n",
                "spike_trains_gt = counts_to_spike_trains(gt_spikes, fps=fps_monkey)\n",
                "spike_trains_diff = counts_to_spike_trains(np.random.poisson(diffusion_rates), fps=fps_monkey)\n",
                "spike_trains_lfads = counts_to_spike_trains(lfads_spikes, fps=fps_monkey)\n",
                "spike_trains_hist = counts_to_spike_trains(hist_spikes, fps=fps_monkey)\n",
                "\n",
                "spike_stats_gt = compute_spike_stats_per_neuron(\n",
                "    spike_trains_gt,\n",
                "    n_samples=gt_spikes.shape[0],\n",
                "    n_neurons=gt_spikes.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "\n",
                "spike_stats_diff = compute_spike_stats_per_neuron(\n",
                "    spike_trains_diff,\n",
                "    n_samples=diffusion_rates.shape[0],\n",
                "    n_neurons=diffusion_rates.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "\n",
                "spike_stats_lfads = compute_spike_stats_per_neuron(\n",
                "    spike_trains_lfads,\n",
                "    n_samples=diffusion_rates.shape[0],\n",
                "    n_neurons=diffusion_rates.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "\n",
                "spike_stats_hist = compute_spike_stats_per_neuron(\n",
                "    spike_trains_hist,\n",
                "    n_samples=diffusion_rates.shape[0],\n",
                "    n_neurons=diffusion_rates.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_gt,\n",
                "    spike_stats_diff,\n",
                "    figsize=cm2inch(12, 4),\n",
                "    color=\"darkred\",\n",
                "    labels=[\"gt\", \"diffusion\"],\n",
                "    save = True, \n",
                "    save_path=save_path + \"Fig_3_compute_spike_stats_per_neuron_gt_diff\"\n",
                ")\n",
                "\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_gt,\n",
                "    spike_stats_lfads,\n",
                "    figsize=cm2inch(12, 4),\n",
                "    color=\"purple\",\n",
                "    labels=[\"gt\", \"lfads\"],\n",
                "    save = True, \n",
                "    save_path=save_path + \"compute_spike_stats_per_neuron_gt_lfads\"\n",
                ")\n",
                "\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_gt,\n",
                "    spike_stats_hist,\n",
                "    figsize=cm2inch(12, 4),\n",
                "    color=\"green\",\n",
                "    labels=[\"gt\", \"hist\"],\n",
                "    save = True, \n",
                "    save_path=save_path + \"compute_spike_stats_per_neuron_gt_lfads\"\n",
                ")\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_gt,\n",
                "    spike_stats_diff,\n",
                "    figsize=cm2inch(10.5, 3),\n",
                "    color=\"darkred\",\n",
                "    labels=[\"gt\", \"diffusion\"],\n",
                "    save = True, \n",
                "    save_path=save_path + \"Fig_3_compute_spike_stats_per_neuron_gt_diff\"\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def rmse(y_pred, y):\n",
                "    return np.sqrt(np.nanmean((y_pred - y)**2))\n",
                "\n",
                "summary_dict['ldns']['mean_isi']['gt'] = spike_stats_gt['mean_isi']\n",
                "summary_dict['ldns']['mean_isi']['model'] = spike_stats_diff['mean_isi']\n",
                "comp_dict['ldns']['rmse_mean_isi'] = rmse(summary_dict['ldns']['mean_isi']['gt'],summary_dict['ldns']['mean_isi']['model'])\n",
                "\n",
                "\n",
                "summary_dict['ldns']['std_isi']['gt'] = spike_stats_gt['std_isi']\n",
                "summary_dict['ldns']['std_isi']['model'] = spike_stats_diff['std_isi']\n",
                "comp_dict['ldns']['rmse_std_isi'] = rmse(summary_dict['ldns']['std_isi']['gt'],summary_dict['ldns']['std_isi']['model'])\n",
                "\n",
                "\n",
                "summary_dict['lfads']['mean_isi']['gt'] = spike_stats_gt['mean_isi']\n",
                "summary_dict['lfads']['mean_isi']['model'] = spike_stats_lfads['mean_isi']\n",
                "comp_dict['lfads']['rmse_mean_isi'] = rmse(summary_dict['lfads']['mean_isi']['gt'],\n",
                "                                           summary_dict['lfads']['mean_isi']['model'])\n",
                "\n",
                "\n",
                "summary_dict['lfads']['std_isi']['gt'] = spike_stats_gt['std_isi']\n",
                "summary_dict['lfads']['std_isi']['model'] = spike_stats_lfads['std_isi']\n",
                "comp_dict['lfads']['rmse_std_isi'] = rmse(summary_dict['lfads']['std_isi']['gt'],\n",
                "                                          summary_dict['lfads']['std_isi']['model'])\n",
                "\n",
                "\n",
                "\n",
                "summary_dict['hist']['mean_isi']['gt'] = spike_stats_gt['mean_isi']\n",
                "summary_dict['hist']['mean_isi']['model'] = spike_stats_hist['mean_isi']\n",
                "comp_dict['hist']['rmse_mean_isi'] = rmse(summary_dict['hist']['mean_isi']['gt'],\n",
                "                                           summary_dict['hist']['mean_isi']['model'])\n",
                "\n",
                "\n",
                "summary_dict['hist']['std_isi']['gt'] = spike_stats_gt['std_isi']\n",
                "summary_dict['hist']['std_isi']['model'] = spike_stats_hist['std_isi']\n",
                "comp_dict['hist']['rmse_std_isi'] = rmse(summary_dict['hist']['std_isi']['gt'],\n",
                "                                          summary_dict['hist']['std_isi']['model'])\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "for var in comparison_stats:\n",
                "    print(f\"----   {var} ----\" )\n",
                "    for method in ['train', 'lfads', 'ldns', 'hist']:#methods:\n",
                "    #     print(f\"{method} {var} {comp_dict[method][var]}\")\n",
                "        print(f\"{method} {comp_dict[method][var]} \")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "# plot reconstructed spikes\n",
                "plt.figure(figsize=cm2inch((6, 4)))\n",
                "bins = np.linspace(-0.5, 15.5, 17)\n",
                "plt.hist(gt_spikes.sum(2).flatten(), density=True, color='grey', bins=bins, alpha=0.5)\n",
                "plt.hist(lfads_spikes.sum(2).flatten(), density=True, color='purple', bins=bins, alpha=0.5)\n",
                "\n",
                "plt.legend(['gt', 'lfads'])\n",
                "plt.title('population spike count')\n",
                "plt.ylabel('frequency')\n",
                "plt.xlabel('number of spikes per time bin')\n",
                "plt.savefig(save_path + 'population_spike_count_lfads.png')\n",
                "plt.savefig(save_path + 'population_spike_count_lfads.pdf')\n",
                "\n",
                "\n",
                "\n",
                "# plot reconstructed spikes\n",
                "plt.figure(figsize=cm2inch((6, 4)))\n",
                "plt.hist(gt_spikes.sum(2).flatten(), density=True, color='grey', bins=bins, alpha=0.5)\n",
                "plt.hist(diffusion_spikes.sum(2).flatten(), density=True, color='darkred', bins=bins, alpha=0.5)\n",
                "\n",
                "plt.legend(['gt', 'ldns'])\n",
                "plt.title('population spike count')\n",
                "plt.ylabel('frequency')\n",
                "plt.xlabel('number of spikes per time bin')\n",
                "plt.savefig(save_path + 'population_spike_count_diff.png')\n",
                "plt.savefig(save_path + 'population_spike_count_diff.pdf')\n",
                "\n",
                "\n",
                "\n",
                "plt.figure(figsize=cm2inch((6, 4)))\n",
                "plt.hist(diffusion_spikes.sum(2).flatten(), density=True, color='darkred', bins=bins, alpha=0.5)\n",
                "plt.hist(lfads_spikes.sum(2).flatten(), density=True, color='purple', bins=bins, alpha=0.5)\n",
                "\n",
                "plt.legend(['diffusion', 'lfads'])\n",
                "plt.title('population spike count')\n",
                "plt.ylabel('frequency')\n",
                "plt.xlabel('number of spikes per time bin')\n",
                "plt.savefig(save_path + 'population_spike_count_lfads_diffusion.png')\n",
                "plt.savefig(save_path + 'population_spike_count_lfads_diffusion.pdf')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "bins"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "diffusion_spikes.sum(2).max()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "max_val_pop_count = np.max([diffusion_spikes.sum(2).max(), lfads_spikes.sum(2).max(), gt_spikes.sum(2).max()])\n",
                "max_val_pop_count = int(max_val_pop_count)\n",
                "bins = np.linspace(-0.5, max_val_pop_count-0.5, max_val_pop_count+1)\n",
                "\n",
                "# calculate kl divergence between these two histograms \n",
                "hist_diff = plt.hist(diffusion_spikes.sum(2).flatten(), density=True, color='darkred', bins=bins, alpha=0.5)\n",
                "hist_gt = plt.hist(gt_spikes.sum(2).flatten(), density=True, color='grey', bins=bins, alpha=0.5)\n",
                "hist_lfads = plt.hist(lfads_spikes.sum(2).flatten(), density=True, color='purple', bins=bins, alpha=0.5)\n",
                "\n",
                "\n",
                "summary_dict['lfads']['psc']['gt'] = hist_gt[0]\n",
                "summary_dict['lfads']['psc']['model'] = hist_lfads[0]\n",
                "summary_dict['ldns']['psc']['gt'] = hist_gt[0]\n",
                "summary_dict['ldns']['psc']['model'] = hist_diff[0]\n",
                "\n",
                "for method in ['lfads', 'ldns']:\n",
                "    comp_dict[method]['kl_div_psc'] = kl_div(summary_dict[method]['psc']['gt'],\n",
                "                                             summary_dict[method]['psc']['model'])\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# calculate correlation matrices"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "C_gt = correlation_matrix(gt_spikes, mode=\"concatenate\")\n",
                "np.fill_diagonal(C_gt, 0)\n",
                "min_max = np.nanmax(np.abs(C_gt))\n",
                "\n",
                "ae_spikes = np.random.poisson(ae_rates)\n",
                "C_ae = correlation_matrix(ae_spikes, mode=\"concatenate\")\n",
                "np.fill_diagonal(C_ae, 0)\n",
                "min_max_ae = np.nanmax(np.abs(C_ae))\n",
                "\n",
                "diff_spikes = np.random.poisson(diffusion_rates)\n",
                "C_diff = correlation_matrix(diff_spikes, mode=\"concatenate\")\n",
                "np.fill_diagonal(C_diff, 0)\n",
                "min_max_diff = np.nanmax(np.abs(C_diff))\n",
                "\n",
                "C_lfads = correlation_matrix(lfads_spikes, mode=\"concatenate\")\n",
                "np.fill_diagonal(C_lfads, 0)\n",
                "min_max_lfads = np.nanmax(np.abs(C_lfads))\n",
                "\n",
                "# only get lower triangluar half\n",
                "C_gt_lower = np.tril(C_gt, k=-1)\n",
                "C_diff_lower = np.tril(C_diff, k=-1)\n",
                "C_lfads_lower = np.tril(C_lfads, k=-1)\n",
                "\n",
                "comp_dict['ldns']['rmse_corr'] = rmse(C_gt_lower, C_diff_lower)\n",
                "summary_dict['ldns']['corr']['gt'] = C_gt_lower.flatten()\n",
                "summary_dict['ldns']['corr']['model'] = C_diff_lower.flatten()\n",
                "\n",
                "comp_dict['lfads']['rmse_corr'] = rmse(C_gt_lower, C_lfads_lower)\n",
                "summary_dict['lfads']['corr']['gt'] = C_gt_lower.flatten()\n",
                "summary_dict['lfads']['corr']['model'] = C_lfads_lower.flatten()\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "for var in comparison_stats:\n",
                "    print(f\"----   {var} ----\" )\n",
                "    for method in ['train', 'lfads', 'ldns']:#methods:\n",
                "    #     print(f\"{method} {var} {comp_dict[method][var]}\")\n",
                "        print(f\"{method} {comp_dict[method][var]} \")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "C_gt = correlation_matrix(gt_spikes, mode=\"concatenate\")\n",
                "np.fill_diagonal(C_gt, 0)\n",
                "min_max = np.nanmax(np.abs(C_gt))\n",
                "plt.figure(figsize=cm2inch(8, 6))\n",
                "\n",
                "plt.imshow(C_gt, vmax=min_max,vmin=-min_max, cmap=\"coolwarm\")\n",
                "plt.colorbar()\n",
                "plt.tight_layout()\n",
                "plt.xlabel('neuron index')\n",
                "plt.ylabel('neuron index')\n",
                "plt.savefig(save_path + 'gt_correlation_matrix.png')\n",
                "plt.savefig(save_path + 'gt_correlation_matrix.pdf')\n",
                "\n",
                "\n",
                "plt.figure(figsize=cm2inch(8, 6))\n",
                "ae_spikes = np.random.poisson(ae_rates)\n",
                "C_ae = correlation_matrix(ae_spikes, mode=\"concatenate\")\n",
                "np.fill_diagonal(C_ae, 0)\n",
                "min_max_ae = np.nanmax(np.abs(C_ae))\n",
                "\n",
                "plt.imshow(C_ae, vmax=min_max_ae,vmin=-min_max_ae, cmap=\"coolwarm\")\n",
                "plt.colorbar()\n",
                "plt.tight_layout()\n",
                "plt.xlabel('neuron index')\n",
                "plt.ylabel('neuron index')\n",
                "plt.savefig(save_path + 'ae_correlation_matrix.png')\n",
                "plt.savefig(save_path + 'ae_correlation_matrix.pdf')\n",
                "\n",
                "\n",
                "plt.figure(figsize=cm2inch(8, 6))\n",
                "diff_spikes = np.random.poisson(diffusion_rates)\n",
                "C_diff = correlation_matrix(diff_spikes, mode=\"concatenate\")\n",
                "np.fill_diagonal(C_diff, 0)\n",
                "min_max_diff = np.nanmax(np.abs(C_diff))\n",
                "\n",
                "plt.imshow(C_diff, vmax=min_max_diff,vmin=-min_max_diff, cmap=\"coolwarm\")\n",
                "plt.colorbar()\n",
                "plt.tight_layout()\n",
                "plt.xlabel('neuron index')\n",
                "plt.ylabel('neuron index')\n",
                "plt.savefig(save_path + 'diffusion_correlation_matrix.png')\n",
                "plt.savefig(save_path + 'diffusion_correlation_matrix.pdf')\n",
                "\n",
                "\n",
                "max_all = np.max([min_max, min_max_ae, min_max_diff])\n",
                "\n",
                "fig, ax = plt.subplots(1, 4, figsize=cm2inch(18, 6))\n",
                "\n",
                "ax[0].imshow(C_gt, vmax=max_all,vmin=-max_all, cmap=\"coolwarm\")\n",
                "ax[1].imshow(C_ae, vmax=max_all,vmin=-max_all, cmap=\"coolwarm\")\n",
                "ax[2].imshow(C_diff, vmax=max_all,vmin=-max_all, cmap=\"coolwarm\")\n",
                "ax[3].plot(C_gt.flatten(), C_ae.flatten(), 'o', alpha=0.5, color=\"midnightblue\")\n",
                "ax[3].plot(C_gt.flatten(), C_diff.flatten(), 'o', alpha=0.5, color=\"darkred\")\n",
                "ax[3].plot([-max_all, max_all], [-max_all, max_all], 'k--')\n",
                "\n",
                "plt.tight_layout()\n",
                "\n",
                "\n",
                "# only get lower triangluar half\n",
                "C_gt_lower = np.tril(C_gt, k=-1)\n",
                "C_diff_lower = np.tril(C_diff, k=-1)\n",
                "\n",
                "\n",
                "comp_dict['ldns']['rmse_corr'] = rmse(C_gt, C_diff)\n",
                "summary_dict['ldns']['corr']['gt'] = C_gt_lower.flatten()\n",
                "summary_dict['ldns']['corr']['model'] = C_diff_lower.flatten()\n",
                "\"\"\"\n",
                "methods = ['ldns', 'glm', 'lfads', 'train']\n",
                "['rmse_mean_rate', 'rmse_mean_isi', 'rmse_std_isi', 'kl_div_psc', 'rmse_corr']\n",
                "['mean_rate', 'mean_isi', 'std_isi', 'psc', 'corr', 'auto_corr', 'cross_corr']\n",
                "summary_dict\n",
                "comp_dict\n",
                "\n",
                "\"\"\""
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "#### Store dict"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "comp_dict['ldns']['rmse_corr'] = rmse(C_gt, C_diff)\n",
                "summary_dict['ldns']['corr']['gt'] = C_gt_lower.flatten()\n",
                "summary_dict['ldns']['corr']['model'] = C_diff_lower.flatten()\n",
                "\n",
                "\n",
                "comp_dict['ldns']['rmse_corr'] = rmse(C_gt, C_diff)\n",
                "plt.plot(summary_dict['ldns']['corr']['gt'], summary_dict['ldns']['corr']['model'], 'o', alpha=0.5)\n",
                "print(comp_dict['ldns']['rmse_corr'])\n",
                "\"\"\"\n",
                "methods = ['ldns', 'glm', 'lfads', 'train']\n",
                "['rmse_mean_rate', 'rmse_mean_isi', 'rmse_std_isi', 'kl_div_psc', 'rmse_corr']\n",
                "['mean_rate', 'mean_isi', 'std_isi', 'psc', 'corr', 'auto_corr', 'cross_corr']\n",
                "summary_dict\n",
                "comp_dict\n",
                "\n",
                "\"\"\""
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import correlation_matrix\n",
                "from ntldm.utils.plotting_utils import plot_correlation_matrices_monkey\n",
                "# Compute correlation matrices for one sample sequence\n",
                "ae_spikes = np.random.poisson(ae_rates)\n",
                "diffusion_spikes = np.random.poisson(diffusion_rates)\n",
                "\n",
                "\n",
                "plot_correlation_matrices_monkey(gt_spikes,\n",
                "                          [ae_spikes, diffusion_spikes],\n",
                "                          sample=None,mode=\"concatenate\",\n",
                "                          model_labels=[ 'ae', 'diffusion'],\n",
                "                          model_colors=['midnightblue', 'darkred'],\n",
                "                          figsize = cm2inch((25, 6)))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "C_gt = correlation_matrix(gt_spikes, mode=\"concatenate\")\n",
                "np.fill_diagonal(C_gt, 0)\n",
                "min_max = np.nanmax(np.abs(C_gt))\n",
                "plt.figure(figsize=cm2inch(8, 6))\n",
                "\n",
                "plt.imshow(C_gt, vmax=min_max,vmin=-min_max, cmap=\"coolwarm\")\n",
                "plt.colorbar()\n",
                "plt.tight_layout()\n",
                "plt.xlabel('neuron index')\n",
                "plt.ylabel('neuron index')\n",
                "plt.savefig(save_path + 'gt_correlation_matrix.png')\n",
                "plt.savefig(save_path + 'gt_correlation_matrix.pdf')\n",
                "\n",
                "\n",
                "plt.figure(figsize=cm2inch(8, 6))\n",
                "ae_spikes = np.random.poisson(ae_rates)\n",
                "C_ae = correlation_matrix(ae_spikes, mode=\"concatenate\")\n",
                "np.fill_diagonal(C_ae, 0)\n",
                "min_max_ae = np.nanmax(np.abs(C_ae))\n",
                "\n",
                "plt.imshow(C_ae, vmax=min_max_ae,vmin=-min_max_ae, cmap=\"coolwarm\")\n",
                "plt.colorbar()\n",
                "plt.tight_layout()\n",
                "plt.xlabel('neuron index')\n",
                "plt.ylabel('neuron index')\n",
                "plt.savefig(save_path + 'ae_correlation_matrix.png')\n",
                "plt.savefig(save_path + 'ae_correlation_matrix.pdf')\n",
                "\n",
                "\n",
                "plt.figure(figsize=cm2inch(8, 6))\n",
                "diff_spikes = np.random.poisson(diffusion_rates)\n",
                "C_diff = correlation_matrix(diff_spikes, mode=\"concatenate\")\n",
                "np.fill_diagonal(C_diff, 0)\n",
                "min_max_diff = np.nanmax(np.abs(C_diff))\n",
                "\n",
                "plt.imshow(C_diff, vmax=min_max_diff,vmin=-min_max_diff, cmap=\"coolwarm\")\n",
                "plt.colorbar()\n",
                "plt.tight_layout()\n",
                "plt.xlabel('neuron index')\n",
                "plt.ylabel('neuron index')\n",
                "plt.savefig(save_path + 'diffusion_correlation_matrix.png')\n",
                "plt.savefig(save_path + 'diffusion_correlation_matrix.pdf')\n",
                "\n",
                "\n",
                "max_all = np.max([min_max, min_max_ae, min_max_diff])\n",
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(18, 6))\n",
                "\n",
                "ax[0].imshow(C_gt, vmax=max_all,vmin=-max_all, cmap=\"coolwarm\")\n",
                "ax[0].colorbar()\n",
                "ax[0].set_title('data monkey')\n",
                "ax[1].imshow(C_diff, vmax=max_all,vmin=-max_all, cmap=\"coolwarm\")\n",
                "ax[1].set_title('LDNS samples')\n",
                "ax[1].colorbar()\n",
                "\n",
                "ax[1].locator_params(nbins=4)\n",
                "ax[1].locator_params(nbins=4)\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import correlation_matrix\n",
                "from ntldm.utils.plotting_utils import plot_correlation_matrices_monkey\n",
                "# Compute correlation matrices for one sample sequence\n",
                "diffusion_spikes = spikes_dict_single['ldns']\n",
                "\n",
                "\n",
                "plot_correlation_matrices_monkey(gt_spikes,\n",
                "                          [lfads_spikes, diffusion_spikes],\n",
                "                          sample=None,mode=\"concatenate\",\n",
                "                          model_labels=[ 'lfads', 'ldns'],\n",
                "                          model_colors=['purple', 'darkred'],\n",
                "                          figsize = cm2inch((30, 10)), \n",
                "                          save = True, save_path=save_path + \"Fig_3_correlation_matrices_lfads_diffusion\")#, xlabel='neuron index', ylabel='neuron index')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "    \n",
                "from ntldm.utils.eval_utils import group_neurons_temp_corr, get_temp_corr_summary\n",
                "from ntldm.utils.plotting_utils import plot_cross_corr_summary, plot_temp_corr_summary\n",
                "\n",
                "\n",
                "ae_spikes = np.random.poisson(ae_rates)\n",
                "diff_spikes = np.random.poisson(diffusion_rates)\n",
                "\n",
                "\n",
                "groups = group_neurons_temp_corr(gt_spikes.transpose(1,0,2), num_groups=4)\n",
                "cross_corr_groups, auto_corr_groups = get_temp_corr_summary(gt_spikes.transpose(1,0,2), groups, nlags=30,mode='biased',\n",
                "                                                            batch_first=False)\n",
                "fig_cross, fig_auto = plot_temp_corr_summary(cross_corr_groups, auto_corr_groups, name='Data')\n",
                "\n",
                "cross_corr_groups_sampled, auto_corr_groups_sampled = get_temp_corr_summary(ae_spikes.transpose(1,0,2), groups, nlags=30, mode='biased',\n",
                "                                                            batch_first=False)\n",
                "fig_cross_sampled, fig_auto_sampled = plot_temp_corr_summary(cross_corr_groups_sampled, auto_corr_groups_sampled, name='AE Samples')\n",
                "\n",
                "\n",
                "cross_corr_groups_diff, auto_corr_groups_diff = get_temp_corr_summary(diff_spikes.transpose(1,0,2), groups, nlags=30, mode='biased',\n",
                "                                                            batch_first=False)\n",
                "fig_cross_sampled, fig_auto_sampled = plot_temp_corr_summary(cross_corr_groups_diff, auto_corr_groups_diff, name='Diffusion Samples')\n",
                "\n",
                "\n",
                "\n",
                "summary_dict['ldns']['auto_corr']['gt'] = auto_corr_groups\n",
                "summary_dict['ldns']['auto_corr']['model'] = auto_corr_groups_diff\n",
                "summary_dict['ldns']['cross_corr']['gt'] = cross_corr_groups\n",
                "summary_dict['ldns']['cross_corr']['model'] = cross_corr_groups_diff\n",
                "comp_dict['ldns']['rmse_auto_corr'] = rmse(np.array(auto_corr_groups), np.array(auto_corr_groups_diff))\n",
                "comp_dict['ldns']['rmse_cross_corr'] = rmse(np.array(cross_corr_groups), np.array(auto_corr_groups_diff))\n",
                "\n",
                "\n",
                "\n",
                "\"\"\"\n",
                "methods = ['ldns', 'glm', 'lfads', 'train']\n",
                "['rmse_mean_rate', 'rmse_mean_isi', 'rmse_std_isi', 'kl_div_psc', 'rmse_corr']\n",
                "['mean_rate', 'mean_isi', 'std_isi', 'psc', 'corr', 'auto_corr', 'cross_corr']\n",
                "summary_dict\n",
                "comp_dict\n",
                "\n",
                "\"\"\"\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_cross_corr_summary\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    save=True,\n",
                "    save_path=save_path + \"cross_corr_gt\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    save=True,\n",
                "    save_path=save_path + \"cross_corr_diffusion\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_sampled,\n",
                "    name=\"ae\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Blues\",\n",
                "    save=True,\n",
                "    save_path=save_path + \"cross_corr_ae\",\n",
                ")\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8, 6))\n",
                "save = True\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=3,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=3,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_sampled,\n",
                "    name=\"ae\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Blues\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"ae \",\n",
                "    ncol=3,\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"cross_corr_all.png\")\n",
                "    fig.savefig(save_path + \"cross_corr_all.pdf\")\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(6, 4))\n",
                "save = True\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=2,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=2,\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"cross_corr_gt_diff.png\")\n",
                "    fig.savefig(save_path + \"cross_corr_gt_diff.pdf\")\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8, 6))\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=3,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=3,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups_sampled,\n",
                "    name=\"ae\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Blues\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"ae \",\n",
                "    ncol=3,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"auto_corr_all.png\")\n",
                "    fig.savefig(save_path + \"auto_corr_all.pdf\")\n",
                "    \n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(6, 4))\n",
                "save = True\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=2,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=2,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"auto_corr_gt_diff.png\")\n",
                "    fig.savefig(save_path + \"auto_corr_gt_diff.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "\n",
                "plot_cross_corr_summary(cross_corr_groups, name='gt', figsize=cm2inch(6, 4), cmap='Greys')\n",
                "plot_cross_corr_summary(cross_corr_groups_diff, name='diffusion', figsize=cm2inch(6, 4), cmap='Reds')\n",
                "plot_cross_corr_summary(cross_corr_groups_sampled, name='ae', figsize=cm2inch(6, 4), cmap='Blues')\n",
                "\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8,6))\n",
                "plot_cross_corr_summary(cross_corr_groups, name='gt', figsize=cm2inch(6, 4), cmap='Greys', ax_corr=ax, labels='gt g', ncol=3)\n",
                "plot_cross_corr_summary(cross_corr_groups_diff, name='diffusion', figsize=cm2inch(6, 4), cmap='Reds', ax_corr=ax, labels='diff g', ncol=3)\n",
                "plot_cross_corr_summary(cross_corr_groups_sampled, name='ae', figsize=cm2inch(6, 4), cmap='Blues', ax_corr=ax, labels='ae g', ncol=3)\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8,6))\n",
                "plot_cross_corr_summary(auto_corr_groups, name='gt', figsize=cm2inch(6, 4), cmap='Greys', ax_corr=ax, labels='gt g', ncol=3, title='auto-corr', ylabel='auto-corr')\n",
                "plot_cross_corr_summary(auto_corr_groups_diff, name='diffusion', figsize=cm2inch(6, 4), cmap='Reds', ax_corr=ax, labels='diff g', ncol=3, title='auto-corr', ylabel='auto-corr')\n",
                "plot_cross_corr_summary(auto_corr_groups_sampled, name='ae', figsize=cm2inch(6, 4), cmap='Blues', ax_corr=ax, labels='ae g', ncol=3, title='auto-corr', ylabel='auto-corr')\n",
                "\n",
                "\n",
                "\n",
                "\n",
                "plot_cross_corr_summary(auto_corr_groups, name='gt', figsize=cm2inch(6, 4), cmap='Greys', title='auto-corr', ylabel='auto-corr')\n",
                "plot_cross_corr_summary(auto_corr_groups_diff, name='diffusion', figsize=cm2inch(6, 4), cmap='Reds', title='auto-corr', ylabel='auto-corr')\n",
                "plot_cross_corr_summary(auto_corr_groups_sampled, name='ae', figsize=cm2inch(6, 4), cmap='Blues', title='auto-corr', ylabel='auto-corr')\n",
                "\n",
                "\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8,6))\n",
                "plot_cross_corr_summary(cross_corr_groups, name='gt', figsize=cm2inch(6, 4), cmap='Greys', ax_corr=ax, labels='gt g', ncol=2)\n",
                "plot_cross_corr_summary(cross_corr_groups_sampled, name='ae', figsize=cm2inch(6, 4), cmap='Blues', ax_corr=ax, labels='ae g', ncol=2)\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8,6))\n",
                "plot_cross_corr_summary(auto_corr_groups, name='gt', figsize=cm2inch(6, 4), cmap='Greys', ax_corr=ax, labels='gt g', ncol=2, title='auto-corr', ylabel='auto-corr')\n",
                "plot_cross_corr_summary(auto_corr_groups_sampled, name='ae', figsize=cm2inch(6, 4), cmap='Blues', ax_corr=ax, labels='ae g', ncol=2, title='auto-corr', ylabel='auto-corr')\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_rate_comparisons\n",
                "from ntldm.utils.eval_utils import average_rates, std_rates\n",
                "\n",
                "figsize = cm2inch((5, 5)) \n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],\n",
                "                      mode='neur', figsize = figsize,fps=fps_monkey,\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "\n",
                "\n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],fps=fps_monkey,\n",
                "                      fn = std_rates, mode='neur',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "\n",
                "\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Store dict"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_rate_comparisons\n",
                "from ntldm.utils.eval_utils import average_rates, std_rates\n",
                "\n",
                "# set numpy seed\n",
                "np.random.seed(42)\n",
                "ae_spikes = np.random.poisson(ae_rates)\n",
                "diffusion_spikes = np.random.poisson(diffusion_rates)\n",
                "figsize = cm2inch((7, 7)) \n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],fps=fps_monkey,\n",
                "                      mode='neur', figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'],\n",
                "                      save = True, xlabel='firing rate (Hz)',\n",
                "                      save_path=save_path + \"Fig_3_mean_spike_comparison_gt_ae_diffusion\")\n",
                "\n",
                "\n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],\n",
                "                      fn = std_rates, mode='neur',figsize = figsize,fps=fps_monkey,\n",
                "                      colors=['midnightblue', 'darkred'],\n",
                "                      save = True, xlabel='std fr (Hz)',\n",
                "                    save_path=save_path + \"Fig_3_std_std_comparison_gt_ae_diffusion\")\n",
                "\n",
                "\n",
                "# summary_dict['ldns']['mean_rate']['gt'] = average_rates(gt_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "# summary_dict['ldns']['mean_rate']['model'] = average_rates(diffusion_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "# comp_dict['ldns']['rmse_mean_rate'] = rmse(summary_dict['ldns']['mean_rate']['gt'],summary_dict['ldns']['mean_rate']['model'])\n",
                "\n",
                "\n",
                "# summary_dict['ldns']['std_rate']['gt'] = std_rates(gt_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "# summary_dict['ldns']['std_rate']['model'] = std_rates(diffusion_spikes, mode='neur', fps=fps_monkey,).flatten()\n",
                "# comp_dict['ldns']['rmse_std_rate'] = rmse(summary_dict['ldns']['std_rate']['gt'],summary_dict['ldns']['std_rate']['model'])\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_rate_comparisons\n",
                "from ntldm.utils.eval_utils import average_rates, std_rates\n",
                "\n",
                "\n",
                "\n",
                "\n",
                "figsize = cm2inch((8, 8)) \n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],fps=fps_monkey,\n",
                "                      mode='neur', figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'], save = True, xlabel='firing rate (Hz)',\n",
                "    save_path=save_path + \"mean_spike_comparison_gt_ae_diffusion\")\n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],\n",
                "                      mode='neurtime',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],\n",
                "                      mode='neursample',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "\n",
                "\n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],\n",
                "                      fn = std_rates, mode='neur',figsize = figsize,fps=fps_monkey,\n",
                "                      colors=['midnightblue', 'darkred'], save = True, xlabel='std fr (Hz)',\n",
                "    save_path=save_path + \"std_std_comparison_gt_ae_diffusion\")\n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],\n",
                "                      fn = std_rates, mode='neurtime',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],\n",
                "                      fn = std_rates, mode='neursample',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Turn spike counts into spike trains "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import compute_spike_stats, counts_to_spike_trains\n",
                "from ntldm.utils.eval_utils import compute_spike_stats_per_neuron\n",
                "from ntldm.utils.plotting_utils import cm2inch, plot_spiketrain_stats\n",
                "\n",
                "\n",
                "spike_trains_gt = counts_to_spike_trains(gt_spikes, fps=fps_monkey)\n",
                "spike_trains_ae = counts_to_spike_trains(np.random.poisson(ae_rates), fps=fps_monkey)\n",
                "spike_trains_diff = counts_to_spike_trains(np.random.poisson(diffusion_rates), fps=fps_monkey)\n",
                "\n",
                "spike_stats_gt = compute_spike_stats_per_neuron(\n",
                "    spike_trains_gt,\n",
                "    n_samples=gt_spikes.shape[0],\n",
                "    n_neurons=gt_spikes.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "spike_stats_ae = compute_spike_stats_per_neuron(\n",
                "    spike_trains_ae,\n",
                "    n_samples=ae_rates.shape[0],\n",
                "    n_neurons=ae_rates.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "spike_stats_diff = compute_spike_stats_per_neuron(\n",
                "    spike_trains_diff,\n",
                "    n_samples=diffusion_rates.shape[0],\n",
                "    n_neurons=diffusion_rates.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_gt,\n",
                "    spike_stats_ae,\n",
                "    figsize=cm2inch(12, 4),\n",
                "    color=\"midnightblue\",\n",
                "    labels=[\"gt\", \"ae\"],\n",
                "    save = True, \n",
                "    save_path=save_path + \"compute_spike_stats_per_neuron_gt_ae\"\n",
                ")\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_gt,\n",
                "    spike_stats_diff,\n",
                "    figsize=cm2inch(12, 4),\n",
                "    color=\"darkred\",\n",
                "    labels=[\"gt\", \"diffusion\"],\n",
                "    save = True, \n",
                "    save_path=save_path + \"compute_spike_stats_per_neuron_gt_diff\"\n",
                ")\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_ae,\n",
                "    spike_stats_diff,\n",
                "    figsize=cm2inch(12, 4),\n",
                "    color=\"purple\",\n",
                "    labels=[\"ae\", \"diffusion\"],\n",
                ")\n",
                "\n",
                "\n",
                "\n",
                "\n",
                "summary_dict['ldns']['mean_isi']['gt'] = spike_stats_gt['mean_isi']\n",
                "summary_dict['ldns']['mean_isi']['model'] = spike_stats_diff['mean_isi']\n",
                "comp_dict['ldns']['rmse_mean_isi'] = rmse(summary_dict['ldns']['mean_isi']['gt'],summary_dict['ldns']['mean_isi']['model'])\n",
                "\n",
                "\n",
                "summary_dict['ldns']['std_isi']['gt'] = spike_stats_gt['std_isi']\n",
                "summary_dict['ldns']['std_isi']['model'] = spike_stats_diff['std_isi']\n",
                "comp_dict['ldns']['rmse_std_isi'] = rmse(summary_dict['ldns']['std_isi']['gt'],summary_dict['ldns']['std_isi']['model'])\n",
                "\n",
                "\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Store dict"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "summary_dict['ldns']['mean_isi']['gt'] = spike_stats_gt['mean_isi']\n",
                "summary_dict['ldns']['mean_isi']['model'] = spike_stats_diff['mean_isi']\n",
                "comp_dict['ldns']['rmse_mean_isi'] = rmse(summary_dict['ldns']['mean_isi']['gt'],summary_dict['ldns']['mean_isi']['model'])\n",
                "\n",
                "\n",
                "summary_dict['ldns']['std_isi']['gt'] = spike_stats_gt['std_isi']\n",
                "summary_dict['ldns']['std_isi']['model'] = spike_stats_diff['std_isi']\n",
                "comp_dict['ldns']['rmse_std_isi'] = rmse(summary_dict['ldns']['std_isi']['gt'],summary_dict['ldns']['std_isi']['model'])\n",
                "\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# take all not mean aggregated\n",
                "from ntldm.utils.eval_utils import compute_spike_stats_per_neuron\n",
                "from ntldm.utils.plotting_utils import cm2inch, plot_spiketrain_stats\n",
                "\n",
                "spike_stats_gt = compute_spike_stats_per_neuron(spike_trains_gt, n_samples=gt_spikes.shape[0], n_neurons=gt_spikes.shape[2], mean_output=False)\n",
                "spike_stats_ae = compute_spike_stats_per_neuron(spike_trains_ae, n_samples=ae_rates.shape[0], n_neurons=ae_rates.shape[2], mean_output=False)\n",
                "spike_stats_diff = compute_spike_stats_per_neuron(spike_trains_diff, n_samples=diffusion_rates.shape[0], n_neurons=diffusion_rates.shape[2], mean_output=False)\n",
                "\n",
                "plot_spiketrain_stats(spike_stats_gt, spike_stats_ae, figsize=cm2inch(12, 4), color=\"midnightblue\", labels=[\"gt\", \"ae\"])\n",
                "plot_spiketrain_stats(spike_stats_gt, spike_stats_diff, figsize=cm2inch(12, 4), color=\"darkred\", labels=[\"gt\", \"diffusion\"])\n",
                "plot_spiketrain_stats(spike_stats_ae, spike_stats_diff, figsize=cm2inch(12, 4), color=\"purple\", labels=[\"ae\", \"diffusion\"])\n",
                "\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Diffusion power spectral density checks "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "fps = 1000 / 5\n",
                "plot_n_channel_sd(diffusion_latents,\n",
                "                  ae_latents,\n",
                "                  channels=np.arange(16),\n",
                "                  fps=fps, save=False,\n",
                "                  save_path=None,\n",
                "                  colors=['darkred', \"midnightblue\"],\n",
                "                  labels=['diffusion', 'ae'], ystack=4, figsize=cm2inch(20, 20))\n",
                "\n",
                "\n",
                "plot_n_channel_sd(ae_latents,diffusion_latents,\n",
                "                  channels=np.arange(16),\n",
                "                  fps=fps, save=True, save_path=save_path+'latents_power_spec_density_diff_ae',\n",
                "                  colors=[color_dict['ael'], color_dict['diffl']],\n",
                "                  labels=['ae', 'diff'], ystack=4, figsize=cm2inch(20, 20), lw=1)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "## eval to decode behavior from generated (smooth) rates, compare with real behavior\n",
                "\n",
                "\n",
                "def gen_rates_and_train_decoded_behavior(\n",
                "    ema_denoiser,\n",
                "    scheduler,\n",
                "    ae,\n",
                "    cfg,\n",
                "    train_latent_dataloader,\n",
                "    val_latent_dataloader,\n",
                "    num_samples=100,\n",
                "    device=\"cuda\",\n",
                "):\n",
                "\n",
                "    avg_denoiser = ema_denoiser.averaged_model\n",
                "    avg_denoiser.eval()\n",
                "\n",
                "    ae.eval()\n",
                "\n",
                "    train_rates = []\n",
                "    train_behavior = []\n",
                "    for batch in train_latent_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        behavior = batch[\"behavior\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = ae(signal)[0].cpu()\n",
                "        train_rates.append(output_rates)\n",
                "        train_behavior.append(behavior.cpu())\n",
                "\n",
                "    train_rates = torch.cat(train_rates, 0)  # [B C L]\n",
                "    train_behavior = torch.cat(train_behavior, 0)  # [B 2 L]\n",
                "    print(train_rates, train_behavior)\n",
                "\n",
                "    val_rates = []\n",
                "    val_behavior = []\n",
                "    for batch in val_latent_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        behavior = batch[\"behavior\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = ae(signal)[0].cpu()\n",
                "        val_rates.append(output_rates)\n",
                "        val_behavior.append(behavior.cpu())\n",
                "\n",
                "    val_rates = torch.cat(val_rates, 0)  # [B C L]\n",
                "    val_behavior = torch.cat(val_behavior, 0)  # [B 2 L]\n",
                "    print(val_rates, val_behavior)\n",
                "\n",
                "    # decode rates to behavior\n",
                "    from sklearn.linear_model import Ridge\n",
                "\n",
                "    bs_train = train_rates.shape[0]\n",
                "    bs_val = val_rates.shape[0]\n",
                "    train_rates = rearrange(train_rates, \"b c l -> (b l) c\").numpy()\n",
                "    val_rates = rearrange(val_rates, \"b c l -> (b l) c\").numpy()\n",
                "    train_behavior = rearrange(train_behavior, \"b c l -> (b l) c\").numpy()\n",
                "    val_behavior = rearrange(val_behavior, \"b c l -> (b l) c\").numpy()\n",
                "\n",
                "\n",
                "    RidgeRegressionModel = Ridge(alpha=1e-6)\n",
                "    RidgeRegressionModel.fit(train_rates, train_behavior)\n",
                "    predicted_behavior = RidgeRegressionModel.predict(val_rates)\n",
                "    # r2 score\n",
                "    from sklearn.metrics import r2_score\n",
                "\n",
                "    r2 = r2_score(val_behavior, predicted_behavior)\n",
                "    print(f\"R2 score on val: {r2:.3f}\")\n",
                "\n",
                "    \n",
                "    # sample from the denoiser\n",
                "    sampled_latents = sample(\n",
                "        ema_denoiser=ema_denoiser,\n",
                "        scheduler=scheduler,\n",
                "        cfg=cfg,\n",
                "        batch_size=num_samples,\n",
                "        device=device,\n",
                "    )\n",
                "    sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "        sampled_latents.device\n",
                "    ) + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "\n",
                "    with torch.no_grad():\n",
                "        sampled_rates = ae.decode(sampled_latents).cpu()\n",
                "\n",
                "    sampled_rates = rearrange(sampled_rates, \"b c l -> (b l) c\").numpy()\n",
                "    predicted_sampled_behavior = RidgeRegressionModel.predict(sampled_rates)\n",
                "\n",
                "    return {\n",
                "        \"predicted_val_behavior\": rearrange(\n",
                "            predicted_behavior, \"(b l) c -> b c l\", b=bs_val\n",
                "        ),\n",
                "        \"real_val_behavior\": rearrange(val_behavior, \"(b l) c -> b c l\", b=bs_val),\n",
                "        \"predicted_sampled_behavior\": rearrange(predicted_sampled_behavior, \"(b l) c -> b c l\", b=num_samples),\n",
                "    }\n",
                "\n",
                "\n",
                "\n",
                "behav_dict = gen_rates_and_train_decoded_behavior(\n",
                "    ema_denoiser=ema_model,\n",
                "    scheduler=scheduler,\n",
                "    ae=ae_model,\n",
                "    cfg=cfg,\n",
                "    train_latent_dataloader=train_latent_dataloader,\n",
                "    val_latent_dataloader=test_latent_dataloader,\n",
                "    num_samples=20,\n",
                "    device=\"cuda\",\n",
                ")\n",
                "\n",
                "\n",
                "predicted_val_behavior = behav_dict[\"predicted_val_behavior\"]\n",
                "real_val_behavior = behav_dict[\"real_val_behavior\"]\n",
                "predicted_sampled_behavior = behav_dict[\"predicted_sampled_behavior\"]\n",
                "\n",
                "predicted_traj = np.cumsum(predicted_val_behavior, axis=-1)\n",
                "real_traj = np.cumsum(real_val_behavior, axis=-1)\n",
                "predicted_sampled_traj = np.cumsum(predicted_sampled_behavior, axis=-1)\n",
                "\n",
                "print(predicted_traj.shape, real_traj.shape, predicted_sampled_traj.shape)\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(10, 5))\n",
                "\n",
                "colors = sns.color_palette(\"tab10\")\n",
                "for i, idx in enumerate(range(0, 70, 7)):\n",
                "    ax[0].plot(\n",
                "        real_traj[idx, 0, :], real_traj[idx, 1, :], label=\"Real trajectory\", color=colors[i]\n",
                "    )\n",
                "    ax[0].plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"Predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=colors[i],\n",
                "    )\n",
                "ax[0].set_xlabel(\"X position\")\n",
                "ax[0].set_ylabel(\"Y position\")\n",
                "ax[0].set_title(\"Real vs predicted trajectory (Val)\")\n",
                "\n",
                "colors = sns.color_palette(\"tab20\")\n",
                "for i, idx in enumerate(range(0, 20)):\n",
                "    ax[1].plot(\n",
                "        predicted_sampled_traj[idx, 0, :],\n",
                "        predicted_sampled_traj[idx, 1, :],\n",
                "        # label=\"Predicted trajectory\",\n",
                "        linestyle=\"-\",\n",
                "        color=colors[i],\n",
                "    )\n",
                "ax[1].set_xlabel(\"X position\")\n",
                "ax[1].set_ylabel(\"Y position\")\n",
                "ax[1].set_title(\"predicted trajectory (Sampled)\")\n",
                "\n",
                "fig.tight_layout()\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(10, 5))\n",
                "\n",
                "colors = sns.color_palette(\"tab20\")\n",
                "for i, idx in enumerate(range(0, 70, 1)):\n",
                "    ax[0].plot(\n",
                "        real_traj[idx, 0, :], real_traj[idx, 1, :], label=\"real trajectory\", color=colors[i%20]\n",
                "    )\n",
                "    ax[0].plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"Predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=colors[i%20],\n",
                "    )\n",
                "ax[0].set_xlabel(\"x position\")\n",
                "ax[0].set_ylabel(\"y position\")\n",
                "ax[0].set_title(\"real vs predicted trajectory\")\n",
                "\n",
                "colors = sns.color_palette(\"tab20\")\n",
                "for i, idx in enumerate(range(0, 50)):\n",
                "    ax[1].plot(\n",
                "        predicted_sampled_traj[idx, 0, :],\n",
                "        predicted_sampled_traj[idx, 1, :],\n",
                "        # label=\"Predicted trajectory\",\n",
                "        linestyle=\"-\",\n",
                "        color=colors[i%20],\n",
                "    )\n",
                "ax[1].set_xlabel(\"x position\")\n",
                "ax[1].set_ylabel(\"y position\")\n",
                "ax[1].set_title(\"unconditional diffusion\")\n",
                "\n",
                "fig.tight_layout()\n",
                "\n",
                "plt.savefig(save_path + \"unconditional_figure_no_color_bar_thick.png\")\n",
                "plt.savefig(save_path + \"unconditional_figure_no_color_bar_thick.pdf\")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Plotting autoencoder results"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "from ntldm.utils.eval_utils import reconstruct_spikes\n",
                "rec_dict = reconstruct_spikes(ae_model, test_dataloader)\n",
                "\n",
                "# plot reconstructed spikes\n",
                "plt.figure(figsize=cm2inch((6, 4)))\n",
                "bins = np.linspace(0, 15, 15)\n",
                "plt.hist(rec_dict['spikes'].sum(1).flatten(), density=True, color='grey', bins=bins, alpha=0.5)\n",
                "plt.hist(rec_dict['rec_spikes'].sum(1).flatten(), density=True, color='darkblue', bins=bins, alpha=0.5)\n",
                "\n",
                "plt.legend(['gt', 'ae'])\n",
                "plt.title('population spike count (test set)')\n",
                "plt.ylabel('frequency')\n",
                "plt.xlabel('number of spikes per time bin')\n",
                "plt.savefig(save_path + 'population_spike_count_test.png')\n",
                "plt.savefig(save_path + 'population_spike_count_test.pdf')\n",
                "\n",
                "\n",
                "\n",
                "# plot reconstructed spikes\n",
                "plt.figure(figsize=cm2inch((6, 4)))\n",
                "bins = np.linspace(0, 20, 20)\n",
                "plt.hist(rec_dict['spikes'].sum(2).flatten(), density=True, color='grey', bins=bins, alpha=0.5)\n",
                "plt.hist(rec_dict['rec_spikes'].sum(2).flatten(), density=True, color='darkblue', bins=bins, alpha=0.5)\n",
                "\n",
                "plt.legend(['gt', 'ae'])\n",
                "plt.title('spike count distribution (test set)')\n",
                "plt.ylabel('frequency')\n",
                "plt.xlabel('number of spikes per trial')\n",
                "plt.savefig(save_path + 'trial_spike_count_test.png')\n",
                "plt.savefig(save_path + 'trial_spike_count_test.pdf')\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save=True\n",
                "idx=0\n",
                "plot_rate_traces(\n",
                "        ae_model,\n",
                "        test_dataloader,\n",
                "        idx=idx,\n",
                "        figsize=cm2inch(10, 10),\n",
                "        true_data=True,\n",
                "        save=save,\n",
                "        save_path=save_path + \"rate_traces\",\n",
                "    )\n",
                "\n",
                "plot_spikes_next_to_each_other(\n",
                "    ae_model,\n",
                "    test_dataloader,\n",
                "    idx=idx,\n",
                "    figsize=cm2inch(8, 4),\n",
                "    save=save,\n",
                "    save_path=save_path + \"spikes_next_to_each_other\",\n",
                ")\n",
                "\n",
                "# run all sorts of analyses\n",
                "plot_inferred_latents(\n",
                "    ae_model,\n",
                "    test_dataloader,\n",
                "    n_latents=8,\n",
                "    y_stack=4,\n",
                "    figsize=cm2inch(10, 10),\n",
                "    color=\"royalblue\",\n",
                "    idx=idx,\n",
                "    save=save,\n",
                "    save_path=save_path + \"inferred_latents\",\n",
                "    )\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# Plot autoencoder reconstructions of reach directions moved to conditional diffusion plotting"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "## sanity check: ae decoded signal, mapped to behavior (trained on train data) should be similar to the actual behavior\n",
                "\n",
                "from ntldm.utils.behav_eval_utils import compute_decoded_behavior\n",
                "\n",
                "\n",
                "behav_dict = compute_decoded_behavior(\n",
                "    ae_model, train_latent_dataloader, test_latent_dataloader\n",
                ")\n",
                "predicted_val_behavior = behav_dict[\"predicted_val_behavior\"]\n",
                "real_val_behavior = behav_dict[\"real_val_behavior\"]\n",
                "predicted_traj = np.cumsum(predicted_val_behavior, axis=-1)\n",
                "real_traj = np.cumsum(real_val_behavior, axis=-1)\n",
                "# get the reach angle \n",
                "angle_real_val_behavior = np.arctan2(real_traj[:, 1, 50], real_traj[:, 0, 50])\n",
                "\n",
                "print(predicted_traj.shape, real_traj.shape)\n",
                "plt.figure(figsize=cm2inch(8, 6))\n",
                "\n",
                "for i, idx in enumerate(range(0, 70, 5)):\n",
                "    plt.plot(\n",
                "        real_traj[idx, 0, :], real_traj[idx, 1, :], label=\"real trajectory\", color=angle_to_color(angle_real_val_behavior[idx]),\n",
                "    )\n",
                "    plt.plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=angle_to_color(angle_real_val_behavior[idx]),\n",
                "    )\n",
                "plt.xlabel(\"x position\")\n",
                "plt.ylabel(\"y position\")\n",
                "# plt.legend()\n",
                "plt.title(\"real vs predicted trajectory\")\n",
                "\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "# set ticks formatting for colorbar in terms of pi\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), label=\"behavior angle\", ax=plt.gca())\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "plt.xlabel('x')\n",
                "plt.ylabel('y')\n",
                "plt.savefig(save_path + \"test_reaches_and_recs_colored_by_angles.png\")\n",
                "plt.savefig(save_path + \"test_reaches_and_recs_colored_by_angles.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "#create cmap based on behavior angles and plot behavior_cumsum on 2d plot \n",
                "cmap = plt.get_cmap(\"hsv\")\n",
                "colors = cmap((latent_dataset_train.behavior_angles.squeeze() + np.pi) / (2 * np.pi))\n",
                "plt.figure(figsize=cm2inch(8, 6))\n",
                "\n",
                "for i in range(0, len(latent_dataset_train.behavior), 10):\n",
                "    plt.plot(\n",
                "        latent_dataset_train.behavior_cumsum[i, 0],\n",
                "        latent_dataset_train.behavior_cumsum[i, 1],\n",
                "        color=colors[i],\n",
                "        alpha=0.3,\n",
                "        lw=0.4\n",
                "    )\n",
                "\n",
                "# switch off axis\n",
                "plt.axis(\"off\")\n",
                "# colormap\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "# set ticks formatting for colorbar in terms of pi\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), label=\"behavior angle\", ax=plt.gca())\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "plt.xlabel('x')\n",
                "plt.ylabel('y')\n",
                "plt.savefig(save_path + \"all_reaches_colored_by_angles.png\")\n",
                "plt.savefig(save_path + \"all_reaches_colored_by_angles.pdf\")\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "timeseries",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.8.18"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
