{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "%load_ext autoreload\n",
                "%autoreload 2\n",
                "\n",
                "\n",
                "import argparse\n",
                "import os\n",
                "import sys\n",
                "\n",
                "ANONAUTHORTrue\n",
                "\n",
                "if ANONAUTHOR:\n",
                "    os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
                "    # append parent directory to path (../notebooks -> ..)\n",
                "    sys.path.append(os.path.dirname(os.getcwd()))\n",
                "    os.chdir(os.path.dirname(os.getcwd()))\n",
                "\n",
                "else:\n",
                "    os.chdir('../')\n",
                "\n",
                "\n",
                "import accelerate\n",
                "import auraloss  # freq loss\n",
                "import lovely_tensors as lt\n",
                "import matplotlib.pyplot as plt\n",
                "import matplotlib\n",
                "import numpy as np\n",
                "import scipy.io as io\n",
                "import seaborn as sns\n",
                "import torch\n",
                "import torch.nn as nn\n",
                "import wandb\n",
                "import yaml\n",
                "from diffusers.optimization import get_scheduler\n",
                "from omegaconf import OmegaConf\n",
                "from scipy.signal import welch\n",
                "from tqdm.auto import tqdm\n",
                "from einops import rearrange\n",
                "\n",
                "from ntldm.networks import AutoEncoder, CountWrapper\n",
                "from ntldm.utils.plotting_utils import *\n",
                "from ntldm.losses import latent_regularizer\n",
                "from ntldm.networks import Denoiser\n",
                "from diffusers.training_utils import EMAModel\n",
                "from diffusers.schedulers import DDPMScheduler\n",
                "# always run from ../ntldm\n",
                "\n",
                "\n",
                "lt.monkey_patch()\n",
                "matplotlib.rc_file('matplotlibrc')\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import diffusers\n",
                "print(diffusers.__version__)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "## load config and model path\n",
                "\n",
                "# cfg_ae = OmegaConf.load(\"conf/sweeps_new/Phoneme_autoencoder-count_s4-phoneme_v2loss.yaml\")\n",
                "cfg_ae = OmegaConf.load(\"conf/sweeps_new/Phoneme_autoencoder-count_s4-phoneme_v2loss_bigger_lesstktd.yaml\")\n",
                "# cfg_ae = OmegaConf.load(\"conf/sweeps_new/Phoneme_autoencoder-count_s4-phoneme_v2loss_big_lesstk.yaml\")\n",
                "\n",
                "\n",
                "cfg_yaml = \"\"\"\n",
                "denoiser_model:\n",
                "  C_in: 32\n",
                "  C: 384\n",
                "  kernel: s4\n",
                "  num_blocks: 6\n",
                "  bidirectional: True\n",
                "  num_train_timesteps: 1000\n",
                "training:\n",
                "  lr: 0.001\n",
                "  num_epochs: 2000\n",
                "  num_warmup_epochs: 100\n",
                "  batch_size: 256\n",
                "  random_seed: 42\n",
                "  precision: \"no\"\n",
                "exp_name: diffusion_s4-phoneme_newbigger\n",
                "\"\"\"\n",
                "cfg = OmegaConf.create(yaml.safe_load(cfg_yaml))\n",
                "cfg.dataset = cfg_ae.dataset\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Plan\n",
                "### 1. Create dataset containing sentences, sentence embeddings, spike dataset, latents from ae\n",
                "\n",
                "### 2. Train diffusion to generate rates unconditional.\n",
                "\n",
                "### 3. Train diffusion to generate rates conditioned on sentence embeddings.\n",
                "\n",
                "### 4. Check if these spikes have highest loglik under the appropriate sentence compare to other sentences"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "\n",
                "### 0. Load autoencoder (with checkpoint) and autoencoder dataset (run for all points 2-5)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "    \n",
                "\n",
                "\n",
                "import math\n",
                "from ntldm.data.phoneme import get_phoneme_dataloaders\n",
                "\n",
                "# set seed\n",
                "torch.manual_seed(cfg.training.random_seed)\n",
                "np.random.seed(cfg.training.random_seed)\n",
                "\n",
                "train_dataloader, val_dataloader, test_dataloader = get_phoneme_dataloaders(\n",
                "        cfg_ae.dataset.datapath, batch_size=cfg_ae.training.batch_size\n",
                "    )\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "ae_model = AutoEncoder(\n",
                "    C_in=cfg_ae.model.C_in,\n",
                "    C=cfg_ae.model.C,\n",
                "    C_latent=cfg_ae.model.C_latent,\n",
                "    L=cfg_ae.dataset.max_seqlen,\n",
                "    kernel=cfg_ae.model.kernel,\n",
                "    num_blocks=cfg_ae.model.num_blocks,\n",
                "    num_blocks_decoder=cfg_ae.model.get(\"num_blocks_decoder\", cfg_ae.model.num_blocks),\n",
                "    num_lin_per_mlp=cfg_ae.model.get(\"num_lin_per_mlp\", 2),  # default 2\n",
                "    bidirectional=cfg_ae.model.get(\"bidirectional\", False),\n",
                ")\n",
                "\n",
                "ae_model = CountWrapper(ae_model, use_sin_enc=cfg_ae.model.get(\"use_sin_enc\", False))\n",
                "\n",
                "\n",
                "accelerator = accelerator = accelerate.Accelerator(\n",
                "    mixed_precision=\"no\",\n",
                "    log_with=\"wandb\",\n",
                ")\n",
                "#\n",
                "\n",
                "# prepare the ae model and dataset\n",
                "\n",
                "ae_model = accelerator.prepare(ae_model)\n",
                "\n",
                "print(cfg_ae.exp_name)\n",
                "accelerator.load_state(f\"exp/{cfg_ae.exp_name}/epoch_200\")  # best checkpoint\n",
                "\n",
                "(\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ") = accelerator.prepare(\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ")\n",
                "\n",
                "\n",
                "def reconstruct_spikes(model, dataloader):\n",
                "    model.eval()\n",
                "    latents = []\n",
                "    spikes = []\n",
                "    rec_spikes = []\n",
                "    signal_masks = []\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        signal_mask = batch[\"mask\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates, z = model(signal)\n",
                "            z = z.cpu()\n",
                "        latents.append(z)\n",
                "        spikes.append(signal.cpu())\n",
                "        rec_spikes.append(torch.poisson(output_rates.cpu()) * signal_mask.cpu())\n",
                "        signal_masks.append(signal_mask.cpu())\n",
                "\n",
                "    return {\n",
                "        \"latents\": torch.cat(latents, 0),\n",
                "        \"spikes\": torch.cat(spikes, 0),\n",
                "        \"rec_spikes\": torch.cat(rec_spikes, 0),\n",
                "        \"signal_masks\": torch.cat(signal_masks, 0),\n",
                "    }\n",
                "\n",
                "\n",
                "rec_dict = reconstruct_spikes(ae_model, test_dataloader)\n",
                "\n",
                "\n",
                "# plot reconstructed spikes\n",
                "plt.figure(figsize=cm2inch((6, 4)))\n",
                "bins = np.linspace(0, 1000, 100)\n",
                "plt.hist(\n",
                "    (rec_dict[\"spikes\"] * rec_dict[\"signal_masks\"]).sum(2).flatten(),\n",
                "    density=True,\n",
                "    color=\"grey\",\n",
                "    bins=bins,\n",
                "    alpha=0.5,\n",
                ")\n",
                "plt.hist(\n",
                "    (rec_dict[\"rec_spikes\"] * rec_dict[\"signal_masks\"]).sum(2).flatten(),\n",
                "    density=True,\n",
                "    color=\"darkblue\",\n",
                "    bins=bins,\n",
                "    alpha=0.5,\n",
                ")\n",
                "\n",
                "plt.legend([\"gt\", \"ae\"])\n",
                "plt.title(\"spike count distribution (test set)\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "# plot reconstructed spikes\n",
                "plt.figure(figsize=cm2inch((6, 4)))\n",
                "bins = np.linspace(1, 200, 199)\n",
                "counts, bins, patches = plt.hist(\n",
                "    (rec_dict[\"spikes\"] * rec_dict[\"signal_masks\"]).sum(1).flatten(),\n",
                "    density=True,\n",
                "    color=\"grey\",\n",
                "    bins=bins,\n",
                "    alpha=0.5,\n",
                ")\n",
                "plt.hist(\n",
                "    (rec_dict[\"rec_spikes\"] * rec_dict[\"signal_masks\"]).sum(1).flatten(),\n",
                "    density=True,\n",
                "    color=\"darkblue\",\n",
                "    bins=bins,\n",
                "    alpha=0.5,\n",
                ")\n",
                "plt.xlim(20, 150)\n",
                "plt.yticks([])\n",
                "plt.legend([\"gt\", \"ae\"])\n",
                "plt.gca().spines['left'].set_visible(False)\n",
                "plt.title(\"spike count distribution (test set)\")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### 1. Create dataset containing sentences, sentence embeddings, spike dataset, latents from ae"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# create the latent dataset\n",
                "class LatentPhonemeDataset(torch.utils.data.Dataset):\n",
                "    def __init__(\n",
                "        self, dataloader, ae_model, clip=True, latent_means=None, latent_stds=None\n",
                "    ):\n",
                "        self.full_dataloader = dataloader\n",
                "        self.ae_model = ae_model\n",
                "        (\n",
                "            self.latents,\n",
                "            self.train_spikes,\n",
                "            self.train_spike_masks,\n",
                "            self.embeddings,\n",
                "            self.embedding_masks,\n",
                "            self.original_sentences,\n",
                "            self.phonemized_sentences,\n",
                "        ) = self.create_latents()\n",
                "\n",
                "        # normalize to N(0, 1)\n",
                "        if latent_means is None or latent_stds is None:\n",
                "            print(self.latents.shape, self.train_spike_masks.shape)\n",
                "            masked_sum = (\n",
                "                self.latents * self.train_spike_masks[:, : self.latents.shape[1], :]\n",
                "            ).sum(dim=(0, 2))\n",
                "            masked_count = self.train_spike_masks[:, : self.latents.shape[1], :].sum(\n",
                "                dim=(0, 2)\n",
                "            )\n",
                "            latent_means = masked_sum / masked_count\n",
                "\n",
                "            # compute masked variance\n",
                "            masked_square_sum = (\n",
                "                self.latents.pow(2)\n",
                "                * self.train_spike_masks[:, : self.latents.shape[1], :]\n",
                "            ).sum(dim=(0, 2))\n",
                "            latent_means_sq = latent_means.pow(2)\n",
                "            masked_variance = (masked_square_sum / masked_count) - latent_means_sq\n",
                "\n",
                "            latent_stds = masked_variance.sqrt().unsqueeze(0).unsqueeze(2)\n",
                "\n",
                "            self.latent_means = latent_means.unsqueeze(0).unsqueeze(2)\n",
                "            self.latent_stds = latent_stds\n",
                "\n",
                "        else:\n",
                "            self.latent_means = latent_means\n",
                "            self.latent_stds = latent_stds\n",
                "\n",
                "        # normalize latents channel-wise to N(0, 1)\n",
                "        self.latents = (self.latents - self.latent_means) / self.latent_stds\n",
                "\n",
                "        if clip:\n",
                "            self.latents = self.latents.clamp(-5, 5)\n",
                "\n",
                "    def symmetrically_pad_and_expand_embedding(\n",
                "        self, embedding, embedding_mask, latent_mask\n",
                "    ):\n",
                "        \"\"\"\n",
                "        NOTE: due to a bug in the phoneme dataset, the embeddings\n",
                "        from the original dataset are padded on the left.\n",
                "\n",
                "        \"\"\"\n",
                "        L, C = embedding.shape\n",
                "\n",
                "        embedding = embedding.permute(1, 0)  # [C, L]\n",
                "        embedding_mask = embedding_mask.permute(1, 0)  # [C, L]\n",
                "\n",
                "        embedding_len = int(embedding_mask[0].sum(0).item())  # l < L\n",
                "        # print('embedding_len', embedding_len)\n",
                "\n",
                "        pad_left = (L - embedding_len) // 2\n",
                "        pad_right = L - embedding_len - pad_left\n",
                "\n",
                "        embedding = torch.nn.functional.pad(\n",
                "            embedding[:, L - embedding_len :], (pad_left, pad_right), mode=\"replicate\"\n",
                "        )\n",
                "        embedding_mask = torch.zeros_like(embedding[:1])  # [1, L]\n",
                "        embedding_mask[:, pad_left : pad_left + embedding_len] = 1\n",
                "\n",
                "        latent_max_len = latent_mask.shape[-1]  # Ll > L\n",
                "\n",
                "        # interpolate with nearest neighbors in the time dim (L)\n",
                "        # [C, L] -> [C, Ll]\n",
                "\n",
                "        embedding = torch.nn.functional.interpolate(\n",
                "            embedding.unsqueeze(0), (latent_max_len), mode=\"nearest\"\n",
                "        ).squeeze(\n",
                "            0\n",
                "        )  # [C, L] -> [1, C, L] -> [1, C, Ll] -> [C, Ll]\n",
                "\n",
                "        embedding_mask = torch.nn.functional.interpolate(\n",
                "            embedding_mask.unsqueeze(0), (latent_max_len), mode=\"nearest\"\n",
                "        ).squeeze(\n",
                "            0\n",
                "        )  # [1, L] -> [1, 1, Ll] -> [1, 1, Ll] -> [1, Ll] -> [C, Ll]\n",
                "\n",
                "        return embedding, embedding_mask\n",
                "\n",
                "    def symmetrically_pad_latent(self, latent, latent_mask):\n",
                "        C, L = latent.shape\n",
                "\n",
                "        latent_len = int(latent_mask[0].sum().item())\n",
                "\n",
                "        pad_left = (L - latent_len) // 2\n",
                "        pad_right = L - latent_len - pad_left\n",
                "\n",
                "        latent = torch.nn.functional.pad(\n",
                "            latent[:, :latent_len], (pad_left, pad_right), mode=\"replicate\"\n",
                "        )\n",
                "        latent_mask = torch.zeros_like(latent[:1])  # [1, L]\n",
                "        latent_mask[:, pad_left : pad_left + latent_len] = 1\n",
                "\n",
                "        return latent, latent_mask\n",
                "\n",
                "    def create_latents(self):\n",
                "        latent_dataset = []\n",
                "\n",
                "        train_spikes = []\n",
                "        train_spike_masks = []\n",
                "\n",
                "        embeddings = []\n",
                "        embedding_masks = []\n",
                "\n",
                "        original_sentences = []\n",
                "        phonemized_sentences = []\n",
                "\n",
                "        ## dataset output:\n",
                "        # def __getitem__(self, idx):\n",
                "        #     return {\n",
                "        #         \"original_sentence\": self.original_sentences[idx],\n",
                "        #         \"phonemized_sentence\": self.phonemized_sentences[idx],\n",
                "        #         \"embedding\": self.embeddings[idx],\n",
                "        #         \"embedding_mask\": self.embedding_masks[idx],\n",
                "        #         \"signal\": (self.spikes[idx].T if self.time_last else self.spikes[idx]),\n",
                "        #         \"mask\": (self.masks[idx].T if self.time_last else self.masks[idx]),\n",
                "        #     }\n",
                "\n",
                "        self.ae_model.eval()\n",
                "        for i, batch in tqdm(\n",
                "            enumerate(self.full_dataloader),\n",
                "            total=len(self.full_dataloader),\n",
                "            desc=\"Creating latent dataset\",\n",
                "        ):\n",
                "            with torch.no_grad():\n",
                "                z = self.ae_model.encode(batch[\"signal\"])\n",
                "                latent_dataset.append(z.cpu())\n",
                "\n",
                "                train_spikes.append(batch[\"signal\"].cpu())\n",
                "                train_spike_masks.append(batch[\"mask\"].cpu())\n",
                "\n",
                "                embeddings.append(batch[\"embedding\"].cpu())\n",
                "                embedding_masks.append(batch[\"embedding_mask\"].cpu())\n",
                "\n",
                "                original_sentences.extend(batch[\"original_sentence\"])\n",
                "                phonemized_sentences.extend(batch[\"phonemized_sentence\"])\n",
                "\n",
                "        return (\n",
                "            torch.cat(latent_dataset),\n",
                "            torch.cat(train_spikes),\n",
                "            torch.cat(train_spike_masks),\n",
                "            torch.cat(embeddings),\n",
                "            torch.cat(embedding_masks),\n",
                "            original_sentences,\n",
                "            phonemized_sentences,\n",
                "        )\n",
                "\n",
                "    def __len__(self):\n",
                "        return len(self.latents)\n",
                "\n",
                "    def __getitem__(self, idx):\n",
                "        embedding, embedding_mask = self.symmetrically_pad_and_expand_embedding(\n",
                "            self.embeddings[idx], self.embedding_masks[idx], self.train_spike_masks[idx]\n",
                "        )\n",
                "        latent, latent_mask = self.symmetrically_pad_latent(\n",
                "            self.latents[idx], self.train_spike_masks[idx]\n",
                "        )\n",
                "\n",
                "        return {\n",
                "            \"signal\": self.train_spikes[idx],  # not symmetrically padded\n",
                "            \"latent\": latent,\n",
                "            \"mask\": latent_mask,\n",
                "            \"embedding\": embedding,\n",
                "            \"embedding_mask\": embedding_mask,\n",
                "            \"original_sentence\": self.original_sentences[idx],\n",
                "            \"phonemized_sentence\": self.phonemized_sentences[idx],\n",
                "        }\n",
                "\n",
                "\n",
                "latent_dataset_train = LatentPhonemeDataset(train_dataloader, ae_model, clip=False)\n",
                "\n",
                "display(latent_dataset_train[0])\n",
                "display(latent_dataset_train[1])\n",
                "\n",
                "latent_dataset_val = LatentPhonemeDataset(\n",
                "    val_dataloader,\n",
                "    ae_model,\n",
                "    latent_means=latent_dataset_train.latent_means,\n",
                "    latent_stds=latent_dataset_train.latent_stds,\n",
                "    clip=False,\n",
                ")\n",
                "latent_dataset_test = LatentPhonemeDataset(\n",
                "    test_dataloader,\n",
                "    ae_model,\n",
                "    latent_means=latent_dataset_train.latent_means,\n",
                "    latent_stds=latent_dataset_train.latent_stds,\n",
                "    clip=False,\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "display(latent_dataset_train[0])\n",
                "display(latent_dataset_train[1])\n",
                "\n",
                "element = latent_dataset_train[0]\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.plot(element['latent'][0])\n",
                "plt.plot(element['mask'][0])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# def strip_and_symmetrically_pad_latents(latents, latent_masks):\n",
                "#     # strip latents to the length of the longest latent\n",
                "    \n",
                "#     max_len = latent_masks.shape[-1]  # L\n",
                "\n",
                "#     latent_lengths = latent_masks[:, 0].sum(-1)\n",
                "\n",
                "#     new_latents = []\n",
                "#     new_latent_masks = []\n",
                "\n",
                "#     for i, (l, l_mask) in enumerate(zip(latents, latent_masks)):\n",
                "#         # pad the latents symmetrically\n",
                "#         pad_left = (max_len - l_mask[0].sum().item()) // 2\n",
                "#         pad_right = (max_len - l_mask[0].sum().item()) - pad_left\n",
                "#         pad_left = int(pad_left)\n",
                "#         pad_right = int(pad_right)\n",
                "#         l = l[:, : int(l_mask[0].sum().item())]\n",
                "#         l_mask_sum = int(l_mask[0].sum().item())\n",
                "#         l_mask = torch.zeros(1, max_len)\n",
                "#         l_mask[:, pad_left : pad_left + l_mask_sum] = 1\n",
                "#         # print(pad_left, pad_right, pad_left + l_mask_sum)\n",
                "\n",
                "#         new_latents.append(\n",
                "#             nn.functional.pad(\n",
                "#                 l,\n",
                "#                 (pad_left, pad_right),\n",
                "#                 mode=\"replicate\",\n",
                "#             )\n",
                "#         )\n",
                "#         new_latent_masks.append(l_mask)\n",
                "\n",
                "#     return torch.stack(new_latents), torch.stack(\n",
                "#         new_latent_masks\n",
                "#     )  # [B, C, L] and [B, 1, L]\n",
                "\n",
                "\n",
                "# # test\n",
                "\n",
                "# latents = batch[\"latent\"]\n",
                "# latent_masks = batch[\"mask\"]\n",
                "\n",
                "# latents, latent_masks = strip_and_symmetrically_pad_latents(latents, latent_masks)\n",
                "\n",
                "# plt.plot(batch[\"mask\"][0, 0].cpu(), label=\"mask before\")\n",
                "# plt.plot(latent_masks[0, 0].cpu(), label=\"mask after\")\n",
                "# plt.legend()\n",
                "# plt.show()\n",
                "\n",
                "\n",
                "# def strip_and_symmetrically_pad_embeddings(embeddings, embedding_masks):\n",
                "\n",
                "#     embeddings = rearrange(embeddings, \"b l c -> b c l\")\n",
                "#     embedding_masks = rearrange(embedding_masks, \"b l c -> b c l\")\n",
                "\n",
                "#     # strip embeddings to the length of the longest embedding\n",
                "#     max_len = embedding_masks.shape[-1]  # L\n",
                "\n",
                "#     embedding_lengths = embedding_masks[:, 0].sum(-1)\n",
                "\n",
                "#     new_embeddings = []\n",
                "#     new_embedding_masks = []\n",
                "\n",
                "#     for i, (e, e_mask) in enumerate(zip(embeddings, embedding_masks)):\n",
                "#         # pad the embeddings symmetrically\n",
                "#         pad_left = (max_len - e_mask[0].sum().item()) // 2\n",
                "#         pad_right = (max_len - e_mask[0].sum().item()) - pad_left\n",
                "#         pad_left = int(pad_left)\n",
                "#         pad_right = int(pad_right)\n",
                "#         e = e[:, : int(e_mask[0].sum().item())]\n",
                "#         e_mask_sum = int(e_mask[0].sum().item())\n",
                "#         e_mask = torch.zeros(1, max_len)\n",
                "#         e_mask[:, pad_left : pad_left + e_mask_sum] = 1\n",
                "#         print(pad_left, pad_right, pad_left + e_mask_sum)\n",
                "\n",
                "#         new_embeddings.append(\n",
                "#             nn.functional.pad(\n",
                "#                 e,\n",
                "#                 (pad_left, pad_right),\n",
                "#                 mode=\"replicate\",\n",
                "#             )\n",
                "#         )\n",
                "#         new_embedding_masks.append(e_mask)\n",
                "\n",
                "#     return torch.stack(new_embeddings), torch.stack(\n",
                "#         new_embedding_masks\n",
                "#     )  # [B, C, L] and [B, 1, L]\n",
                "\n",
                "\n",
                "# # test\n",
                "\n",
                "# embeddings = (batch[\"embedding\"])\n",
                "# embedding_masks = batch[\"embedding_mask\"]\n",
                "\n",
                "# embeddings, embedding_masks = strip_and_symmetrically_pad_embeddings(embeddings, embedding_masks)\n",
                "\n",
                "# plt.plot(rearrange(batch[\"embedding_mask\"], 'b l c -> b c l')[0, 0].cpu(), label=\"mask before\")\n",
                "# plt.plot(embedding_masks[0, 0].cpu(), label=\"mask after\")\n",
                "# plt.legend()\n",
                "# plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "print(\"latent dataset\", latent_dataset_train.latents)\n",
                "print(\"latent dataset means\", latent_dataset_train.latent_means)\n",
                "print(\"latent dataset stds\", latent_dataset_train.latent_stds)\n",
                "plt.figure(figsize=cm2inch(5, 3))\n",
                "hist = plt.hist(latent_dataset_train.latents[:10].flatten(), bins=200, density=True, alpha=0.5)\n",
                "hist = plt.hist(latent_dataset_val.latents[:10].flatten(), bins=200, density=True, alpha=0.5)\n",
                "plt.legend()\n",
                "plt.title(\"Latent dataset histogram\")\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "train_latent_dataloader = torch.utils.data.DataLoader(\n",
                "    latent_dataset_train,\n",
                "    batch_size=cfg.training.batch_size,\n",
                "    shuffle=True,\n",
                "    num_workers=4,\n",
                "    pin_memory=True,\n",
                ")\n",
                "\n",
                "val_latent_dataloader = torch.utils.data.DataLoader(\n",
                "    latent_dataset_val,\n",
                "    batch_size=cfg.training.batch_size,\n",
                "    shuffle=False,\n",
                "    num_workers=4,\n",
                "    pin_memory=True,\n",
                ")\n",
                "\n",
                "num_batches = len(train_latent_dataloader)\n",
                "\n",
                "# check if signal length is power of 2\n",
                "if cfg.dataset.max_seqlen & (cfg.dataset.max_seqlen - 1) != 0:\n",
                "    cfg.training.precision = \"no\"  # torch.fft doesnt support half if L!=2^x\n",
                "\n",
                "# prepare the denoiser model and dataset\n",
                "(\n",
                "    train_latent_dataloader,\n",
                "    val_latent_dataloader,\n",
                ") = accelerator.prepare(\n",
                "    train_latent_dataloader,\n",
                "    val_latent_dataloader,\n",
                ")\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "print(train_latent_dataloader.dataset.embedding_masks.shape)\n",
                "emb_lens = train_latent_dataloader.dataset.embedding_masks[:,:, 0].sum(-1)\n",
                "print(emb_lens)\n",
                "\n",
                "latent_lens = train_latent_dataloader.dataset.train_spike_masks[:,0].sum(-1)\n",
                "print(latent_lens)\n",
                "\n",
                "plt.scatter(emb_lens,latent_lens, alpha=0.1, s=1)\n",
                "\n",
                "# get best fit linear\n",
                "from scipy.stats import linregress\n",
                "slope, intercept, r_value, p_value, std_err = linregress(emb_lens,latent_lens)\n",
                "xs = np.linspace(emb_lens.min(), emb_lens.max(), 100)\n",
                "plt.plot(xs, intercept + slope*xs, 'k:', label=f'fitted line, l={slope:.2f}e+{intercept:.2f}')\n",
                "plt.legend()\n",
                "\n",
                "plt.xlabel(\"embedding length\")\n",
                "plt.ylabel(\"latent length\")\n",
                "\n",
                "plt.show()\n",
                "print(slope, intercept, r_value, p_value, std_err)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "latent_dataset_train[0], latent_dataset_train[1]"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "\n",
                "### 2. Train diffusion to unconditionally generate rates (and then spikes), check if these spikes follow the distribution of real spikes\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "## initialize (unconditional) denoiser\n",
                "\n",
                "from ntldm.networks import Denoiser\n",
                "\n",
                "denoiser = Denoiser(\n",
                "    C_in=cfg.denoiser_model.C_in + 1, # 1 for mask\n",
                "    C=cfg.denoiser_model.C,\n",
                "    L=cfg.dataset.max_seqlen,\n",
                "    kernel=cfg.denoiser_model.kernel,\n",
                "    num_blocks=cfg.denoiser_model.num_blocks,\n",
                "    bidirectional=cfg.denoiser_model.get(\"bidirectional\", True),\n",
                ")\n",
                "\n",
                "# initial values may be way off so better to scale down the output layer\n",
                "denoiser.conv_out.weight.data = denoiser.conv_out.weight.data * 0.1\n",
                "denoiser.conv_out.bias.data = denoiser.conv_out.bias.data * 0.1\n",
                "\n",
                "start_epoch = 0\n",
                "\n",
                "# # load previous checkpoint\n",
                "# from safetensors.torch import load_file\n",
                "# state_dict = load_file(f'exp/{cfg.exp_name}/epoch_2100/model_1.safetensors')\n",
                "# display(state_dict)\n",
                "# denoiser.load_state_dict(state_dict)\n",
                "# start_epoch = 1370\n",
                "finetune = False\n",
                "\n",
                "denoiser.load_state_dict(torch.load(f\"exp/{cfg.exp_name}/epoch_1999/ema_model.pt\"))\n",
                "\n",
                "    \n",
                "scheduler = DDPMScheduler(\n",
                "    num_train_timesteps=cfg.denoiser_model.num_train_timesteps,\n",
                "    clip_sample=False,\n",
                "    beta_schedule=\"linear\", # ddpm doesnt support cosine\n",
                ")\n",
                "\n",
                "\n",
                "optimizer = torch.optim.AdamW(\n",
                "    denoiser.parameters(), lr=cfg.training.lr\n",
                ")  # default wd=0.01 for now\n",
                "\n",
                "\n",
                "\n",
                "num_batches = len(train_latent_dataloader)\n",
                "lr_scheduler = get_scheduler(\n",
                "    name=\"cosine\",\n",
                "    optimizer=optimizer,\n",
                "    num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "    num_training_steps=num_batches * cfg.training.num_epochs * 1.3,  # total number of steps\n",
                ")\n",
                "\n",
                "# prepare the denoiser model and dataset\n",
                "(\n",
                "    denoiser,\n",
                "    optimizer,\n",
                "    lr_scheduler,\n",
                ") = accelerator.prepare(\n",
                "    denoiser,\n",
                "    optimizer,\n",
                "    lr_scheduler,\n",
                ")\n",
                "\n",
                "ema_model = EMAModel(denoiser, min_value=(0.9999 if finetune else 0))\n",
                "\n",
                "\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "from ntldm.utils import count_parameters\n",
                "\n",
                "print(count_parameters(denoiser)/1e6, \"M parameters\")\n",
                "\n",
                "# accelerator.load_state(f\"exp/{cfg.exp_name}/epoch_200\") # 200 because i restarted the traiinng with ep=0\n",
                "# ema_model = EMAModel(denoiser)\n",
                "\n",
                "\n",
                "ae_model, denoiser\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "def sample_spikes_with_mask(ema_denoiser, scheduler, ae, cfg, lengths=None, batch_size=1, device=\"cuda\"):\n",
                "    z_t = torch.randn(\n",
                "        (batch_size, cfg.denoiser_model.C_in, cfg.dataset.max_seqlen)\n",
                "    ).to(device)\n",
                "\n",
                "    if lengths is None:\n",
                "        lengths = torch.linspace(100, 512, batch_size).long().to(device)\n",
                "    else:\n",
                "        if isinstance(lengths, int):\n",
                "            lengths = torch.tensor([lengths] * batch_size).to(device)\n",
                "        elif isinstance(lengths, list):\n",
                "            lengths = torch.tensor(lengths).long().to(device)\n",
                "    \n",
                "    masks = torch.zeros(batch_size, cfg.dataset.max_seqlen).to(device)\n",
                "    for i, l in enumerate(lengths):\n",
                "        padding_left = (cfg.dataset.max_seqlen - l) // 2\n",
                "        padding_right = cfg.dataset.max_seqlen - l - padding_left\n",
                "        masks[i, padding_left:padding_left + l] = 1.0\n",
                "\n",
                "    masks = masks.unsqueeze(1)\n",
                "\n",
                "    ema_denoiser_avg = ema_denoiser.averaged_model\n",
                "    ema_denoiser_avg.eval()\n",
                "    scheduler.set_timesteps(cfg.denoiser_model.num_train_timesteps)\n",
                "\n",
                "    for t in tqdm(scheduler.timesteps, desc=\"Sampling DDPM (different masks)\"):\n",
                "        with torch.no_grad():\n",
                "            model_output = ema_denoiser_avg(\n",
                "                torch.cat([z_t, masks], dim=1), torch.tensor([t] * batch_size).to(device).long()\n",
                "            )[:, :-1]\n",
                "        z_t = scheduler.step(model_output, t, z_t, return_dict=False)[0]\n",
                "\n",
                "    z_t = z_t * latent_dataset_train.latent_stds.to(z_t.device) + latent_dataset_train.latent_means.to(z_t.device)\n",
                "\n",
                "    with torch.no_grad():\n",
                "        rates = ae.decode(z_t).cpu()\n",
                "    \n",
                "    spikes = torch.poisson(rates)\n",
                "\n",
                "    return {\n",
                "        \"rates\": rates,\n",
                "        \"spikes\": spikes,\n",
                "        \"latents\": z_t.cpu(),\n",
                "        \"masks\": masks.cpu(),\n",
                "        \"mask_lengths\": lengths,\n",
                "    }\n",
                "    \n",
                "    "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def plot_real_vs_sampled_rates_and_spikes(\n",
                "    real_rates,\n",
                "    sampled_rates,\n",
                "    real_spikes,\n",
                "    sampled_spikes,\n",
                "    real_masks,\n",
                "    sampled_masks,\n",
                "    batch_idx=0,\n",
                "):\n",
                "    B, C, L = real_rates.shape\n",
                "\n",
                "    fig, axs = plt.subplots(2, 2, figsize=cm2inch(12, 8), dpi=300)\n",
                "\n",
                "    real_rates = real_rates[batch_idx]\n",
                "    sampled_rates = sampled_rates[batch_idx]\n",
                "    real_spikes = real_spikes[batch_idx]\n",
                "    sampled_spikes = sampled_spikes[batch_idx]\n",
                "    real_masks = real_masks[batch_idx]\n",
                "    sampled_masks = sampled_masks[batch_idx]\n",
                "\n",
                "    real_mask_idx_with_1 = torch.arange(real_masks[0].nonzero().flatten().numel())\n",
                "    sampled_mask_idx_with_1 = sampled_masks[0].nonzero().flatten()\n",
                "    # print(real_mask_idx_with_1)\n",
                "\n",
                "    im = axs[0, 0].imshow(real_rates[:, real_mask_idx_with_1], cmap=\"viridis\", alpha=1.0, aspect=\"auto\")\n",
                "    axs[0, 0].set_title(\"Real rates\")\n",
                "    fig.colorbar(im, ax=axs[0, 0], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "\n",
                "    im = axs[0, 1].imshow(\n",
                "        sampled_rates[:, sampled_mask_idx_with_1], cmap=\"viridis\", alpha=1.0, aspect=\"auto\"\n",
                "    )\n",
                "    axs[0, 1].set_title(\"Sampled rates\")\n",
                "    fig.colorbar(im, ax=axs[0, 1], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "\n",
                "    im = axs[1, 0].imshow(real_spikes[:, real_mask_idx_with_1], cmap=\"Greys\", alpha=1.0, aspect=\"auto\")\n",
                "    axs[1, 0].set_title(\"Real spikes\")\n",
                "    fig.colorbar(im, ax=axs[1, 0], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "\n",
                "    im = axs[1, 1].imshow(\n",
                "        sampled_spikes[:, sampled_mask_idx_with_1], cmap=\"Greys\", alpha=1.0, aspect=\"auto\"\n",
                "    )\n",
                "    axs[1, 1].set_title(\"Sampled spikes\")\n",
                "    fig.colorbar(im, ax=axs[1, 1], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "\n",
                "    \n",
                "    print(\n",
                "        real_rates[:, real_mask_idx_with_1].shape,\n",
                "        sampled_rates[:, sampled_mask_idx_with_1].shape,\n",
                "        real_spikes[:, real_mask_idx_with_1].shape,\n",
                "        sampled_spikes[:, sampled_mask_idx_with_1].shape,\n",
                "    )\n",
                "\n",
                "    # add colorbars\n",
                "    for i, ax in enumerate(axs.flatten()):\n",
                "        if i% 2 != 0:\n",
                "            ax.set_yticks([])\n",
                "        # ax.set_yticks([])\n",
                "\n",
                "    fig.tight_layout()\n",
                "    plt.show()\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "# loss_fn = torch.nn.SmoothL1Loss(\n",
                "#     beta=0.04, reduction=\"none\"\n",
                "# )  # faster convergence than mse\n",
                "\n",
                "# start_epoch = 430\n",
                "\n",
                "# # wandb.init(project=\"ntldm\", entity=\"anon-project\")\n",
                "# # print('initialized wandb')\n",
                "# pbar = tqdm(range(start_epoch, cfg.training.num_epochs + start_epoch), desc=\"epochs\")\n",
                "# for epoch in pbar:\n",
                "#     for i, batch in enumerate(train_latent_dataloader):\n",
                "#         denoiser.train()\n",
                "#         optimizer.zero_grad()\n",
                "\n",
                "#         z = batch[\"latent\"]\n",
                "#         z_mask = batch[\"mask\"]\n",
                "#         embedding = batch[\"embedding\"]\n",
                "#         embedding_mask = batch[\"embedding_mask\"]\n",
                "\n",
                "#         t = torch.randint(\n",
                "#             0, cfg.denoiser_model.num_train_timesteps, (z.shape[0],), device=\"cpu\"\n",
                "#         ).long()\n",
                "\n",
                "#         noise = torch.randn_like(z)\n",
                "#         noisy_z = scheduler.add_noise(z, noise, t)\n",
                "\n",
                "\n",
                "#         noise_pred = denoiser(torch.cat([noisy_z, z_mask], dim=1), t)\n",
                "#         noise_pred = noise_pred[:,:-1] # remove the dim corresponding to conditioning mask\n",
                "\n",
                "\n",
                "#         loss = loss_fn(noise, noise_pred)\n",
                "#         loss = loss * z_mask  # mask out the padding\n",
                "#         loss = loss.mean()\n",
                "\n",
                "#         accelerator.backward(loss)\n",
                "#         accelerator.clip_grad_norm_(denoiser.parameters(), 1.0)\n",
                "\n",
                "#         optimizer.step()\n",
                "#         lr_scheduler.step()\n",
                "\n",
                "#         if i % 10 == 0:\n",
                "#             pbar.set_postfix(\n",
                "#                 {\n",
                "#                     \"loss\": loss.item(),\n",
                "#                     \"lr\": lr_scheduler.get_last_lr()[0],\n",
                "#                     \"epoch\": epoch,\n",
                "#                 }\n",
                "#             )\n",
                "#             wandb.log(\n",
                "#                 {\n",
                "#                     \"loss\": loss.item(),\n",
                "#                     \"lr\": lr_scheduler.get_last_lr()[0],\n",
                "#                     \"epoch\": epoch,\n",
                "#                 }\n",
                "#             )\n",
                "\n",
                "#         ema_model.step(denoiser)\n",
                "\n",
                "#     if (epoch) % 50 == 0:  # plot samples\n",
                "\n",
                "#         denoiser.eval()\n",
                "\n",
                "#         val_batch = next(iter(val_latent_dataloader))\n",
                "#         val_batch['mask'][:2,0].sum(-1)\n",
                "#         ret_dict = sample_spikes_with_mask(ema_model, scheduler, ae_model, cfg, lengths=[int(i.item()) for i in val_batch['mask'][:2,0].sum(-1)], batch_size=2, device=\"cuda\")\n",
                "#         with torch.no_grad():\n",
                "#             val_batch_rates = ae_model(val_batch[\"signal\"])[0].cpu()\n",
                "\n",
                "#         plot_real_vs_sampled_rates_and_spikes(\n",
                "#             val_batch_rates,\n",
                "#             ret_dict[\"rates\"],\n",
                "#             val_batch[\"signal\"].cpu(),\n",
                "#             ret_dict[\"spikes\"],\n",
                "#             val_batch[\"mask\"].cpu(),\n",
                "#             ret_dict[\"masks\"],\n",
                "#             batch_idx=0,\n",
                "#         )\n",
                "#         plot_real_vs_sampled_rates_and_spikes(\n",
                "#             val_batch_rates,\n",
                "#             ret_dict[\"rates\"],\n",
                "#             val_batch[\"signal\"].cpu(),\n",
                "#             ret_dict[\"spikes\"],\n",
                "#             val_batch[\"mask\"].cpu(),\n",
                "#             ret_dict[\"masks\"],\n",
                "#             batch_idx=1,\n",
                "#         )\n",
                "\n",
                "#         # sampled_latents = sample(\n",
                "#         #     ema_denoiser=ema_model,\n",
                "#         #     scheduler=scheduler,\n",
                "#         #     cfg=cfg,\n",
                "#         #     batch_size=2,\n",
                "#         #     device=\"cuda\",\n",
                "#         # )\n",
                "#         # sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "#         #     sampled_latents.device\n",
                "#         # ) + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "\n",
                "#         # real_latents = latent_dataset_train.latents[:2].cuda()\n",
                "#         # real_latents = real_latents * latent_dataset_train.latent_stds.to(\n",
                "#         #     real_latents.device\n",
                "#         # ) + latent_dataset_train.latent_means.to(real_latents.device)\n",
                "\n",
                "#         # with torch.no_grad():\n",
                "#         #     sampled_rates = ae_model.decode(sampled_latents).cpu()\n",
                "#         #     decoded_rates_from_real_latents = ae_model.decode(real_latents).cpu()\n",
                "\n",
                "#         # fig, ax = plt.subplots(1, 2, figsize=cm2inch(12, 4))\n",
                "#         # im = ax[0].imshow(sampled_rates[0], aspect=\"auto\")\n",
                "#         # ax[0].set_title(\"Sampled rates\")\n",
                "#         # fig.colorbar(im, ax=ax[0], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "\n",
                "#         # im = ax[1].imshow(decoded_rates_from_real_latents[0], aspect=\"auto\")\n",
                "#         # ax[1].set_title(\"Real rates\")\n",
                "#         # fig.colorbar(im, ax=ax[1], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "#         # fig.tight_layout()\n",
                "#         # plt.show()\n",
                "\n",
                "#         # # get avg spike count across neurons\n",
                "#         # real_spikes = latent_dataset_val.train_spikes.cpu()\n",
                "#         # gen_spikes, gen_rates = sample_spikes(\n",
                "#         #     ema_model,\n",
                "#         #     scheduler,\n",
                "#         #     ae_model,\n",
                "#         #     cfg,\n",
                "#         #     batch_size=real_spikes.shape[0] * 10,\n",
                "#         #     device=\"cuda\",\n",
                "#         # )\n",
                "\n",
                "#         # spike_count_means = real_spikes.mean(dim=(0, 2))\n",
                "#         # spike_count_stds = real_spikes.std(dim=(0, 2))\n",
                "#         # gen_spike_count_means = gen_spikes.mean(dim=(0, 2))\n",
                "#         # gen_spike_count_stds = gen_spikes.std(dim=(0, 2))\n",
                "\n",
                "#         # print(real_spikes, gen_spikes, spike_count_means, spike_count_stds)\n",
                "\n",
                "#         # # plot boxplots across neurons\n",
                "#         # plt.figure(figsize=cm2inch(4, 4))\n",
                "#         # plt.violinplot(\n",
                "#         #     [spike_count_means.numpy(), gen_spike_count_means.numpy()],\n",
                "#         #     positions=[1, 2],\n",
                "#         #     showmeans=True,\n",
                "#         # )\n",
                "#         # plt.ylabel(\"Mean spike count\")\n",
                "#         # plt.title(\"Mean spike count across neurons\")\n",
                "#         # plt.xticks([1, 2], [\"Real\", \"Generated\"])\n",
                "#         # plt.show()\n",
                "\n",
                "#         # # plot boxplots per neuron\n",
                "#         # from einops import reduce\n",
                "\n",
                "#         # spike_count_per_neuron = reduce(real_spikes, \"B C L -> B C\", reduction=\"mean\")\n",
                "#         # gen_spike_count_per_neuron = reduce(\n",
                "#         #     gen_spikes, \"B C L -> B C\", reduction=\"mean\"\n",
                "#         # )\n",
                "\n",
                "#         # # sort channels by mean spike count\n",
                "#         # sorted_indices = spike_count_means.argsort()\n",
                "#         # sorted_indices = torch.flip(sorted_indices, (0,))\n",
                "#         # print(sorted_indices)\n",
                "#         # spike_count_per_neuron = spike_count_per_neuron[:, sorted_indices]\n",
                "#         # gen_spike_count_per_neuron = gen_spike_count_per_neuron[:, sorted_indices]\n",
                "\n",
                "#         # plt.figure(figsize=(8, 4))\n",
                "#         # for i, (spike_count, gen_spike_count) in enumerate(\n",
                "#         #     zip(\n",
                "#         #         spike_count_per_neuron[:, ::10].T, gen_spike_count_per_neuron[:, ::10].T\n",
                "#         #     )\n",
                "#         # ):\n",
                "#         #     plt.violinplot(\n",
                "#         #         [spike_count.numpy(), gen_spike_count.numpy()],\n",
                "#         #         positions=[i, i + 0.5],\n",
                "#         #         showmeans=True,\n",
                "#         #     )\n",
                "#         #     # scatter plot across the violinplot for better visualization\n",
                "#         #     plt.scatter(\n",
                "#         #         [i] * len(spike_count), spike_count.numpy(), color=\"black\", alpha=0.1\n",
                "#         #     )\n",
                "#         #     plt.scatter(\n",
                "#         #         [i + 0.5] * len(gen_spike_count),\n",
                "#         #         gen_spike_count.numpy(),\n",
                "#         #         color=\"black\",\n",
                "#         #         alpha=0.1,\n",
                "#         #     )\n",
                "\n",
                "#         # plt.xticks(\n",
                "#         #     np.arange(len(sorted_indices[::10])) + 0.5,\n",
                "#         #     sorted_indices[::10].numpy().tolist(),\n",
                "#         # )\n",
                "#         # plt.ylabel(\"Spike count\")\n",
                "#         # plt.title(\"Spike count per neuron (real vs generated)\")\n",
                "#         # plt.xlabel(\"neuron index\")\n",
                "#         # plt.yscale(\"symlog\", linthresh=0.001)\n",
                "#         # plt.ylim(\n",
                "#         #     -0.0,\n",
                "#         #     max(\n",
                "#         #         spike_count_per_neuron.max().item(),\n",
                "#         #         gen_spike_count_per_neuron.max().item(),\n",
                "#         #     )\n",
                "#         #     + 0.1,\n",
                "#         # )\n",
                "#         # plt.show()\n",
                "\n",
                "#         # save\n",
                "#         accelerator.save_state(f\"exp/{cfg.exp_name}/epoch_{epoch}\")\n",
                "#         torch.save(ema_model.averaged_model.state_dict(), f\"exp/{cfg.exp_name}/epoch_{epoch}/ema_model.pt\")\n",
                "\n",
                "# pbar.close()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "torch.save(ema_model.averaged_model.state_dict(), f\"exp/{cfg.exp_name}/epoch_{400}/ema_model.pt\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### 3.1. Eval with different mask lengths adn check spike stats (unconditional)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "ret_dict = sample_spikes_with_mask(\n",
                "    ema_model,\n",
                "    scheduler,\n",
                "    ae_model,\n",
                "    cfg,\n",
                "    batch_size=train_latent_dataloader.dataset.train_spike_masks[::6].shape[0],\n",
                "    lengths=[\n",
                "        l.sum() for l in train_latent_dataloader.dataset.train_spike_masks[::6, 0]\n",
                "    ],\n",
                "    device=\"cuda\",\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# torch.save(ret_dict, f\"exp/{cfg.exp_name}/sampled_spikes.pt\")\n",
                "ret_dict = torch.load(f\"exp/{cfg.exp_name}/sampled_spikes.pt\")\n",
                "\n",
                "val_spikes, val_masks = val_latent_dataloader.dataset.train_spikes[:300], val_latent_dataloader.dataset.train_spike_masks[:300]"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "val_spikes_trimmed = []\n",
                "\n",
                "for i in range(len(val_spikes)):\n",
                "    nonzero_mask = val_masks[i,0].nonzero().flatten()\n",
                "    spike = val_spikes[i]\n",
                "    # print(spike, nonzero_mask)\n",
                "    spike_ = spike[:,nonzero_mask]\n",
                "\n",
                "    val_spikes_trimmed.append(spike_)\n",
                "    \n",
                "val_spikes_trimmed"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "train_spikes, train_masks = train_latent_dataloader.dataset.train_spikes[::6], train_latent_dataloader.dataset.train_spike_masks[::6]\n",
                "train_spikes_trimmed = []\n",
                "\n",
                "for i in range(len(train_spikes)):\n",
                "    nonzero_mask = train_masks[i,0].nonzero().flatten()\n",
                "    spike = train_spikes[i]\n",
                "    # print(spike, nonzero_mask)\n",
                "    spike_ = spike[:,nonzero_mask]\n",
                "\n",
                "    train_spikes_trimmed.append(spike_)\n",
                "    \n",
                "train_spikes_trimmed"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# val_spikes_trimmed = [spike[:,mask[0].sum().long().item()] for spike, mask in zip(val_spikes, val_masks)]\n",
                "\n",
                "sampled_spikes, sampled_masks = ret_dict['spikes'], ret_dict['masks']\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "sampled_spikes, sampled_masks"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "sampled_spikes_trimmed = []\n",
                "\n",
                "for i in range(len(sampled_spikes)):\n",
                "    nonzero_mask = sampled_masks[i,0].nonzero().flatten()\n",
                "    spike = sampled_spikes[i]\n",
                "    # print(spike, nonzero_mask)\n",
                "    spike_ = spike[:,nonzero_mask]\n",
                "\n",
                "    sampled_spikes_trimmed.append(spike_)\n",
                "    "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# autoencoder\n",
                "rec_train_spikes = rec_dict['rec_spikes']\n",
                "rec_train_spike_masks = rec_dict['signal_masks']\n",
                "\n",
                "rec_train_spikes_trimmed = []\n",
                "\n",
                "for i in range(len(rec_train_spikes)):\n",
                "    nonzero_mask = rec_train_spikes[i,0].nonzero().flatten()\n",
                "    spike = sampled_spikes[i]\n",
                "    # print(spike, nonzero_mask)\n",
                "    spike_ = spike[:,nonzero_mask]\n",
                "\n",
                "    rec_train_spikes_trimmed.append(spike_)\n",
                "    "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "len(sampled_spikes_trimmed), len(val_spikes_trimmed)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "sampled_spikes_trimmed_cat = torch.cat(sampled_spikes_trimmed, dim=-1)\n",
                "val_spikes_trimmed_cat = torch.cat(val_spikes_trimmed, dim=-1)\n",
                "train_spikes_trimmed_cat = torch.cat(train_spikes_trimmed, dim=-1)\n",
                "rec_train_spikes_trimmed_cat = torch.cat(rec_train_spikes_trimmed, dim=-1)\n",
                "\n",
                "sampled_spikes_trimmed_cat, val_spikes_trimmed_cat"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "train_spikes_trimmed_cat"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "colors = sns.color_palette('hsv', 128)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.scatter(train_spikes_trimmed_cat.mean(-1), val_spikes_trimmed_cat.mean(-1))#, c=colors)\n",
                "plt.xlabel(\"train average spike count per neuron\")\n",
                "plt.ylabel(\"val average spike count per neuron\")\n",
                "from scipy.stats import linregress\n",
                "slope, intercept, r_value, p_value, std_err = linregress(train_spikes_trimmed_cat.mean(-1), val_spikes_trimmed_cat.mean(-1))\n",
                "plt.title(f'Human BCI data, R2={r_value:.2f}')\n",
                "plt.plot(np.linspace(0, 2, 100), np.linspace(0, 2, 100), 'k--')\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.figure(figsize=cm2inch((4.5, 4)), dpi=300)\n",
                "plt.plot(np.linspace(0, 65, 100), np.linspace(0, 65, 100), 'k--', zorder=-10, alpha=0.99)\n",
                "inv_bin_size = 1000/20.0\n",
                "plt.scatter(train_spikes_trimmed_cat.mean(-1)*inv_bin_size, sampled_spikes_trimmed_cat.mean(-1)*inv_bin_size, color='darkred', edgecolors='k', linewidths=0.0, alpha=0.6)\n",
                "plt.xlabel(\"ground truth\")\n",
                "plt.ylabel(\"ldns\")\n",
                "plt.xlim(0, 65)\n",
                "plt.ylim(0, 65)\n",
                "from scipy.stats import linregress\n",
                "slope, intercept, r_value, p_value, std_err = linregress(train_spikes_trimmed_cat.mean(-1), sampled_spikes_trimmed_cat.mean(-1))\n",
                "plt.savefig(os.path.join(\"exp\", cfg.exp_name, \"ldns_vs_gt_neuron_firing_rate.pdf\"), dpi=300)\n",
                "# plt.title(f'Mean firing rate per neuron (Hz)')# R$^2$={r_value:.2f}')\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.scatter(val_spikes_trimmed_cat.mean(-1), sampled_spikes_trimmed_cat.mean(-1), c=colors)\n",
                "plt.xlabel(\"gt (val) average spike count per neuron\")\n",
                "plt.ylabel(\"sampled average spike count per neuron\")\n",
                "# show R2\n",
                "from scipy.stats import linregress\n",
                "slope, intercept, r_value, p_value, std_err = linregress(val_spikes_trimmed_cat.mean(-1), sampled_spikes_trimmed_cat.mean(-1))\n",
                "plt.title(f'Unconditional samples for Human BCI data, R2={r_value:.2f}')\n",
                "\n",
                "plt.plot(np.linspace(0, 2, 100), np.linspace(0, 2, 100), 'k--')\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# plotting correlation matrix\n",
                "\n",
                "# train\n",
                "train_spikes_trimmed_cat_perm = train_spikes_trimmed_cat[:, torch.randperm(train_spikes_trimmed_cat.shape[-1])]\n",
                "corrcoefs_train = np.corrcoef(train_spikes_trimmed_cat_perm[:,:])\n",
                "# corrcoefs_sampled = np.corrcoef(train_spikes_trimmed_cat_perm[:,len(train_spikes_trimmed_cat_perm[0])//2:])\n",
                "# corrcoefs_sampled = np.corrcoef(val_spikes_trimmed_cat)\n",
                "corrcoefs_sampled = np.corrcoef(sampled_spikes_trimmed_cat)\n",
                "# corrcoefs_sampled = np.corrcoef(rec_train_spikes_trimmed_cat)\n",
                "np.fill_diagonal(corrcoefs_train, 0)\n",
                "np.fill_diagonal(corrcoefs_sampled, 0)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "fig = plt.figure(figsize=cm2inch(4, 4), dpi=300)\n",
                "plt.scatter(corrcoefs_train.flatten(), corrcoefs_sampled.flatten(), alpha=0.05, s=1, color='darkred')\n",
                "min_global = min(corrcoefs_train.min(), corrcoefs_sampled.min())\n",
                "max_global = max(corrcoefs_train.max(), corrcoefs_sampled.max())\n",
                "plt.xlabel(\"gt\")\n",
                "plt.ylabel(\"ldns\")\n",
                "# plt.ylabel(\"valid gt\")\n",
                "plt.title(\"neuron vs neuron correlation\")\n",
                "x = np.linspace(min_global, max_global, 10)\n",
                "plt.plot(x, x, 'k--', zorder=-10, alpha=0.99)\n",
                "print(\"R2\", linregress(corrcoefs_train.flatten(), corrcoefs_sampled.flatten()).rvalue**2)\n",
                "print(\"R\", np.corrcoef(corrcoefs_train.flatten(), corrcoefs_sampled.flatten())[0,1])\n",
                "from sklearn.metrics import r2_score\n",
                "r2_score(corrcoefs_train.flatten(), corrcoefs_sampled.flatten())"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "summed_spikes_train = np.concatenate([t.sum(0) for t in train_spikes_trimmed])\n",
                "summed_spikes_val = np.concatenate([t.sum(0) for t in val_spikes_trimmed])\n",
                "summed_spikes_sampled = np.concatenate([t.sum(0) for t in sampled_spikes_trimmed])\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from scipy.stats import gaussian_kde\n",
                "\n",
                "fig, ax = plt.subplots(2, 2, figsize=cm2inch(9, 8))\n",
                "\n",
                "ax = ax.flatten()\n",
                "\n",
                "val = torch.tensor(summed_spikes_sampled).clip(0, 220)\n",
                "gt_spikes = torch.tensor(summed_spikes_train).clip(0, 220)\n",
                "\n",
                "\n",
                "kde_model = gaussian_kde(val)\n",
                "kde_gt = gaussian_kde(gt_spikes)\n",
                "\n",
                "\n",
                "# Evaluating densities over a common range derived from data\n",
                "x_eval = np.linspace(0, 220, 220+1)\n",
                "density_model = kde_model(x_eval)\n",
                "density_gt = kde_gt(x_eval)\n",
                "density_model /= density_model.sum()\n",
                "density_gt /= density_gt.sum()\n",
                "\n",
                "bins_psc = np.arange(-0.5, 220-0.5, 1)\n",
                "# bins_psc = x_eval\n",
                "# plot the population spike count histogram for the lfads model and the data\n",
                "ax[0].hist(gt_spikes, bins=bins_psc, density=True, alpha=0.5, label='data', color='grey', rasterized=True)\n",
                "ax[0].hist(val, bins=bins_psc, density=True, alpha=0.5, label='ldns', color='darkred', rasterized=True)\n",
                "\n",
                "# now plot the density estimate\n",
                "ax[0].plot(x_eval, density_gt, '.-', label='data kde', color='black', rasterized=False)\n",
                "ax[0].plot(x_eval, density_model, '.-', label='ldns kde', color='darkred', rasterized=False)\n",
                "\n",
                "ax[0].set_xlim(40, 160)\n",
                "# set yaxis off\n",
                "ax[0].set_yticks([])\n",
                "# remove yspine\n",
                "ax[0].spines['left'].set_visible(False)\n",
                "\n",
                "ax[0].legend(fontsize=7)\n",
                "ax[0].set_xlabel('spike count')\n",
                "# ax[0].set_ylabel('density')\n",
                "\n",
                "\n",
                "# now plot the correlation matrix\n",
                "\n",
                "# get the correlation structure\n",
                "C_model = corrcoefs_sampled\n",
                "np.fill_diagonal(C_model, 0)\n",
                "C_model = np.tril(C_model, k=-1)\n",
                "\n",
                "C_gt = corrcoefs_train\n",
                "np.fill_diagonal(C_gt, 0)\n",
                "C_gt = np.tril(C_gt, k=-1)\n",
                "\n",
                "ax[1].plot(C_gt.flatten(), C_model.flatten(), '.', alpha=0.3, color='darkred', ms=2, rasterized=True)\n",
                "data_limits = [min(C_gt.min(), C_model.min()), max(C_gt.max(), C_model.max())]\n",
                "#data_limits = [-0.02, 0.05]\n",
                "ax[1].plot([data_limits[0], data_limits[1]], [data_limits[0], data_limits[1]], '--', color='black')\n",
                "# make a 1:1 line that spans the whole plot\n",
                "#ax[1].plot([0, 1], [0, 1], '--', color='black', transform= ax[1].transAxes)\n",
                "ax[1].set_xlabel('gt')\n",
                "ax[1].set_ylabel('ldns')\n",
                "ax[1].set_aspect('equal')\n",
                "data_limis_ax = [data_limits[0]-0.15*np.abs(data_limits[0]), data_limits[1]+0.15*np.abs(data_limits[1])]\n",
                "ax[1].set_xlim(data_limis_ax)\n",
                "ax[1].set_ylim(data_limis_ax)\n",
                "\n",
                "\n",
                "ax[2].plot(spike_stats_gt['mean_isi'].flatten(),spike_stats_sampled['mean_isi'].flatten(), '.', alpha=0.5, color='darkred', ms=2, rasterized=True)\n",
                "data_limits = [min(spike_stats_gt['mean_isi'].flatten().min(), spike_stats_sampled['mean_isi'].flatten().min()), max(spike_stats_gt['mean_isi'].flatten().max(), spike_stats_sampled['mean_isi'].flatten().max())]\n",
                "#data_limits = [-0.02, 0.05]\n",
                "ax[2].plot([data_limits[0], data_limits[1]], [data_limits[0], data_limits[1]], '--', color='black')\n",
                "# make a 1:1 line that spans the whole plot\n",
                "#ax[1].plot([0, 1], [0, 1], '--', color='black', transform= ax[1].transAxes)\n",
                "ax[2].set_xlabel('gt mean isi')\n",
                "ax[2].set_ylabel('ldns' + ' mean isi')\n",
                "ax[2].set_aspect('equal')\n",
                "data_limis_ax = [data_limits[0]-0.15*np.abs(data_limits[0]), data_limits[1]+0.15*np.abs(data_limits[1])]\n",
                "ax[2].set_xlim(data_limis_ax)\n",
                "ax[2].set_ylim(data_limis_ax)\n",
                "\n",
                "\n",
                "\n",
                "ax[3].plot(spike_stats_gt['std_isi'].flatten(),spike_stats_sampled['std_isi'].flatten(), '.', alpha=0.5, color='darkred', ms=2, rasterized=True)\n",
                "data_limits = [min(spike_stats_gt['std_isi'].flatten().min(), spike_stats_sampled['std_isi'].flatten().min()), max(spike_stats_gt['std_isi'].flatten().max(), spike_stats_sampled['std_isi'].flatten().max())]\n",
                "#data_limits = [-0.02, 0.05]\n",
                "ax[3].plot([data_limits[0], data_limits[1]], [data_limits[0], data_limits[1]], '--', color='black')\n",
                "# make a 1:1 line that spans the whole plot\n",
                "#ax[1].plot([0, 1], [0, 1], '--', color='black', transform= ax[1].transAxes)\n",
                "ax[3].set_xlabel('gt std isi')\n",
                "ax[3].set_ylabel('ldns' + ' std isi')\n",
                "ax[3].set_aspect('equal')\n",
                "data_limis_ax = [data_limits[0]-0.15*np.abs(data_limits[0]), data_limits[1]+0.15*np.abs(data_limits[1])]\n",
                "ax[3].set_xlim(data_limis_ax)\n",
                "ax[3].set_ylim(data_limis_ax)\n",
                "ax[3].set_xticks([0.05, 0.15])\n",
                "ax[3].set_yticks([0.05, 0.15])\n",
                "\n",
                "\n",
                "# plt.tight_layout()\n",
                "# save the figure\n",
                "plt.savefig(save_path + f'Fig_population_and_single_spike_stats_{\"ldns\"}.png')\n",
                "plt.savefig(save_path + f'Fig_population_and_single_spike_stats_{\"ldns\"}.pdf')\n",
                "\n",
                " "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(6.5, 3), dpi=300)\n",
                "\n",
                "im = ax[0].imshow(corrcoefs_sampled, cmap='Reds', vmin=-0.2, vmax=0.35, aspect='auto')\n",
                "ax[0].axis('off')\n",
                "cbar = plt.colorbar(im, ax=ax[0], orientation='vertical', fraction=0.046, pad=0.04)\n",
                "# constrain colormap so that the 0 is center and not the mean\n",
                "# set clim\n",
                "# im.set_clim(-0.6, 0.6)\n",
                "\n",
                "\n",
                "ax[0].set_title(\"ldns neuron corr\")\n",
                "im = ax[1].imshow(corrcoefs_train, cmap='Reds', vmin=-0.2, vmax=0.35, aspect='auto')\n",
                "cbar = plt.colorbar(im, ax=ax[1], orientation='vertical', fraction=0.046, pad=0.04)\n",
                "# im.set_clim(-0.6, 0.6)\n",
                "ax[1].set_title(\"gt neuron corr\")\n",
                "ax[1].axis('off')\n",
                "\n",
                "fig.tight_layout()\n",
                "plt.savefig(save_path + f'Fig_correlation_matrix_{\"ndls\"}.pdf')\n",
                "\n",
                "print(corrcoefs_sampled.min(), corrcoefs_sampled.max(), corrcoefs_train.min(), corrcoefs_train.max(), )"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.plot(train_spikes_trimmed_cat.mean(-1), label='train')\n",
                "plt.plot(sampled_spikes_trimmed_cat.mean(-1), label='sampled')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "def get_spike_train_and_stats(spikes, fps=200):\n",
                "    spiketrain = counts_to_spike_trains(spikes, fps=fps)\n",
                "    spike_stats_m = compute_spike_stats_per_neuron(\n",
                "        spiketrain,\n",
                "        n_samples=spikes.shape[0],\n",
                "        n_neurons=spikes.shape[2],\n",
                "        mean_output=False,\n",
                "    )\n",
                "    return spike_stats_m\n",
                "\n",
                "\n",
                "\n",
                "\n",
                "ax[2].plot(spike_stats_gt['mean_isi'].flatten(),spike_stats_model['mean_isi'].flatten(), '.', alpha=0.5, color=color_dict[key], ms=2)\n",
                "data_limits = [min(spike_stats_gt['mean_isi'].flatten().min(), spike_stats_model['mean_isi'].flatten().min()), max(spike_stats_gt['mean_isi'].flatten().max(), spike_stats_model['mean_isi'].flatten().max())]\n",
                "#data_limits = [-0.02, 0.05]\n",
                "ax[2].plot([data_limits[0], data_limits[1]], [data_limits[0], data_limits[1]], '--', color='black')\n",
                "# make a 1:1 line that spans the whole plot\n",
                "#ax[1].plot([0, 1], [0, 1], '--', color='black', transform= ax[1].transAxes)\n",
                "ax[2].set_xlabel('gt mean isi')\n",
                "ax[2].set_ylabel(key + ' mean isi')\n",
                "ax[2].set_aspect('equal')\n",
                "data_limis_ax = [data_limits[0]-0.15*np.abs(data_limits[0]), data_limits[1]+0.15*np.abs(data_limits[1])]\n",
                "ax[2].set_xlim(data_limis_ax)\n",
                "ax[2].set_ylim(data_limis_ax)\n",
                "\n",
                "ax[3].plot(spike_stats_gt['std_isi'].flatten(),spike_stats_model['std_isi'].flatten(), '.', alpha=0.5, color=color_dict[key], ms=2)\n",
                "data_limits = [min(spike_stats_gt['std_isi'].flatten().min(), spike_stats_model['std_isi'].flatten().min()), max(spike_stats_gt['std_isi'].flatten().max(), spike_stats_model['std_isi'].flatten().max())]\n",
                "#data_limits = [-0.02, 0.05]\n",
                "ax[3].plot([data_limits[0], data_limits[1]], [data_limits[0], data_limits[1]], '--', color='black')\n",
                "# make a 1:1 line that spans the whole plot\n",
                "#ax[1].plot([0, 1], [0, 1], '--', color='black', transform= ax[1].transAxes)\n",
                "ax[3].set_xlabel('gt std isi')\n",
                "ax[3].set_ylabel(key + ' std isi')\n",
                "ax[3].set_aspect('equal')\n",
                "data_limis_ax = [data_limits[0]-0.15*np.abs(data_limits[0]), data_limits[1]+0.15*np.abs(data_limits[1])]\n",
                "ax[3].set_xlim(data_limis_ax)\n",
                "ax[3].set_ylim(data_limis_ax)\n",
                "\n",
                "plt.tight_layout()\n",
                "# save the figure\n",
                "plt.savefig(save_path + f'Fig_population_and_single_spike_stats_{key}.png')\n",
                "plt.savefig(save_path + f'Fig_population_and_single_spike_stats_{key}.pdf')\n",
                "\n",
                " "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "\n",
                "print(corrcoefs_sampled.min(), corrcoefs_sampled.max(),corrcoefs_train.min(), corrcoefs_train.max(), )\n",
                "# print(corrcoefs_train)\n",
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(9, 4), dpi=300)\n",
                "\n",
                "im = ax[0].imshow(corrcoefs_train, aspect='auto', cmap='viridis', vmin=-0.5, vmax=1)\n",
                "ax[0].set_title(\"gt\")\n",
                "ax[0].set_ylabel('neuron cross-correlation')\n",
                "plt.colorbar(im, ax=ax[0], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "im = ax[1].imshow(corrcoefs_sampled, aspect='auto', cmap='viridis', vmin=-0.5, vmax=1)\n",
                "ax[1].set_title(\"diffusion\")\n",
                "cbar = plt.colorbar(im, ax=ax[1], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "# manually set the y-tick labels\n",
                "cbar.set_ticks([-0.5, 0, 0.5, 1])\n",
                "cbar.set_ticklabels([-0.5, 0, 0.5, 1])\n",
                "# ax[0].axis('off')\n",
                "ax[0].set_xticks([])\n",
                "ax[0].set_yticks([])\n",
                "ax[0].spines['left'].set_visible(False)\n",
                "ax[0].spines['bottom'].set_visible(False)\n",
                "ax[1].axis('off')\n",
                "\n",
                "fig.tight_layout()\n",
                "# fig.suptitle('neuron cross-correlation')\n",
                "plt.show()\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import compute_spike_stats, counts_to_spike_trains\n",
                "fps=1000/20\n",
                "spike_trains_train_uncat = counts_to_spike_trains([t.permute(1, 0).numpy() for t in train_spikes_trimmed], fps=fps)\n",
                "spike_trains_sampled_uncat = counts_to_spike_trains([t.permute(1, 0).numpy() for t in sampled_spikes_trimmed], fps=fps)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "spike_trains_train_uncat"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "spike_trains_train[(0, 117)].shape, spike_trains_train.keys()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import compute_spike_stats_per_neuron\n",
                "from ntldm.utils.plotting_utils import cm2inch, plot_spiketrain_stats\n",
                "\n",
                "# spike_stats_gt = compute_spike_stats_per_neuron(\n",
                "#     spike_trains_train,\n",
                "#     # n_samples=train_spikes_trimmed_cat.shape[-1],\n",
                "#     n_samples=1,\n",
                "#     n_neurons=train_spikes_trimmed_cat.shape[0],\n",
                "#     mean_output=False,\n",
                "# )\n",
                "# spike_stats_sampled = compute_spike_stats_per_neuron(\n",
                "#     spike_trains_sampled,\n",
                "#     # n_samples=sampled_spikes_trimmed_cat.shape[-1],\n",
                "#     n_samples=1,\n",
                "#     n_neurons=sampled_spikes_trimmed_cat.shape[0],\n",
                "#     mean_output=False,\n",
                "# )\n",
                "\n",
                "spike_stats_gt = compute_spike_stats_per_neuron(\n",
                "    spike_trains_train_uncat,\n",
                "    # n_samples=train_spikes_trimmed_cat.shape[-1],\n",
                "    n_samples=len(train_spikes_trimmed),\n",
                "    n_neurons=train_spikes_trimmed[0].shape[0],\n",
                "    mean_output=False,\n",
                ")\n",
                "spike_stats_sampled = compute_spike_stats_per_neuron(\n",
                "    spike_trains_sampled_uncat,\n",
                "    # n_samples=sampled_spikes_trimmed_cat.shape[-1],\n",
                "    n_samples=len(sampled_spikes_trimmed),\n",
                "    n_neurons=sampled_spikes_trimmed[0].shape[0],\n",
                "    mean_output=False,\n",
                ")\n",
                "# spike_stats_diff = compute_spike_stats_per_neuron(\n",
                "#     spike_trains_sampled,\n",
                "#     n_samples=diffusion_rates.shape[0],\n",
                "#     n_neurons=diffusion_rates.shape[2],\n",
                "#     mean_output=False,\n",
                "# )\n",
                "\n",
                "save_path = \"exp/\" + cfg.exp_name + \"/\"\n",
                "\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_gt,\n",
                "    spike_stats_sampled,\n",
                "    figsize=cm2inch(16, 4),\n",
                "    color=\"darkred\",\n",
                "    labels=[\"ground truth\", \"ldns\"],\n",
                "    save = True, \n",
                "    save_path=save_path + \"compute_spike_stats_per_neuron_gt_ae\"\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "summed_spikes_train = np.concatenate([t.sum(0) for t in train_spikes_trimmed])\n",
                "summed_spikes_val = np.concatenate([t.sum(0) for t in val_spikes_trimmed])\n",
                "summed_spikes_sampled = np.concatenate([t.sum(0) for t in sampled_spikes_trimmed])\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "torch.from_numpy(summed_spikes_train), torch.from_numpy(summed_spikes_sampled)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "fig = plt.figure(figsize=cm2inch(4, 1.5), dpi=300)\n",
                "spikes_train_len = np.array([t.shape[-1] for t in train_spikes_trimmed])\n",
                "# spikes_sampled_len = np.array([t.shape[-1] for t in sampled_spikes_trimmed])\n",
                "_1, bins, _2 = plt.hist(spikes_train_len/50, color='grey', alpha=0.99, bins=40)\n",
                "# plt.hist(spikes_sampled_len, color='darkred', alpha=0.5, bins=bins)\n",
                "plt.yticks([])\n",
                "plt.xticks([2, 6, 10])\n",
                "plt.gca().spines[\"left\"].set_visible(False)\n",
                "plt.xlabel('trial length (s)')\n",
                "save_path = \"exp/\" + cfg.exp_name + \"/\"\n",
                "plt.savefig(save_path + \"trial_length_hist.pdf\", dpi=300)\n",
                "print(spikes_train_len.mean()/50)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(4, 4), dpi=300)\n",
                "\n",
                "sampled_mean = summed_spikes_sampled.mean()\n",
                "sampled_std = summed_spikes_sampled.std()\n",
                "train_mean = summed_spikes_train.mean()\n",
                "train_std = summed_spikes_train.std()\n",
                "\n",
                "print(\n",
                "    len(summed_spikes_train),\n",
                "    len(summed_spikes_sampled),\n",
                "    torch.tensor(summed_spikes_train),\n",
                "    torch.tensor(summed_spikes_sampled),\n",
                ")\n",
                "\n",
                "# ax.hist(torch.tensor(summed_spikes_train).clip(0, 220), color='grey', label=f'Ground Truth\\n{train_mean:.1f}$\\pm${train_std:.1f}', bins=50, alpha=0.5, density=True)\n",
                "# ax.hist(torch.tensor(summed_spikes_sampled).clip(0, 220), color='darkred', label=f'LDNS\\n{sampled_mean:.1f}$\\pm${sampled_std:.1f}', bins=50, alpha=0.5, density=True)\n",
                "ax.hist(\n",
                "    torch.tensor(summed_spikes_train).clip(0, 220),\n",
                "    color=\"grey\",\n",
                "    label=f\"gt\",\n",
                "    bins=50,\n",
                "    alpha=0.5,\n",
                "    density=True,\n",
                ")\n",
                "ax.hist(\n",
                "    torch.tensor(summed_spikes_sampled).clip(0, 220),\n",
                "    color=\"darkred\",\n",
                "    label=f\"ldns\",\n",
                "    bins=50,\n",
                "    alpha=0.5,\n",
                "    density=True,\n",
                ")\n",
                "print(\n",
                "    \"GT:\",\n",
                "    torch.tensor(summed_spikes_train).mean(),\n",
                "    torch.tensor(summed_spikes_train).std(),\n",
                "    \"LDNS:\",\n",
                "    torch.tensor(summed_spikes_sampled).mean(),\n",
                "    torch.tensor(summed_spikes_sampled).std(),\n",
                ")\n",
                "# ax.vlines([summed_spikes_train.mean()], [0], [0.03], color='grey', linestyle='--')\n",
                "# ax.vlines([summed_spikes_sampled.mean()], [0], [0.03], color='darkred', linestyle='--')\n",
                "# ax.set_ylim(1, summed_spikes_train.)\n",
                "# ax[1].set_ylim(1, 1e7)\n",
                "# ax[0].set_yscale('log')\n",
                "# ax[1].set_yscale('log')\n",
                "ax.legend()\n",
                "ax.set_xlim(40, 160)\n",
                "# ax.set_xlim(2000, 8000)\n",
                "plt.yticks([])\n",
                "# exponent mantissa format for y-axis\n",
                "from matplotlib.ticker import ScalarFormatter\n",
                "plt.gca().yaxis.set_major_formatter(ScalarFormatter(useMathText=True))\n",
                "\n",
                "plt.xlabel(\"#spikes per 1s\")\n",
                "# plt.gca().spines[\"left\"].set_visible(False)\n",
                "plt.ylabel('frequency')\n",
                "# fig.suptitle(\"Population spike count distribution\")\n",
                "plt.savefig(save_path + \"population_spike_count_dist.pdf\", dpi=300)\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "fig = plt.figure(figsize=cm2inch(6, 2), dpi=300)\n",
                "\n",
                "vmin = 0\n",
                "vmax = 5\n",
                "im = plt.imshow(train_spikes[5, :], cmap=\"Greys\", alpha=1.0, aspect=\"auto\", vmin=vmin, vmax=vmax)\n",
                "# plt.colorbar(im, ax=axs[1, 0], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "# sampled_mask_idx_with_1 = sampled_masks[0,0].nonzero().flatten()\n",
                "\n",
                "# sampled_spikes_padded = torch.cat((sampled_spikes[0, :, sampled_mask_idx_with_1], torch.zeros(128, len(sampled_mask_idx_with_1))), dim=-1)\n",
                "# plt.imshow(\n",
                "#     sampled_spikes_padded, cmap=\"Greys\", alpha=1.0, aspect=\"auto\"\n",
                "# )\n",
                "# remove xaxis altogether\n",
                "plt.xticks([])\n",
                "plt.gca().spines[\"bottom\"].set_visible(False)\n",
                "\n",
                "# set yticks and yticklabels\n",
                "plt.yticks([])\n",
                "# plt.gca().set_yticklabels([1, 64, 128])\n",
                "plt.ylabel(\"neurons\")\n",
                "plt.savefig(save_path + \"gt_spikes1.pdf\", dpi=300)\n",
                "plt.show()\n",
                "\n",
                "\n",
                "\n",
                "fig = plt.figure(figsize=cm2inch(6, 2), dpi=300)\n",
                "\n",
                "im = plt.imshow(train_spikes[0, :], cmap=\"Greys\", alpha=1.0, aspect=\"auto\", vmin=vmin, vmax=vmax)\n",
                "# plt.colorbar(im, ax=axs[1, 0], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "# sampled_mask_idx_with_1 = sampled_masks[5,0].nonzero().flatten()\n",
                "\n",
                "# sampled_spikes_padded = torch.cat((sampled_spikes[5, :, sampled_mask_idx_with_1], torch.zeros(128, len(sampled_mask_idx_with_1))), dim=-1)\n",
                "# plt.imshow(\n",
                "#     sampled_spikes_padded, cmap=\"Greys\", alpha=1.0, aspect=\"auto\"\n",
                "# )\n",
                "# remove xaxis altogether\n",
                "plt.xticks([0,500])\n",
                "plt.gca().set_xticklabels([0,10])\n",
                "plt.xlabel(\"time (s)\")\n",
                "plt.xlabel(\"time (s)\")\n",
                "# plt.gca().spines[\"bottom\"].set_visible(False)\n",
                "\n",
                "# set yticks and yticklabels\n",
                "plt.yticks([])\n",
                "# plt.gca().set_yticklabels([1, 64, 128])\n",
                "plt.ylabel(\"neurons\")\n",
                "\n",
                "plt.savefig(save_path + \"gt_spikes2.pdf\", dpi=300)\n",
                "plt.show()\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "fig = plt.figure(figsize=cm2inch(6, 2), dpi=300)\n",
                "vmax = 4\n",
                "sampled_mask_idx_with_11 = sampled_masks[248,0].nonzero().flatten()\n",
                "\n",
                "sampled_spikes_padded2 = torch.cat((sampled_spikes[248, :, sampled_mask_idx_with_11], torch.zeros(128, 512-len(sampled_mask_idx_with_11))), dim=-1)\n",
                "print(sampled_spikes_padded2.shape)\n",
                "plt.imshow(\n",
                "    sampled_spikes_padded2, cmap=\"Greys\", alpha=1.0, aspect=\"auto\", vmin=0, vmax=vmax\n",
                ")\n",
                "# remove xaxis altogether\n",
                "# plt.xticks([0,500])\n",
                "plt.xticks([])\n",
                "# plt.gca().set_xticklabels([0,10])\n",
                "# plt.xlabel(\"time (s)\")\n",
                "plt.gca().spines[\"bottom\"].set_visible(False)\n",
                "\n",
                "# set yticks and yticklabels\n",
                "plt.yticks([])\n",
                "# plt.gca().set_yticklabels([1, 64, 128])\n",
                "plt.ylabel(\"neurons\")\n",
                "\n",
                "plt.savefig(save_path + \"sampled_spikes1.pdf\", dpi=300)\n",
                "plt.show()\n",
                "\n",
                "fig = plt.figure(figsize=cm2inch(6, 2), dpi=300)\n",
                "\n",
                "sampled_mask_idx_with_12 = sampled_masks[16,0].nonzero().flatten()\n",
                "\n",
                "sampled_spikes_padded2 = torch.cat((sampled_spikes[16, :, sampled_mask_idx_with_12], torch.zeros(128, 512-len(sampled_mask_idx_with_12))), dim=-1)\n",
                "plt.imshow(\n",
                "    sampled_spikes_padded2, cmap=\"Greys\", alpha=1.0, aspect=\"auto\", vmin=0, vmax=vmax\n",
                ")\n",
                "# remove xaxis altogether\n",
                "# plt.xticks([0,500])\n",
                "# plt.gca().set_xticklabels([0,10])\n",
                "plt.xticks([])\n",
                "# plt.xlabel(\"time (s)\")\n",
                "plt.gca().spines[\"bottom\"].set_visible(False)\n",
                "\n",
                "# set yticks and yticklabels\n",
                "plt.yticks([])\n",
                "# plt.gca().set_yticklabels([1, 64, 128])\n",
                "plt.ylabel(\"neurons\")\n",
                "\n",
                "plt.savefig(save_path + \"sampled_spikes2.pdf\", dpi=300)\n",
                "plt.show()\n",
                "\n",
                "\n",
                "fig = plt.figure(figsize=cm2inch(6, 2), dpi=300)\n",
                "\n",
                "sampled_mask_idx_with_13 = sampled_masks[89,0].nonzero().flatten()\n",
                "\n",
                "sampled_spikes_padded2 = torch.cat((sampled_spikes[89, :, sampled_mask_idx_with_13], torch.zeros(128, 512-len(sampled_mask_idx_with_13))), dim=-1)\n",
                "plt.imshow(\n",
                "    sampled_spikes_padded2, cmap=\"Greys\", alpha=1.0, aspect=\"auto\", vmin=0, vmax=vmax\n",
                ")\n",
                "# remove xaxis altogether\n",
                "plt.xticks([0,500])\n",
                "plt.gca().set_xticklabels([0,10])\n",
                "plt.xlabel(\"time (s)\")\n",
                "# plt.gca().spines[\"bottom\"].set_visible(False)\n",
                "\n",
                "# set yticks and yticklabels\n",
                "plt.yticks([])\n",
                "# plt.gca().set_yticklabels([1, 64, 128])\n",
                "plt.ylabel(\"neurons\")\n",
                "\n",
                "plt.savefig(save_path + \"sampled_spikes3.pdf\", dpi=300)\n",
                "plt.show()\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "for i in range(len(sampled_masks)):\n",
                "    if sampled_masks[i,0].nonzero().shape[0] <120:\n",
                "        print(i, sampled_masks[i,0].nonzero().shape[0])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "ms = 20\n",
                "for i in range(len(sampled_spikes_trimmed[10])):\n",
                "    plt.plot(\n",
                "        np.ones_like(sampled_spikes_trimmed[10][0]) * i,\n",
                "        sampled_spikes_trimmed[10][i],\n",
                "        \"|\",\n",
                "        color=\"darkred\",\n",
                "        markersize=ms,\n",
                "    )\n",
                "\n",
                "plt.xlabel(\"time [s]\")\n",
                "plt.ylabel(\"neuron idx\")\n",
                "plt.title(\"diff spikes\")\n",
                "plt.locator_params(nbins=5)\n",
                "\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "sampled_spikes_trimmed[10]"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "sampled_spikes_trimmed"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "denoiser.eval()\n",
                "\n",
                "val_batch = next(iter(val_latent_dataloader))\n",
                "val_batch['mask'][:2,0].sum(-1)\n",
                "# ret_dict = sample_spikes_with_mask(ema_model, scheduler, ae_model, cfg, lengths=[int(i.item()) for i in val_batch['mask'][:2,0].sum(-1)], batch_size=2, device=\"cuda\")\n",
                "with torch.no_grad():\n",
                "    val_batch_rates = ae_model(val_batch[\"signal\"])[0].cpu()\n",
                "\n",
                "plot_real_vs_sampled_rates_and_spikes(\n",
                "    val_batch_rates,\n",
                "    ret_dict[\"rates\"],\n",
                "    val_batch[\"signal\"].cpu(),\n",
                "    ret_dict[\"spikes\"],\n",
                "    val_batch[\"mask\"].cpu(),\n",
                "    ret_dict[\"masks\"],\n",
                "    batch_idx=3,\n",
                ")\n",
                "plot_real_vs_sampled_rates_and_spikes(\n",
                "    val_batch_rates,\n",
                "    ret_dict[\"rates\"],\n",
                "    val_batch[\"signal\"].cpu(),\n",
                "    ret_dict[\"spikes\"],\n",
                "    val_batch[\"mask\"].cpu(),\n",
                "    ret_dict[\"masks\"],\n",
                "    batch_idx=5,\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "plot_real_vs_sampled_rates_and_spikes(\n",
                "    val_batch_rates,\n",
                "    ret_dict[\"rates\"],\n",
                "    val_batch[\"signal\"].cpu(),\n",
                "    ret_dict[\"spikes\"],\n",
                "    val_batch[\"mask\"].cpu(),\n",
                "    ret_dict[\"masks\"],\n",
                "    batch_idx=40,\n",
                ")\n",
                "plot_real_vs_sampled_rates_and_spikes(\n",
                "    val_batch_rates,\n",
                "    ret_dict[\"rates\"],\n",
                "    val_batch[\"signal\"].cpu(),\n",
                "    ret_dict[\"spikes\"],\n",
                "    val_batch[\"mask\"].cpu(),\n",
                "    ret_dict[\"masks\"],\n",
                "    batch_idx=200,\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### creating per-timestep conditioned diffusion model (also bidirectional, with different conditioning networks)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# # test\n",
                "# C_in = 32\n",
                "# C = 256\n",
                "# L = 1000\n",
                "# kernel = \"s4\"\n",
                "# bidirectional = True\n",
                "# kernel_params = None\n",
                "# num_blocks = 6\n",
                "# time_condition_dim = 512\n",
                "# condition_dim = None\n",
                "\n",
                "\n",
                "# denoiser = TSConditionalDenoiser(\n",
                "#     C_in=C_in,\n",
                "#     C=C,\n",
                "#     L=L,\n",
                "#     kernel=kernel,\n",
                "#     bidirectional=bidirectional,\n",
                "#     kernel_params=kernel_params,\n",
                "#     num_blocks=num_blocks,\n",
                "#     time_condition_dim=time_condition_dim,\n",
                "#     condition_dim=condition_dim,\n",
                "# )\n",
                "\n",
                "# x = torch.randn(2, C_in, L)\n",
                "# t = torch.randint(0, 1000, (2,))\n",
                "# t_mask = torch.ones(1, 2, L)\n",
                "# c_t = torch.randn(2, time_condition_dim, L)\n",
                "\n",
                "\n",
                "# from torchinfo import summary\n",
                "# summary(denoiser, (x.shape, t.shape, t_mask.shape, c_t.shape))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "plt.plot(\n",
                "    sampled_spikes_trimmed[10],\n",
                "    np.ones_like(sampled_spikes_trimmed[10]) * neuron_idx,\n",
                "    \"|\",\n",
                "    color=\"darkred\",\n",
                "    markersize=ms,\n",
                ")\n",
                "\n",
                "plt.xlabel(\"time [s]\")\n",
                "plt.ylabel(\"neuron idx\")\n",
                "plt.title(\"diff spikes\")\n",
                "plt.locator_params(nbins=5)\n",
                "\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "\n",
                "### 3. Train diffusion to unconditionally generate rates (and then spikes), check if these spikes follow the distribution of real spikes"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "## initialize (unconditional) denoiser\n",
                "from ntldm.networks import TSConditionalDenoiser\n",
                "\n",
                "denoiser = TSConditionalDenoiser(\n",
                "    C_in=cfg.denoiser_model.C_in,\n",
                "    C=cfg.denoiser_model.C,\n",
                "    L=cfg.dataset.max_seqlen,\n",
                "    kernel=cfg.denoiser_model.kernel,\n",
                "    num_blocks=cfg.denoiser_model.num_blocks,\n",
                "    bidirectional=cfg.denoiser_model.get(\"bidirectional\", True),\n",
                "    time_condition_dim=512, # phoneme bert embedding dim \n",
                "    condition_dim=None, # no global conditioning\n",
                ")\n",
                "\n",
                "# initial values may be way off so better to scale down the output layer\n",
                "denoiser.conv_out.weight.data = denoiser.conv_out.weight.data * 0.1\n",
                "denoiser.conv_out.bias.data = denoiser.conv_out.bias.data * 0.1\n",
                "\n",
                "scheduler = DDPMScheduler(\n",
                "    num_train_timesteps=cfg.denoiser_model.num_train_timesteps,\n",
                "    clip_sample=False,\n",
                "    beta_schedule=\"linear\", # ddpm doesnt support cosine\n",
                ")\n",
                "\n",
                "\n",
                "optimizer = torch.optim.AdamW(\n",
                "    denoiser.parameters(), lr=cfg.training.lr\n",
                ")  # default wd=0.01 for now\n",
                "\n",
                "\n",
                "\n",
                "num_batches = len(train_latent_dataloader)\n",
                "lr_scheduler = get_scheduler(\n",
                "    name=\"cosine\",\n",
                "    optimizer=optimizer,\n",
                "    num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "    num_training_steps=num_batches * cfg.training.num_epochs * 1.3,  # total number of steps\n",
                ")\n",
                "\n",
                "# prepare the denoiser model and dataset\n",
                "(\n",
                "    denoiser,\n",
                "    optimizer,\n",
                "    lr_scheduler,\n",
                ") = accelerator.prepare(\n",
                "    denoiser,\n",
                "    optimizer,\n",
                "    lr_scheduler,\n",
                ")\n",
                "\n",
                "accelerator.load_state(f\"exp/{cfg.exp_name}/epoch_1300\")\n",
                "\n",
                "ema_model = EMAModel(denoiser)\n",
                "\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "from ntldm.utils import count_parameters\n",
                "\n",
                "print(count_parameters(denoiser)/1e6, \"M parameters\")\n",
                "\n",
                "# accelerator.load_state(f\"exp/{cfg.exp_name}/epoch_200\") # 200 because i restarted the traiinng with ep=0\n",
                "# ema_model = EMAModel(denoiser)\n",
                "\n",
                "\n",
                "ae_model, denoiser"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# def sample(\n",
                "#     ema_denoiser,\n",
                "#     scheduler,\n",
                "#     cfg,\n",
                "#     batch_size=1,\n",
                "#     generator=None,\n",
                "#     device=\"cuda\",\n",
                "#     signal_length=None\n",
                "# ):  \n",
                "#     if signal_length is None:\n",
                "#         signal_length = cfg.dataset.signal_length\n",
                "#     z_t = torch.randn(\n",
                "#         (batch_size, cfg.denoiser_model.C_in, signal_length)\n",
                "#     ).to(device)\n",
                "#     ema_denoiser_avg = ema_denoiser.averaged_model\n",
                "#     ema_denoiser_avg.eval()\n",
                "\n",
                "\n",
                "#     scheduler.set_timesteps(cfg.denoiser_model.num_train_timesteps)\n",
                "\n",
                "#     for t in tqdm(scheduler.timesteps, desc=\"Sampling DDPM\"):\n",
                "#         with torch.no_grad():\n",
                "#             model_output = ema_denoiser_avg(\n",
                "#                 z_t, torch.tensor([t] * batch_size).to(device).long()\n",
                "#             )\n",
                "#         z_t = scheduler.step(\n",
                "#             model_output, t, z_t, generator=generator, return_dict=False\n",
                "#         )[0]\n",
                "\n",
                "#     return z_t\n",
                "\n",
                "\n",
                "# def sample_spikes(ema_denoiser, scheduler, ae, cfg, batch_size=1, device=\"cuda\"):\n",
                "#     z_t = torch.randn(\n",
                "#         (batch_size, cfg.denoiser_model.C_in, cfg.dataset.signal_length)\n",
                "#     ).to(device)\n",
                "#     ema_denoiser_avg = ema_denoiser.averaged_model\n",
                "#     ema_denoiser_avg.eval()\n",
                "#     scheduler.set_timesteps(cfg.denoiser_model.num_train_timesteps)\n",
                "\n",
                "#     for t in tqdm(scheduler.timesteps, desc=\"Sampling DDPM\"):\n",
                "#         with torch.no_grad():\n",
                "#             model_output = ema_denoiser_avg(\n",
                "#                 z_t, torch.tensor([t] * batch_size).to(device).long()\n",
                "#             )\n",
                "#         z_t = scheduler.step(model_output, t, z_t, return_dict=False)[0]\n",
                "\n",
                "#     z_t = z_t * latent_dataset_train.latent_stds.to(z_t.device) + latent_dataset_train.latent_means.to(z_t.device)\n",
                "\n",
                "#     with torch.no_grad():\n",
                "#         rates = ae.decode(z_t).cpu()\n",
                "    \n",
                "#     spikes = torch.poisson(rates)\n",
                "\n",
                "#     return spikes, rates\n",
                "    "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# ## to restart training from the last checkpoint\n",
                "\n",
                "# lr_scheduler = get_scheduler(\n",
                "#     name=\"cosine\",\n",
                "#     optimizer=optimizer,\n",
                "#     num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "#     num_training_steps=num_batches * cfg.training.num_epochs * 1.3,  # total number of steps\n",
                "# )\n",
                "# lr_scheduler = accelerator.prepare(lr_scheduler)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "loss_fn = torch.nn.SmoothL1Loss(\n",
                "    beta=0.04, reduction=\"none\"\n",
                ")  # faster convergence than mse\n",
                "\n",
                "\n",
                "wandb.init(project=\"ntldm\", entity=\"anon-project\")\n",
                "pbar = tqdm(range(0, cfg.training.num_epochs), desc=\"epochs\")\n",
                "for epoch in pbar:\n",
                "    for i, batch in enumerate(train_latent_dataloader):\n",
                "        denoiser.train()\n",
                "        optimizer.zero_grad()\n",
                "\n",
                "        z = batch[\"latent\"]\n",
                "        z_mask = batch[\"mask\"]\n",
                "        embedding = batch[\"embedding\"]\n",
                "        embedding_mask = batch[\"embedding_mask\"]\n",
                "\n",
                "        t = torch.randint(\n",
                "            0, cfg.denoiser_model.num_train_timesteps, (z.shape[0],), device=\"cpu\"\n",
                "        ).long()\n",
                "\n",
                "        noise = torch.randn_like(z)\n",
                "        noisy_z = scheduler.add_noise(z, noise, t)\n",
                "\n",
                "        concatenated_masks = torch.cat([z_mask, embedding_mask], dim=1)\n",
                "\n",
                "        noise_pred = denoiser(noisy_z, t, t_mask=concatenated_masks, c_t=embedding)\n",
                "\n",
                "        # loss = torch.nn.functional.mse_loss(noise, noise_pred) * 0.5\n",
                "        # loss = loss + (noise - noise_pred).abs().mean() * 0.5  # l1 loss\n",
                "\n",
                "        loss = loss_fn(noise, noise_pred)\n",
                "        loss = loss * z_mask  # mask out the padding\n",
                "        loss = loss.mean()\n",
                "\n",
                "        accelerator.backward(loss)\n",
                "        accelerator.clip_grad_norm_(denoiser.parameters(), 1.0)\n",
                "\n",
                "        optimizer.step()\n",
                "        lr_scheduler.step()\n",
                "\n",
                "        if i % 10 == 0:\n",
                "            pbar.set_postfix(\n",
                "                {\n",
                "                    \"loss\": loss.item(),\n",
                "                    \"lr\": lr_scheduler.get_last_lr()[0],\n",
                "                    \"epoch\": epoch,\n",
                "                }\n",
                "            )\n",
                "            wandb.log(\n",
                "                {\n",
                "                    \"loss\": loss.item(),\n",
                "                    \"lr\": lr_scheduler.get_last_lr()[0],\n",
                "                    \"epoch\": epoch,\n",
                "                }\n",
                "            )\n",
                "\n",
                "        ema_model.step(denoiser)\n",
                "\n",
                "    if (epoch) % 100 == 0:  # plot samples\n",
                "\n",
                "        # denoiser.eval()\n",
                "\n",
                "        # sampled_latents = sample(\n",
                "        #     ema_denoiser=ema_model,\n",
                "        #     scheduler=scheduler,\n",
                "        #     cfg=cfg,\n",
                "        #     batch_size=2,\n",
                "        #     device=\"cuda\",\n",
                "        # )\n",
                "        # sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "        #     sampled_latents.device\n",
                "        # ) + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "\n",
                "        # real_latents = latent_dataset_train.latents[:2].cuda()\n",
                "        # real_latents = real_latents * latent_dataset_train.latent_stds.to(\n",
                "        #     real_latents.device\n",
                "        # ) + latent_dataset_train.latent_means.to(real_latents.device)\n",
                "\n",
                "        # with torch.no_grad():\n",
                "        #     sampled_rates = ae_model.decode(sampled_latents).cpu()\n",
                "        #     decoded_rates_from_real_latents = ae_model.decode(real_latents).cpu()\n",
                "\n",
                "        # fig, ax = plt.subplots(1, 2, figsize=cm2inch(12, 4))\n",
                "        # im = ax[0].imshow(sampled_rates[0], aspect=\"auto\")\n",
                "        # ax[0].set_title(\"Sampled rates\")\n",
                "        # fig.colorbar(im, ax=ax[0], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "\n",
                "        # im = ax[1].imshow(decoded_rates_from_real_latents[0], aspect=\"auto\")\n",
                "        # ax[1].set_title(\"Real rates\")\n",
                "        # fig.colorbar(im, ax=ax[1], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "        # fig.tight_layout()\n",
                "        # plt.show()\n",
                "\n",
                "        # # get avg spike count across neurons\n",
                "        # real_spikes = latent_dataset_val.train_spikes.cpu()\n",
                "        # gen_spikes, gen_rates = sample_spikes(\n",
                "        #     ema_model,\n",
                "        #     scheduler,\n",
                "        #     ae_model,\n",
                "        #     cfg,\n",
                "        #     batch_size=real_spikes.shape[0] * 10,\n",
                "        #     device=\"cuda\",\n",
                "        # )\n",
                "\n",
                "        # spike_count_means = real_spikes.mean(dim=(0, 2))\n",
                "        # spike_count_stds = real_spikes.std(dim=(0, 2))\n",
                "        # gen_spike_count_means = gen_spikes.mean(dim=(0, 2))\n",
                "        # gen_spike_count_stds = gen_spikes.std(dim=(0, 2))\n",
                "\n",
                "        # print(real_spikes, gen_spikes, spike_count_means, spike_count_stds)\n",
                "\n",
                "        # # plot boxplots across neurons\n",
                "        # plt.figure(figsize=cm2inch(4, 4))\n",
                "        # plt.violinplot(\n",
                "        #     [spike_count_means.numpy(), gen_spike_count_means.numpy()],\n",
                "        #     positions=[1, 2],\n",
                "        #     showmeans=True,\n",
                "        # )\n",
                "        # plt.ylabel(\"Mean spike count\")\n",
                "        # plt.title(\"Mean spike count across neurons\")\n",
                "        # plt.xticks([1, 2], [\"Real\", \"Generated\"])\n",
                "        # plt.show()\n",
                "\n",
                "        # # plot boxplots per neuron\n",
                "        # from einops import reduce\n",
                "\n",
                "        # spike_count_per_neuron = reduce(real_spikes, \"B C L -> B C\", reduction=\"mean\")\n",
                "        # gen_spike_count_per_neuron = reduce(\n",
                "        #     gen_spikes, \"B C L -> B C\", reduction=\"mean\"\n",
                "        # )\n",
                "\n",
                "        # # sort channels by mean spike count\n",
                "        # sorted_indices = spike_count_means.argsort()\n",
                "        # sorted_indices = torch.flip(sorted_indices, (0,))\n",
                "        # print(sorted_indices)\n",
                "        # spike_count_per_neuron = spike_count_per_neuron[:, sorted_indices]\n",
                "        # gen_spike_count_per_neuron = gen_spike_count_per_neuron[:, sorted_indices]\n",
                "\n",
                "        # plt.figure(figsize=(8, 4))\n",
                "        # for i, (spike_count, gen_spike_count) in enumerate(\n",
                "        #     zip(\n",
                "        #         spike_count_per_neuron[:, ::10].T, gen_spike_count_per_neuron[:, ::10].T\n",
                "        #     )\n",
                "        # ):\n",
                "        #     plt.violinplot(\n",
                "        #         [spike_count.numpy(), gen_spike_count.numpy()],\n",
                "        #         positions=[i, i + 0.5],\n",
                "        #         showmeans=True,\n",
                "        #     )\n",
                "        #     # scatter plot across the violinplot for better visualization\n",
                "        #     plt.scatter(\n",
                "        #         [i] * len(spike_count), spike_count.numpy(), color=\"black\", alpha=0.1\n",
                "        #     )\n",
                "        #     plt.scatter(\n",
                "        #         [i + 0.5] * len(gen_spike_count),\n",
                "        #         gen_spike_count.numpy(),\n",
                "        #         color=\"black\",\n",
                "        #         alpha=0.1,\n",
                "        #     )\n",
                "\n",
                "        # plt.xticks(\n",
                "        #     np.arange(len(sorted_indices[::10])) + 0.5,\n",
                "        #     sorted_indices[::10].numpy().tolist(),\n",
                "        # )\n",
                "        # plt.ylabel(\"Spike count\")\n",
                "        # plt.title(\"Spike count per neuron (real vs generated)\")\n",
                "        # plt.xlabel(\"neuron index\")\n",
                "        # plt.yscale(\"symlog\", linthresh=0.001)\n",
                "        # plt.ylim(\n",
                "        #     -0.0,\n",
                "        #     max(\n",
                "        #         spike_count_per_neuron.max().item(),\n",
                "        #         gen_spike_count_per_neuron.max().item(),\n",
                "        #     )\n",
                "        #     + 0.1,\n",
                "        # )\n",
                "        # plt.show()\n",
                "\n",
                "        # save\n",
                "        accelerator.save_state(f\"exp/{cfg.exp_name}/epoch_{epoch}\")\n",
                "\n",
                "pbar.close()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def sample_conditional(\n",
                "    ema_denoiser,\n",
                "    scheduler,\n",
                "    cfg,\n",
                "    concatenated_mask,\n",
                "    embedding,\n",
                "    batch_size=1,\n",
                "    generator=None,\n",
                "    device=\"cuda\",\n",
                "    signal_length=None,\n",
                "):\n",
                "    if signal_length is None:\n",
                "        signal_length = cfg.dataset.max_seqlen\n",
                "    z_t = torch.randn((batch_size, cfg.denoiser_model.C_in, signal_length)).to(device)\n",
                "    ema_denoiser_avg = ema_denoiser.averaged_model\n",
                "    ema_denoiser_avg.eval()\n",
                "\n",
                "    scheduler.set_timesteps(cfg.denoiser_model.num_train_timesteps)\n",
                "\n",
                "    for t in tqdm(scheduler.timesteps, desc=\"Sampling DDPM\"):\n",
                "        with torch.no_grad():\n",
                "            # print(z_t.shape, t, concatenated_mask.shape, embedding.shape)\n",
                "            model_output = ema_denoiser_avg(\n",
                "                z_t,\n",
                "                torch.tensor([t] * batch_size).to(device).long(),\n",
                "                t_mask=concatenated_mask,\n",
                "                c_t=embedding,\n",
                "            )\n",
                "        z_t = scheduler.step(\n",
                "            model_output, t, z_t, generator=generator, return_dict=False\n",
                "        )[0]\n",
                "\n",
                "    return z_t\n",
                "\n",
                "\n",
                "def gen_spikes_conditional(avg_denoiser, ae, dataloader):\n",
                "    rec_rates = []\n",
                "    real_latents = []\n",
                "    gen_latents = []\n",
                "    real_signals = []\n",
                "    original_rates = []\n",
                "    gen_spikes = []\n",
                "    masks = []\n",
                "\n",
                "    latent_means = dataloader.dataset.latent_means\n",
                "    latent_stds = dataloader.dataset.latent_stds\n",
                "\n",
                "    # dataset output dict looks like this:\n",
                "    # {'signal': tensor[128, 512] n=65536 (0.2Mb) x∈[0., 6.000] μ=0.348 σ=0.693,\n",
                "    # 'latent': tensor[32, 512] n=16384 (64Kb) x∈[-2.909, 1.777] μ=-0.128 σ=0.830,\n",
                "    # 'mask': tensor[1, 512] 2Kb x∈[0., 1.000] μ=0.607 σ=0.489,\n",
                "    # 'embedding': tensor[512, 512] n=262144 (1Mb) x∈[-9.599, 8.834] μ=-0.026 σ=1.963,\n",
                "    # 'embedding_mask': tensor[1, 512] 2Kb x∈[0., 1.000] μ=0.357 σ=0.480,\n",
                "    # 'original_sentence': 'We could talk about that.',\n",
                "    # 'phonemized_sentence': 'wiː kʊd tˈɔːk ɐbˌaʊt ðˈæt .'}\n",
                "\n",
                "    for i, batch in enumerate(dataloader):\n",
                "        real_latent = batch[\"latent\"]\n",
                "        real_signal = batch[\"signal\"]\n",
                "        embedding = batch[\"embedding\"]\n",
                "        embedding_mask = batch[\"embedding_mask\"]\n",
                "        latent_mask = batch[\"mask\"]\n",
                "        # print(embedding_mask.shape, real_signal.shape)\n",
                "\n",
                "        concatenated_mask = torch.cat([latent_mask, embedding_mask], dim=1)\n",
                "        # print(concatenated_mask.shape, embedding.shape)\n",
                "\n",
                "        sampled_latent = sample_conditional(\n",
                "            avg_denoiser,\n",
                "            scheduler,\n",
                "            cfg,\n",
                "            concatenated_mask,\n",
                "            embedding,\n",
                "            batch_size=real_signal.shape[0],\n",
                "            device=\"cuda\",\n",
                "        )\n",
                "\n",
                "        sampled_latent = sampled_latent * latent_stds.to(\n",
                "            sampled_latent.device\n",
                "        ) + latent_means.to(sampled_latent.device)\n",
                "    \n",
                "        real_latent = real_latent * latent_stds.to(real_latent.device) + latent_means.to(real_latent.device)\n",
                "\n",
                "        rec_rates.append(ae.decode(sampled_latent).cpu())\n",
                "        gen_latents.append(sampled_latent.cpu())\n",
                "        real_signals.append(real_signal.cpu())\n",
                "        original_rates.append(ae.decode(real_latent).cpu())\n",
                "        real_latents.append(real_latent.cpu())\n",
                "        gen_spikes.append(torch.poisson(rec_rates[-1]))\n",
                "        masks.append(latent_mask.cpu())\n",
                "\n",
                "    return {\n",
                "        \"real_latents\": torch.cat(real_latents),\n",
                "        \"gen_latents\": torch.cat(gen_latents),\n",
                "        \"real_signals\": torch.cat(real_signals),\n",
                "        \"original_rates\": torch.cat(original_rates),\n",
                "        \"gen_rates\": torch.cat(rec_rates),\n",
                "        \"gen_spikes\": torch.cat(gen_spikes),\n",
                "        \"masks\": torch.cat(masks),\n",
                "    }\n",
                "\n",
                "\n",
                "gen_dict = gen_spikes_conditional(ema_model, ae_model, val_latent_dataloader)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "gen_rates = []\n",
                "real_rates = []\n",
                "\n",
                "from matplotlib import pyplot as plt\n",
                "\n",
                "for i in range(gen_dict[\"original_rates\"].shape[0]):\n",
                "    mask = gen_dict[\"masks\"][i, 0].cpu()\n",
                "    gen_rate = gen_dict[\"gen_rates\"][i].cpu()\n",
                "    real_rate = gen_dict[\"original_rates\"][i].cpu()\n",
                "    # get the indices where mask is 1\n",
                "    indices = torch.where(mask == 1)[0]\n",
                "    # print([i for i in indices])\n",
                "    gen_rates.append(gen_rate[:, indices[0].item() : indices[-1].item() + 1])\n",
                "    real_rates.append(real_rate[:, indices[0].item() : indices[-1].item() + 1])\n",
                "\n",
                "for i in range(0, gen_dict[\"original_rates\"].shape[0], 20):\n",
                "    fig, ax = plt.subplots(1, 2, figsize=cm2inch(8, 4))\n",
                "    im = ax[0].imshow(real_rates[i].detach(), aspect=\"auto\", cmap=\"Greys\")\n",
                "    plt.colorbar(im, ax=ax[0])\n",
                "    im = ax[1].imshow(gen_rates[i].detach(), aspect=\"auto\", cmap=\"Greys\")\n",
                "    plt.colorbar(im, ax=ax[1])\n",
                "    fig.suptitle(f\"Real vs Generated rates, idx={i}\")\n",
                "    fig.tight_layout()\n",
                "    plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "len(train_latent_dataloader.dataset.original_sentences[0:10])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "real_spikes = latent_dataset_val.train_spikes.cpu()\n",
                "gen_spikes, gen_rates  = sample_spikes(ema_model, scheduler, ae_model, cfg, batch_size=real_spikes.shape[0] * 10, device=\"cuda\")\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "# plot reconstructed spikes\n",
                "plt.figure(figsize=cm2inch((6, 4)))\n",
                "bins = np.linspace(0, 20, 20)\n",
                "plt.hist(real_spikes.sum(2).flatten(), density=True, color='grey', bins=bins, alpha=0.5)\n",
                "plt.hist(gen_spikes.sum(2).flatten(), density=True, color='darkred', bins=bins, alpha=0.5)\n",
                "\n",
                "plt.legend(['gt', 'diffusion'])\n",
                "plt.title('spike count distribution (samples from diffusion)')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "gen_spikes"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import compute_spike_stats, counts_to_spike_trains\n",
                "\n",
                "\n",
                "spike_trains_gt = counts_to_spike_trains(rearrange(latent_dataset_train.train_spikes[:500], 'b c l -> b l c'), fps=10)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "spike_stats_gt = compute_spike_stats(spike_trains_gt, n_samples=latent_dataset_train.train_spikes.shape[0], n_neurons=latent_dataset_train.train_spikes.shape[1])\n",
                "\n",
                "\n",
                "spike_trains_diff = counts_to_spike_trains(rearrange(gen_spikes[:500], 'b c l -> b l c'), fps=10)\n",
                "spike_stats_diff = compute_spike_stats(spike_trains_diff, n_samples=gen_spikes.shape[0], n_neurons=gen_spikes.shape[1])\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "spike_stats_1 = spike_stats_gt\n",
                "spike_stats_2 = spike_stats_diff\n",
                "\n",
                "fig, axs = plt.subplots(4, 10, figsize=cm2inch(20, 6), dpi=300)\n",
                "axs = axs.T\n",
                "for n_i, neuron in enumerate(sorted_indices[:100:10]):\n",
                "    axs1 = axs[n_i]\n",
                "    for i, (key, value) in enumerate(spike_stats_1.items()):\n",
                "        ax = axs1[i]\n",
                "        ax.hist(spike_stats_2[key][:,neuron], bins=30, alpha=0.5, color='r')\n",
                "        ax.hist(value[:,neuron], bins=30, alpha=0.5, color='k')\n",
                "        # ensure a square aspect ratio\n",
                "        # ax.set_aspect(\"equal\", adjustable=\"box\")\n",
                "        # plot identity line\n",
                "        # ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls=\"--\", c=\"black\")\n",
                "        if n_i == 0:\n",
                "            ax.set_ylabel(key)\n",
                "        # ax.set_title(key)\n",
                "        if i == 3:\n",
                "            ax.set_xlabel(f\"Neuron {neuron}\")\n",
                "        ax.set_xticks([])\n",
                "        ax.set_yticks([])\n",
                "\n",
                "fig.tight_layout()\n",
                "fig.suptitle('Spike stats acoss neurons (diffusion is red, real is grey)', y=1.05)\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# plot only colorbar (horizonta, greys from 0 to 5, discretized)\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(0.5, 4), dpi=300)\n",
                "cmap = plt.cm.Greys\n",
                "norm = plt.Normalize(vmin=0, vmax=5)\n",
                "cb1 = matplotlib.colorbar.ColorbarBase(ax, cmap=cmap, norm=norm, orientation='vertical', boundaries=np.arange(0, 7) - 0.5, ticks=[0,5])\n",
                "cb1.set_label('spike count')\n",
                "plt.savefig(save_path + \"colorbar.pdf\", dpi=300)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "print(';skdjshf')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "timeseries",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.9.18"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
