{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import argparse\n",
                "import os\n",
                "import sys\n",
                "\n",
                "ANONAUTHOR = False\n",
                "\n",
                "if ANONAUTHOR:\n",
                "    os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
                "    # append parent directory to path (../notebooks -> ..)\n",
                "    sys.path.append(os.path.dirname(os.getcwd()))\n",
                "    os.chdir(os.path.dirname(os.getcwd()))\n",
                "\n",
                "else:\n",
                "    os.chdir('../')\n",
                "\n",
                "\n",
                "import accelerate\n",
                "import auraloss  # freq loss\n",
                "import lovely_tensors as lt\n",
                "import matplotlib.pyplot as plt\n",
                "import matplotlib\n",
                "import numpy as np\n",
                "import scipy.io as io\n",
                "import seaborn as sns\n",
                "import torch\n",
                "import torch.nn as nn\n",
                "import wandb\n",
                "import yaml\n",
                "from diffusers.optimization import get_scheduler\n",
                "from omegaconf import OmegaConf\n",
                "from scipy.signal import welch\n",
                "from tqdm.auto import tqdm\n",
                "from einops import rearrange\n",
                "\n",
                "from ntldm.data.lds import get_lds_dataset\n",
                "from ntldm.data.latent_attractor import get_attractor_dataset, LatentDataset\n",
                "from ntldm.data.lds import get_lds_dataset\n",
                "from ntldm.networks import S4AE, AutoEncoder, CountWrapper\n",
                "from ntldm.utils.plotting_utils import *\n",
                "from ntldm.losses import latent_regularizer_v2\n",
                "from ntldm.networks import Denoiser\n",
                "from diffusers.training_utils import EMAModel\n",
                "from diffusers.schedulers import DDPMScheduler\n",
                "# always run from ../ntldm\n",
                "\n",
                "\n",
                "lt.monkey_patch()\n",
                "matplotlib.rc_file('matplotlibrc')\n",
                "\n",
                "%load_ext autoreload\n",
                "%autoreload 2\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "#cfg_ae = OmegaConf.load(\"conf/sweeps_count/autoencoder-count_s4-Lorenz_z=4.yaml\")\n",
                "\n",
                "#cfg_ae = OmegaConf.load(\"conf/sweeps_count/autoencoder-count_s4-Lorenz_z=8_true_pointwise_decoder_with_test.yaml\")\n",
                "cfg_ae = OmegaConf.load(\"conf/sweeps_count/autoencoder-count_s4-Lorenz_z=8_new_regularization_03.yaml\")\n",
                "# CHANGE Cin to latetns   \n",
                "cfg_yaml = \"\"\"\n",
                "denoiser_model:\n",
                "  C_in: 8 \n",
                "  C: 64\n",
                "  kernel: s4\n",
                "  num_blocks: 4 \n",
                "  bidirectional: True\n",
                "  num_train_timesteps: 1000\n",
                "training:\n",
                "  lr: 0.001\n",
                "  weight_decay: 0.0\n",
                "  num_epochs: 1000\n",
                "  num_warmup_epochs: 50\n",
                "  batch_size: 512\n",
                "  random_seed: 42\n",
                "  precision: \"no\"\n",
                "exp_name: diffusion_s4-Lorenz_z=8_true-pointwise_decoder_new_regularization_03\n",
                "\"\"\"\n",
                "\n",
                "cfg = OmegaConf.create(yaml.safe_load(cfg_yaml))\n",
                "cfg.dataset = cfg_ae.dataset\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "cfg_ae"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "    \n",
                "ae_model = AutoEncoder(\n",
                "    C_in=cfg_ae.model.C_in,\n",
                "    C=cfg_ae.model.C,\n",
                "    C_latent=cfg_ae.model.C_latent,\n",
                "    L=cfg_ae.dataset.signal_length,\n",
                "    kernel=cfg_ae.model.kernel,\n",
                "    num_blocks=cfg_ae.model.num_blocks,\n",
                "    num_blocks_decoder=cfg_ae.model.num_blocks_decoder,\n",
                "    num_lin_per_mlp=cfg_ae.model.get(\"num_lin_per_mlp\", 2),  # default 2\n",
                ")\n",
                "\n",
                "ae_model = CountWrapper(ae_model, use_sin_enc=cfg_ae.model.get(\"use_sin_enc\", False))\n",
                "\n",
                "ae_model.load_state_dict(torch.load(f\"exp/{cfg_ae.exp_name}/model.pt\", map_location=\"cpu\"))\n",
                "\n",
                "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
                "\n",
                "ae_model = ae_model.to(device)\n",
                "\n",
                "train_dataloader, val_dataloader, test_dataloader = get_attractor_dataset(\n",
                "    system_name=cfg_ae.dataset.system_name,\n",
                "    n_neurons=cfg_ae.model.C_in,\n",
                "    sequence_length=cfg_ae.dataset.signal_length,\n",
                "    noise_std=0.05,\n",
                "    n_ic=cfg_ae.dataset.n_ic,\n",
                "    mean_spike_count=cfg_ae.dataset.mean_rate * cfg.dataset.signal_length,\n",
                "    train_frac=cfg_ae.dataset.split_frac_train,\n",
                "    valid_frac=cfg_ae.dataset.split_frac_val, # test is 1 - train - valid\n",
                "    random_seed=cfg_ae.training.random_seed,\n",
                "    batch_size=cfg_ae.training.batch_size,\n",
                "    softplus_beta=cfg_ae.dataset.get(\"softplus_beta\", 2.0),\n",
                ")\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "ae_model"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Set up the accelerator and the latent dataset"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "# check if signal length is power of 2\n",
                "if cfg.dataset.signal_length & (cfg.dataset.signal_length - 1) != 0:\n",
                "    cfg.training.precision = \"no\"  # torch.fft doesnt support half if L!=2^x\n",
                "\n",
                "\n",
                "accelerator = accelerator = accelerate.Accelerator(\n",
                "    mixed_precision=cfg.training.precision,\n",
                "    log_with=\"wandb\",\n",
                ")\n",
                "#\n",
                "\n",
                "\n",
                "# prepare the ae model and dataset\n",
                "(\n",
                "    ae_model,\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ") = accelerator.prepare(\n",
                "    ae_model,\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ")\n",
                "\n",
                "\n",
                "# create the latent dataset\n",
                "latent_dataset_train = LatentDataset(train_dataloader, ae_model)\n",
                "latent_dataset_val = LatentDataset(\n",
                "    val_dataloader,\n",
                "    ae_model,\n",
                "    latent_means=latent_dataset_train.latent_means,\n",
                "    latent_stds=latent_dataset_train.latent_stds,\n",
                ")\n",
                "latent_dataset_test = LatentDataset(\n",
                "    test_dataloader,\n",
                "    ae_model,\n",
                "    latent_means=latent_dataset_train.latent_means,\n",
                "    latent_stds=latent_dataset_train.latent_stds,\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "print(\"latent dataset\", latent_dataset_train.latents)\n",
                "print(\"latent dataset means\", latent_dataset_train.latent_means)\n",
                "print(\"latent dataset stds\", latent_dataset_train.latent_stds)\n",
                "plt.figure(figsize=cm2inch(5, 3))\n",
                "hist = plt.hist(latent_dataset_train.latents[:100].flatten(), bins=200)\n",
                "hist = plt.hist(latent_dataset_val.latents[:100].flatten(), bins=200)\n",
                "hist = plt.hist(latent_dataset_test.latents[:100].flatten(), bins=200)\n",
                "\n",
                "plt.title(\"Latent dataset histogram\")\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import evaluate_autoencoder\n",
                "\n",
                "\n",
                "save_path = 'exp/'+cfg_ae.exp_name\n",
                "save_path = '/home/anonauthor/anonloc1/results/projects/latent-diffusion/LORENZ/figures/'\n",
                "\n",
                "_1, _2 , test_dataloader_longer = get_attractor_dataset(\n",
                "    system_name=cfg_ae.dataset.system_name,\n",
                "    n_neurons=cfg_ae.model.C_in,\n",
                "    sequence_length=cfg_ae.dataset.signal_length * 4,\n",
                "    noise_std=0.05,\n",
                "    n_ic=cfg_ae.dataset.n_ic,\n",
                "    mean_spike_count=cfg_ae.dataset.mean_rate * cfg.dataset.signal_length * 4,\n",
                "    train_frac=cfg_ae.dataset.split_frac_train,\n",
                "    valid_frac=cfg_ae.dataset.split_frac_val, # test is 1 - train - valid\n",
                "    random_seed=cfg_ae.training.random_seed,\n",
                "    batch_size=cfg_ae.training.batch_size//16,\n",
                "    softplus_beta=cfg_ae.dataset.get(\"softplus_beta\", 2.0),\n",
                ")\n",
                "\n",
                "test_dataloader_longer = accelerator.prepare(test_dataloader_longer)\n",
                "\n",
                "evaluate_autoencoder(ae_model, test_dataloader, test_dataloader_longer, n_latents=8, save=True, save_path=save_path, idx=10, indices=[1,3,7])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save_path = '/home/anonauthor/anonloc1/results/projects/latent-diffusion/LORENZ/figures/'\n",
                "\n",
                "evaluate_autoencoder(ae_model, test_dataloader, test_dataloader_longer, n_latents=8,\n",
                "                      save=True, save_path=save_path, idx=10, indices=[6,4,3])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "denoiser = Denoiser(\n",
                "    C_in=cfg.denoiser_model.C_in,\n",
                "    C=cfg.denoiser_model.C,\n",
                "    L=cfg.dataset.signal_length,\n",
                "    kernel=cfg.denoiser_model.kernel,\n",
                "    num_blocks=cfg.denoiser_model.num_blocks,\n",
                "    bidirectional=cfg.denoiser_model.get(\"bidirectional\", True),\n",
                ")\n",
                "\n",
                "scheduler = DDPMScheduler(\n",
                "    num_train_timesteps=cfg.denoiser_model.num_train_timesteps,\n",
                "    clip_sample=False,\n",
                "    beta_schedule=\"linear\", # ddpm doesnt support cosine\n",
                ")\n",
                "\n",
                "\n",
                "\n",
                "optimizer = torch.optim.AdamW(\n",
                "    denoiser.parameters(), lr=cfg.training.lr\n",
                ")  # default wd=0.01 for now\n",
                "\n",
                "train_latent_dataloader = torch.utils.data.DataLoader(\n",
                "    latent_dataset_train,\n",
                "    batch_size=cfg.training.batch_size,\n",
                "    shuffle=True,\n",
                "    num_workers=4,\n",
                "    pin_memory=True,\n",
                ")\n",
                "\n",
                "val_latent_dataloader = torch.utils.data.DataLoader(\n",
                "    latent_dataset_val,\n",
                "    batch_size=cfg.training.batch_size,\n",
                "    shuffle=False,\n",
                "    num_workers=4,\n",
                "    pin_memory=True,\n",
                ")\n",
                "\n",
                "test_latent_dataloader = torch.utils.data.DataLoader(\n",
                "    latent_dataset_test,\n",
                "    batch_size=cfg.training.batch_size,\n",
                "    shuffle=False,\n",
                "    num_workers=4,\n",
                "    pin_memory=True,\n",
                ")\n",
                "\n",
                "num_batches = len(train_latent_dataloader)\n",
                "lr_scheduler = get_scheduler(\n",
                "    name=\"cosine\",\n",
                "    optimizer=optimizer,\n",
                "    num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "    num_training_steps=num_batches * cfg.training.num_epochs * 1.3,  # total number of steps\n",
                ")\n",
                "\n",
                "# check if signal length is power of 2\n",
                "if cfg.dataset.signal_length & (cfg.dataset.signal_length - 1) != 0:\n",
                "    cfg.training.precision = \"no\"  # torch.fft doesnt support half if L!=2^x\n",
                "\n",
                "# prepare the denoiser model and dataset\n",
                "(\n",
                "    denoiser,\n",
                "    train_latent_dataloader,\n",
                "    val_latent_dataloader,\n",
                "    test_latent_dataloader,\n",
                "    lr_scheduler,\n",
                ") = accelerator.prepare(\n",
                "    denoiser,\n",
                "    train_latent_dataloader,\n",
                "    val_latent_dataloader,\n",
                "    test_latent_dataloader,\n",
                "    lr_scheduler,\n",
                ")\n",
                "\n",
                "ema_model = EMAModel(denoiser)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "ae_model, denoiser"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def sample(\n",
                "    ema_denoiser,\n",
                "    scheduler,\n",
                "    cfg,\n",
                "    batch_size=1,\n",
                "    generator=None,\n",
                "    device=\"cuda\",\n",
                "    signal_length=None\n",
                "):  \n",
                "    if signal_length is None:\n",
                "        signal_length = cfg.dataset.signal_length\n",
                "    z_t = torch.randn(\n",
                "        (batch_size, cfg.denoiser_model.C_in, signal_length)\n",
                "    ).to(device)\n",
                "    ema_denoiser_avg = ema_denoiser.averaged_model\n",
                "    ema_denoiser_avg.eval()\n",
                "\n",
                "    scheduler.set_timesteps(cfg.denoiser_model.num_train_timesteps)\n",
                "\n",
                "    for t in tqdm(scheduler.timesteps, desc=\"Sampling DDPM\"):\n",
                "        with torch.no_grad():\n",
                "            model_output = ema_denoiser_avg(\n",
                "                z_t, torch.tensor([t] * batch_size).to(device).long()\n",
                "            )\n",
                "        z_t = scheduler.step(\n",
                "            model_output, t, z_t, generator=generator, return_dict=False\n",
                "        )[0]\n",
                "\n",
                "    return z_t\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# train the diffusion model"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
                "\n",
                "\n",
                "pbar = tqdm(range(cfg.training.num_epochs), desc=\"epochs\")\n",
                "for epoch in pbar:\n",
                "    for i, batch in enumerate(train_latent_dataloader):\n",
                "\n",
                "        optimizer.zero_grad()\n",
                "\n",
                "        z = batch\n",
                "        t = torch.randint(\n",
                "            0, cfg.denoiser_model.num_train_timesteps, (z.shape[0],), device=\"cpu\"\n",
                "        ).long()\n",
                "        # print(z.shape, t.shape)\n",
                "        noise = torch.randn_like(z)\n",
                "        noisy_z = scheduler.add_noise(z, noise, t)\n",
                "        noise_pred = denoiser(noisy_z, t)\n",
                "\n",
                "        loss = torch.nn.functional.mse_loss(noise, noise_pred)\n",
                "        accelerator.backward(loss)\n",
                "        accelerator.clip_grad_norm_(denoiser.parameters(), 1.0)\n",
                "\n",
                "        optimizer.step()\n",
                "        lr_scheduler.step()\n",
                "\n",
                "        if i % 10 == 0:\n",
                "            pbar.set_postfix({\"loss\": loss.item(), \"lr\": lr_scheduler.get_last_lr()[0]})\n",
                "\n",
                "        ema_model.step(denoiser)\n",
                "\n",
                "    if (epoch) % 100 == 0: # plot samples\n",
                "\n",
                "        sampled_latents = sample(\n",
                "            ema_denoiser=ema_model, scheduler=scheduler, cfg=cfg, batch_size=2, device=\"cuda\"\n",
                "        )\n",
                "        sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "            sampled_latents.device\n",
                "        ) + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "        with torch.no_grad():\n",
                "            sampled_rates = ae_model.decode(sampled_latents).cpu()\n",
                "\n",
                "        fig, ax = plt.subplots(1, 2, figsize=cm2inch(12,4))\n",
                "        im = ax[0].imshow(sampled_rates[0], aspect='auto')\n",
                "        ax[0].set_title(\"Sampled rates\")\n",
                "        fig.colorbar(im, ax=ax[0], orientation='vertical', fraction=0.046, pad=0.04)\n",
                "\n",
                "        im = ax[1].imshow(train_dataloader.dataset.dataset[0][\"rates\"], aspect='auto')\n",
                "        ax[1].set_title(\"Real rates\")\n",
                "        fig.colorbar(im, ax=ax[1], orientation='vertical', fraction=0.046, pad=0.04)\n",
                "        fig.tight_layout()\n",
                "        plt.show()\n",
                "\n",
                "pbar.close()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "cfg.exp_name"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save_new_model = False\n",
                "load_model = True\n",
                "\n",
                "if save_new_model:\n",
                "\n",
                "    # save model and config file\n",
                "    if accelerator.is_main_process:\n",
                "        os.makedirs(f\"exp/{cfg.exp_name}\", exist_ok=True)\n",
                "        torch.save(accelerator.unwrap_model(denoiser).state_dict(), f\"exp/{cfg.exp_name}/model.pt\")\n",
                "        with open(f\"conf/sweeps_count/{cfg.exp_name}.yaml\", \"w\") as f:\n",
                "            f.write(OmegaConf.to_yaml(cfg))\n",
                "    print('saved model to ', cfg.exp_name)\n",
                "            \n",
                "elif load_model:\n",
                "    # load the congig and model path\n",
                "    with open(f\"conf/sweeps_count/{cfg.exp_name}.yaml\") as f:\n",
                "        cfg = OmegaConf.create(yaml.safe_load(f))\n",
                "\n",
                "    denoiser = Denoiser(\n",
                "        C_in=cfg.denoiser_model.C_in,\n",
                "        C=cfg.denoiser_model.C,\n",
                "        L=cfg.dataset.signal_length,\n",
                "        kernel=cfg.denoiser_model.kernel,\n",
                "        num_blocks=cfg.denoiser_model.num_blocks,\n",
                "        bidirectional=cfg.denoiser_model.get(\"bidirectional\", True),\n",
                "    )\n",
                "\n",
                "    denoiser.load_state_dict(torch.load(f\"exp/{cfg.exp_name}/model.pt\", map_location=\"cpu\"))\n",
                "\n",
                "    scheduler = DDPMScheduler(\n",
                "        num_train_timesteps=cfg.denoiser_model.num_train_timesteps,\n",
                "        clip_sample=False,\n",
                "        beta_schedule=\"linear\", # ddpm doesnt support cosine\n",
                "    )\n",
                "\n",
                "\n",
                "\n",
                "    optimizer = torch.optim.AdamW(\n",
                "        denoiser.parameters(), lr=cfg.training.lr\n",
                "    )  # default wd=0.01 for now\n",
                "\n",
                "    train_latent_dataloader = torch.utils.data.DataLoader(\n",
                "        latent_dataset_train,\n",
                "        batch_size=cfg.training.batch_size,\n",
                "        shuffle=True,\n",
                "        num_workers=4,\n",
                "        pin_memory=True,\n",
                "    )\n",
                "\n",
                "    val_latent_dataloader = torch.utils.data.DataLoader(\n",
                "        latent_dataset_val,\n",
                "        batch_size=cfg.training.batch_size,\n",
                "        shuffle=False,\n",
                "        num_workers=4,\n",
                "        pin_memory=True,\n",
                "    )\n",
                "    test_latent_dataloader = torch.utils.data.DataLoader(\n",
                "        latent_dataset_test,\n",
                "        batch_size=cfg.training.batch_size,\n",
                "        shuffle=False,\n",
                "        num_workers=4,\n",
                "        pin_memory=True,\n",
                "    )\n",
                "\n",
                "    num_batches = len(train_latent_dataloader)\n",
                "    lr_scheduler = get_scheduler(\n",
                "        name=\"cosine\",\n",
                "        optimizer=optimizer,\n",
                "        num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "        num_training_steps=num_batches * cfg.training.num_epochs * 1.3,  # total number of steps\n",
                "    )\n",
                "\n",
                "    # check if signal length is power of 2\n",
                "    if cfg.dataset.signal_length & (cfg.dataset.signal_length - 1) != 0:\n",
                "        cfg.training.precision = \"no\"  # torch.fft doesnt support half if L!=2^x\n",
                "\n",
                "    # prepare the denoiser model and dataset\n",
                "    (\n",
                "        denoiser,\n",
                "        train_latent_dataloader,\n",
                "        val_latent_dataloader,\n",
                "        test_latent_dataloader,\n",
                "        lr_scheduler,\n",
                "    ) = accelerator.prepare(\n",
                "        denoiser,\n",
                "        train_latent_dataloader,\n",
                "        val_latent_dataloader,\n",
                "        test_latent_dataloader,\n",
                "        lr_scheduler,\n",
                "    )\n",
                "\n",
                "    ema_model = EMAModel(denoiser)\n",
                "\n",
                "        \n",
                "sampled_latents = sample(\n",
                "    ema_denoiser=ema_model, scheduler=scheduler, cfg=cfg, batch_size=2, device=\"cuda\", signal_length=cfg.dataset.signal_length * 4\n",
                ")\n",
                "sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "    sampled_latents.device\n",
                ") + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "with torch.no_grad():\n",
                "    sampled_rates = ae_model.decode(sampled_latents).cpu()\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(12,4))\n",
                "im = ax.imshow(sampled_rates[0], aspect='auto')\n",
                "ax.set_title(\"Sampled rates\")\n",
                "fig.colorbar(im, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)\n",
                "fig.tight_layout()\n",
                "plt.show()\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_3d_latent_trajectory_direct\n",
                "     \n",
                "sampled_latents = sample(\n",
                "    ema_denoiser=ema_model, scheduler=scheduler, cfg=cfg, batch_size=10, device=\"cuda\", signal_length=cfg.dataset.signal_length * 16\n",
                ")\n",
                "sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "    sampled_latents.device\n",
                ") + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "with torch.no_grad():\n",
                "    sampled_rates = ae_model.decode(sampled_latents).cpu()\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(12,4))\n",
                "im = ax.imshow(sampled_rates[0], aspect='auto')\n",
                "ax.set_title(\"Sampled rates\")\n",
                "fig.colorbar(im, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)\n",
                "fig.tight_layout()\n",
                "plt.show()\n",
                "plot_3d_latent_trajectory_direct(sampled_latents, cmap=\"Reds\")\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "def plot_3d_latent_trajectory_direct(\n",
                "    z,\n",
                "    figsize=(6, 6),\n",
                "    sample_idx=0,\n",
                "    indices=[0, 1, 2],\n",
                "    save=False,\n",
                "    save_path=None,\n",
                "    cmap='viridis'\n",
                "):\n",
                "    from mpl_toolkits.mplot3d import Axes3D\n",
                "    from matplotlib.animation import FuncAnimation\n",
                "    from matplotlib import cm\n",
                "\n",
                "    # Create a figure and an axis\n",
                "    fig = plt.figure(figsize=figsize)\n",
                "    ax = fig.add_subplot(111, projection=\"3d\")\n",
                "\n",
                "    z1 = z[sample_idx, indices[0], :].cpu().numpy()  # (1024,)\n",
                "    z2 = z[sample_idx, indices[1], :].cpu().numpy()\n",
                "    z3 = z[sample_idx, indices[2], :].cpu().numpy()\n",
                "\n",
                "    # Initialize the scatter plot with correct color dimensions\n",
                "    if cmap == 'viridis':\n",
                "        colors = cm.viridis(\n",
                "            np.linspace(0, 1, len(z1))\n",
                "        )  # Ensure colors are mapped from a colormap\n",
                "    elif cmap == 'Reds':\n",
                "        colors = cm.Reds(\n",
                "            np.linspace(0, 1, len(z1))\n",
                "        )\n",
                "    else:\n",
                "        colors = cm.viridis(\n",
                "            np.linspace(0, 1, len(z1))\n",
                "        )  # Ensure colors are mapped from a colormap\n",
                "    scatter = ax.scatter(z1, z2, z3, c=colors, s=1)\n",
                "    line, = ax.plot(z1, z2, z3, color=\"k\", alpha=0.3)\n",
                "    # set x y and z label names\n",
                "    ax.set_xlabel(f\"lat {indices[0]}\")\n",
                "    ax.set_ylabel(f\"lat {indices[1]}\")\n",
                "    ax.set_zlabel(f\"lat {indices[2]}\")\n",
                "    ax.locator_params(nbins=4)\n",
                "    plt.tight_layout()\n",
                "    if save and save_path is not None:\n",
                "        plt.savefig(save_path + \".png\")\n",
                "        plt.savefig(save_path + \".pdf\")\n",
                "    plt.show()\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save_name = save_path + 'latent_trajectory/'\n",
                "os.makedirs(save_name, exist_ok=True)\n",
                "plot_3d_latent_trajectory_direct(sampled_latents,cmap=\"Reds\",\n",
                "save=True, save_path=save_name+'trajectory_long', sample_idx=4, figsize=cm2inch(8,8),  indices=[4,5,6])\n",
                "plot_3d_latent_trajectory_direct(sampled_latents[:,:,:cfg.dataset.signal_length], cmap=\"Reds\",\n",
                "save=True, save_path=save_name+'trajectory_short', sample_idx=4, figsize=cm2inch(8,8),  indices=[4,5,6])\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save_name = save_path + '/latent_trajectory/'\n",
                "os.makedirs(save_name, exist_ok=True)\n",
                "plot_3d_latent_trajectory_direct(sampled_latents,cmap=\"Reds\",\n",
                "save=True, save_path=save_name+'trajectory_long', sample_idx=0, figsize=cm2inch(8,8))\n",
                "plot_3d_latent_trajectory_direct(sampled_latents[:,:,:cfg.dataset.signal_length], cmap=\"Reds\",\n",
                "save=True, save_path=save_name+'trajectory_short', sample_idx=0, figsize=cm2inch(8,8))\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Evaluation Diffusion on latents"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "true_data = False\n",
                "n_siglen = 1\n",
                "sample_cutoff = int(10e5)\n",
                "\n",
                "ae_model.eval()\n",
                "\n",
                "ae_rates = []\n",
                "ae_latents = []\n",
                "diffusion_rates = []\n",
                "diffusion_latents = []\n",
                "gt_spikes = []\n",
                "\n",
                "if not true_data:\n",
                "    gt_rates = []\n",
                "    gt_latents = []\n",
                "\n",
                "\n",
                "count = 0\n",
                "# autoencoder eval \n",
                "for batch in train_dataloader:\n",
                "    signal = batch[\"signal\"]\n",
                "    with torch.no_grad():\n",
                "        output_rates, latent = ae_model(signal)\n",
                "        #output_rates = ae_model(signal)[0].cpu()\n",
                "    ae_rates.append(output_rates.cpu())\n",
                "    ae_latents.append(latent.cpu())\n",
                "    gt_spikes.append(signal.cpu())\n",
                "    if not true_data:\n",
                "        gt_rates.append(batch[\"rates\"].cpu())\n",
                "        gt_latents.append(batch[\"latents\"].cpu())\n",
                "    count += 1\n",
                "    # if count > 1:\n",
                "    #     break\n",
                "\n",
                "# concatenate along batch dimension\n",
                "ae_rates = torch.cat(ae_rates, dim=0)\n",
                "ae_latents = torch.cat(ae_latents, dim=0)\n",
                "gt_spikes = torch.cat(gt_spikes, dim=0)\n",
                "if not true_data:\n",
                "    gt_rates = torch.cat(gt_rates, dim=0)\n",
                "    gt_latents = torch.cat(gt_latents, dim=0)\n",
                "    \n",
                "\n",
                "\n",
                "# diffusion eval\n",
                "sampled_latents = sample(\n",
                "    ema_denoiser=ema_model, scheduler=scheduler, cfg=cfg, batch_size=ae_rates.shape[0], device=\"cuda\", signal_length=cfg.dataset.signal_length * n_siglen\n",
                ")\n",
                "\n",
                "# project back to non standardized space\n",
                "sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "    sampled_latents.device\n",
                ") + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "\n",
                "with torch.no_grad():\n",
                "    sampled_rates = ae_model.decode(sampled_latents).cpu()\n",
                "    \n",
                "diffusion_rates.append(sampled_rates)\n",
                "diffusion_latents.append(sampled_latents.cpu())\n",
                "\n",
                "# concatenate along batch dimension\n",
                "diffusion_rates = torch.cat(diffusion_rates, dim=0)\n",
                "diffusion_latents = torch.cat(diffusion_latents, dim=0)\n",
                "\n",
                "vecs = [ae_rates, ae_latents, diffusion_rates, diffusion_latents, gt_spikes]\n",
                "\n",
                "if not true_data:\n",
                "    vecs.append(gt_rates)\n",
                "    vecs.append(gt_latents)\n",
                "\n",
                "vecs = [vec.cpu().numpy() for vec in vecs]\n",
                "vecs = [vec[:sample_cutoff] for vec in vecs]\n",
                "vecs = [rearrange(vec, 'b n t -> b t n') for vec in vecs]\n",
                "\n",
                "\n",
                "if not true_data:\n",
                "    ae_rates, ae_latents, diffusion_rates, diffusion_latents, gt_spikes, gt_rates, gt_latents = vecs\n",
                "else:\n",
                "    ae_rates, ae_latents, diffusion_rates, diffusion_latents, gt_spikes = vecs\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# generate figures"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "color_dict = {\n",
                "    \"ae\": \"midnightblue\",##2d3047\",\n",
                "    \"ael\": \"royalblue\",#\"cornflowerblue\",#6ca6c1\",\n",
                "    \"diff\": \"darkred\",#A44A3F\",\n",
                "    \"diffl\": \"orangered\", #84271F\",\n",
                "    # \"sim\": \"#afcb90\",\n",
                "    # \"siml\": \"#495f41\",\n",
                "    \"gt\": \"darkgrey\",\n",
                "    \"gtl\": \"#808080\",\n",
                "}\n",
                "lab_dict = {\n",
                "    \"ae\": \"ae\",\n",
                "    \"ael\": \"latents ae\",\n",
                "    \"diff\": \"diffusion\",\n",
                "    \"diffl\": \"latents diffusion\",\n",
                "    \"gt\": \"gt\",\n",
                "    \"gtl\": \"latents gt\",\n",
                "}\n",
                "\n",
                "# plot all colors \n",
                "plt.figure(figsize=cm2inch(8, 5))\n",
                "for i, (key, color) in enumerate(color_dict.items()):\n",
                "    plt.plot(ae_rates[0, :, i*4], color=color, label=lab_dict[key])\n",
                "plt.legend()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.utils import l2_distances\n",
                "samp_dists, _, _, _ = l2_distances(\n",
                "    torch.from_numpy(np.float32(diffusion_latents)),\n",
                "    torch.from_numpy(np.float32(diffusion_latents)),\n",
                ")\n",
                "real_dists, _, _, _ = l2_distances(\n",
                "    torch.from_numpy(np.float32(ae_latents)),\n",
                "    torch.from_numpy(np.float32(ae_latents)),\n",
                ")\n",
                "\n",
                "dists, close_ids, ex1, ex2 = l2_distances(\n",
                "    torch.from_numpy(np.float32(diffusion_latents)), torch.from_numpy(np.float32(ae_latents))\n",
                ")\n",
                "\n",
                "plt.plot(np.sort(dists.numpy().flatten()), label='distance real train')\n",
                "plt.plot(np.sort(real_dists.numpy().flatten()), label='distance within train')\n",
                "plt.plot(np.sort(samp_dists.numpy().flatten()), label='distance within samples')\n",
                "plt.vlines(len(diffusion_latents), 0.0, 70.0, colors=\"black\", linestyles=\"dashed\")\n",
                "plt.xscale(\"log\")\n",
                "plt.legend()\n",
                "plt.show()\n",
                "\n",
                "\n",
                "samp_dists, _, _, _ = l2_distances(\n",
                "    torch.from_numpy(np.float32(diffusion_rates)),\n",
                "    torch.from_numpy(np.float32(diffusion_rates)),\n",
                ")\n",
                "real_dists, _, _, _ = l2_distances(\n",
                "    torch.from_numpy(np.float32(ae_rates)),\n",
                "    torch.from_numpy(np.float32(ae_rates)),\n",
                ")\n",
                "\n",
                "dists, close_ids, ex1, ex2 = l2_distances(\n",
                "    torch.from_numpy(np.float32(diffusion_rates)), torch.from_numpy(np.float32(ae_rates))\n",
                ")\n",
                "\n",
                "plt.plot(np.sort(dists.numpy().flatten()), label='distance real train')\n",
                "plt.plot(np.sort(real_dists.numpy().flatten()), label='distance within train')\n",
                "plt.plot(np.sort(samp_dists.numpy().flatten()), label='distance within samples')\n",
                "plt.vlines(len(diffusion_latents), 0.0, 70.0, colors=\"black\", linestyles=\"dashed\")\n",
                "plt.xscale(\"log\")\n",
                "plt.legend()\n",
                "plt.show()\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "ae_rates.shape, ae_latents.shape, diffusion_rates.shape, diffusion_latents.shape, gt_spikes.shape"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save_path = 'exp/'+cfg.exp_name+'/ae_diffusion_comparison/'\n",
                "save_path = '/home/anonauthor/anonloc1/results/projects/latent-diffusion/LORENZ/ae_diffusion_comparison/'\n",
                "os.makedirs(save_path, exist_ok=True)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "fps=1\n",
                "plot_n_channel_sd(gt_rates,\n",
                "                  diffusion_rates,\n",
                "                  channels=[0, 10, 30, 50, 80, 100, 120, 127], fps=fps,\n",
                "                  save=True, save_path=save_path+'rate_power_spec_density_diffusion_gt',\n",
                "                  colors=[color_dict['gt'], color_dict['diff']], \n",
                "                  labels=['gt', 'diff'], ystack=2, figsize=cm2inch(14, 6.5), lw=1)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "fps=1\n",
                "plot_n_channel_sd(gt_rates,\n",
                "                  ae_rates,\n",
                "                  channels=[0, 10, 30, 50, 80, 100, 120, 127], fps=fps,\n",
                "                  save=True, save_path=save_path+'rate_power_spec_density',\n",
                "                  colors=[color_dict['gt'], color_dict['ae']], \n",
                "                  labels=['gt', 'ae'], ystack=2, figsize=cm2inch(14, 6.5), lw=1)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "channels = np.random.choice(range(8),8,replace=False)\n",
                "\n",
                "plot_n_channel_sd(ae_latents,diffusion_latents,\n",
                "                  channels=channels,\n",
                "                  fps=fps, save=True, save_path=save_path+'latents_power_spec_density_diff_ae',\n",
                "                  colors=[color_dict['ael'], color_dict['diffl']],\n",
                "                  labels=['ae', 'diff'], ystack=2, figsize=cm2inch(14, 6.5), lw=1)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plot_n_channel_sd( ae_rates,diffusion_rates,\n",
                "                channels=[0, 10, 30, 50, 80, 100, 120, 127],\n",
                "                  fps=fps,  save=True, save_path=save_path+'rates_power_spec_density_diff_ae',\n",
                "                  colors=[ color_dict['ae'], color_dict['diff']],\n",
                "                  labels=['ae', 'diff'],ystack=2, figsize=cm2inch(14, 6.5), lw=1)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def group_neurons_by_corr(data, num_groups=4):\n",
                "    \"\"\"Group neurons by their overall correlation.\"\"\"\n",
                "    C_mat = correlation_matrix(gt_spikes, mode=\"concatenate\")\n",
                "    np.fill_diagonal(C_mat, 0)\n",
                "    # sum up the square of the correlations\n",
                "    summed_sq_corr = np.sum(C_mat ** 2, axis=0)\n",
                "    sorted_indices = np.argsort(-summed_sq_corr)\n",
                "    \n",
                "    # group size \n",
                "    group_size = len(sorted_indices) // num_groups\n",
                "    # split the indices into groups \n",
                "    groups = [sorted_indices[i * group_size:(i + 1) * group_size] for i in range(num_groups)]\n",
                "    return groups\n",
                "\n",
                "grouped_neurons = group_neurons_by_corr(gt_spikes, num_groups=4)\n",
                "\n",
                "grouped_neurons"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import group_neurons_temp_corr, get_temp_corr_summary\n",
                "from ntldm.utils.plotting_utils import plot_temp_corr_summary\n",
                "import pickle\n",
                "# group neurons acccording to their correlation to make pairwise correlation plots more efficient to compute\n",
                "# calculate the time lagged auto correlation and cross correlation for each group\n",
                "# set correlation of neuron with itself to 0 at lag 0\n",
                "calc_again=False\n",
                "\n",
                "if calc_again:\n",
                "    ae_spikes = np.random.poisson(ae_rates)\n",
                "    diff_spikes = np.random.poisson(diffusion_rates)\n",
                "\n",
                "    groups = group_neurons_temp_corr(gt_spikes.transpose(1,0,2), num_groups=4)\n",
                "    cross_corr_groups, auto_corr_groups = get_temp_corr_summary(gt_spikes.transpose(1,0,2), groups, nlags=30,mode='biased',\n",
                "                                                                batch_first=False)\n",
                "    fig_cross, fig_auto = plot_temp_corr_summary(cross_corr_groups, auto_corr_groups, name='Data')\n",
                "\n",
                "    cross_corr_groups_sampled, auto_corr_groups_sampled = get_temp_corr_summary(ae_spikes.transpose(1,0,2), groups, nlags=30, mode='biased',\n",
                "                                                                batch_first=False)\n",
                "    fig_cross_sampled, fig_auto_sampled = plot_temp_corr_summary(cross_corr_groups_sampled, auto_corr_groups_sampled, name='AE Samples')\n",
                "\n",
                "\n",
                "    cross_corr_groups_diff, auto_corr_groups_diff = get_temp_corr_summary(diff_spikes.transpose(1,0,2), groups, nlags=30, mode='biased',\n",
                "                                                                batch_first=False)\n",
                "    fig_cross_sampled, fig_auto_sampled = plot_temp_corr_summary(cross_corr_groups_diff, auto_corr_groups_diff, name='Diffusion Samples')\n",
                "\n",
                "\n",
                "else:\n",
                "    cross_corr_groups, auto_corr_groups = pickle.load(open(save_path+'cross_corr_groups.pkl', 'rb'))\n",
                "    cross_corr_groups_sampled, auto_corr_groups_sampled = pickle.load(open(save_path+'cross_corr_groups_sampled.pkl', 'rb'))\n",
                "    cross_corr_groups_diff, auto_corr_groups_diff = pickle.load(open(save_path+'cross_corr_groups_diff.pkl', 'rb')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# pickle all of the results\n",
                "import pickle\n",
                "with open(save_path+'cross_corr_groups.pkl', 'wb') as f:\n",
                "    pickle.dump([cross_corr_groups, auto_corr_groups], f)\n",
                "with open(save_path+'cross_corr_groups_sampled.pkl', 'wb') as f:\n",
                "    pickle.dump([cross_corr_groups_sampled, auto_corr_groups_sampled], f)\n",
                "with open(save_path+'cross_corr_groups_diff.pkl', 'wb') as f:\n",
                "    pickle.dump([cross_corr_groups_diff, auto_corr_groups_diff], f)\n",
                "    \n",
                "    "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_cross_corr_summary\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    save=True,\n",
                "    save_path=save_path + \"cross_corr_gt\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    save=True,\n",
                "    save_path=save_path + \"cross_corr_diffusion\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_sampled,\n",
                "    name=\"ae\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Blues\",\n",
                "    save=True,\n",
                "    save_path=save_path + \"cross_corr_ae\",\n",
                ")\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8, 6))\n",
                "save = True\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=3,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=3,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_sampled,\n",
                "    name=\"ae\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Blues\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"ae \",\n",
                "    ncol=3,\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"cross_corr_all.png\")\n",
                "    fig.savefig(save_path + \"cross_corr_all.pdf\")\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(6, 4))\n",
                "save = True\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=2,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=2,\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"cross_corr_gt_diff.png\")\n",
                "    fig.savefig(save_path + \"cross_corr_gt_diff.pdf\")\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8, 6))\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=3,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=3,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups_sampled,\n",
                "    name=\"ae\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Blues\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"ae \",\n",
                "    ncol=3,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"auto_corr_all.png\")\n",
                "    fig.savefig(save_path + \"auto_corr_all.pdf\")\n",
                "    \n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(6, 4))\n",
                "save = True\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=2,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=2,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"auto_corr_gt_diff.png\")\n",
                "    fig.savefig(save_path + \"auto_corr_gt_diff.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save=True\n",
                "# Define custom function to create a colormap with specified colors\n",
                "def create_custom_cmap(base_cmap, num_colors):\n",
                "    \"\"\"Create a custom colormap with num_colors from a base colormap.\"\"\"\n",
                "    cmap = plt.cm.get_cmap(base_cmap)\n",
                "    colors = [cmap(i) for i in range(cmap.N)]\n",
                "    return cmap.from_list('Custom cmap', colors, num_colors)\n",
                "\n",
                "# Generate custom colormaps for different shades of grey, blue, and another shade of blue\n",
                "num_shades = 20\n",
                "grey_cmap = create_custom_cmap('Greys', num_shades)\n",
                "red_cmap = create_custom_cmap('Reds', num_shades)\n",
                "blue_cmap = create_custom_cmap('Blues', num_shades)  # Adjust base cmap as needed\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(3, 1, figsize=cm2inch(8, 8), sharex=True)\n",
                "for i in range(20):\n",
                "    ax[0].plot(gt_rates[0,:, i], color=grey_cmap(i), label=f\"gt {i}\")\n",
                "    ax[1].plot(ae_rates[0,:, i], color=blue_cmap(i), label=f\"ae {i}\")\n",
                "\n",
                "    ax[2].plot(diffusion_rates[0,:, i], color=red_cmap(i), label=f\"diffusion {i}\")\n",
                "\n",
                "\n",
                "ax[0].set_ylabel(\"gt rates\")\n",
                "ax[1].set_ylabel(\"ae rates\")\n",
                "ax[2].set_ylabel(\"diffusion\")\n",
                "\n",
                "plt.tight_layout()\n",
                "plt.xlabel(\"time (a.u.)\")\n",
                "\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"supp_fig_rates_all.png\")\n",
                "    fig.savefig(save_path + \"supp_fig_rates_all.pdf\")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Plot population spike histogram"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_population_spike_histogram\n",
                "plot_population_spike_histogram(\n",
                "    gt_spikes,\n",
                "    np.random.poisson(ae_rates),\n",
                "    labels=[\"gt\", \"ae\"],\n",
                "    colors=[color_dict[\"gt\"], color_dict[\"ae\"]],\n",
                "    save=True,\n",
                "    save_path=save_path + \"population_spike_hist_ae\",\n",
                ")\n",
                "plot_population_spike_histogram(\n",
                "    gt_spikes,\n",
                "    np.random.poisson(diffusion_rates),\n",
                "    labels=[\"gt\", \"diff\"],\n",
                "    colors=[color_dict[\"gt\"], color_dict[\"diff\"]],\n",
                "    save=True,\n",
                "    save_path=save_path + \"population_spike_hist_diff\",\n",
                ")\n",
                "\n",
                "fig, axs = plt.subplots(1, 2, figsize=cm2inch(8, 3.5))\n",
                "plot_population_spike_histogram(\n",
                "    gt_spikes,\n",
                "    np.random.poisson(ae_rates),\n",
                "    labels=[\"gt\", \"ae\"],\n",
                "    colors=[color_dict[\"gt\"], color_dict[\"ae\"], color_dict[\"diff\"]],\n",
                "    ax=axs[0],\n",
                "    x_label=\"# spikes/bin\",\n",
                "    y_label=\"frequency\",\n",
                ")\n",
                "plot_population_spike_histogram(\n",
                "    gt_spikes,\n",
                "    np.random.poisson(diffusion_rates),\n",
                "    labels=[\"gt\", \"diff\"],\n",
                "    colors=[color_dict[\"gt\"], color_dict[\"diff\"], color_dict[\"diff\"]],\n",
                "    ax=axs[1],\n",
                "    x_label=\"# spikes/bin\",\n",
                "    y_label=\"frequency\",\n",
                ")\n",
                "plt.tight_layout()\n",
                "fig.savefig(save_path + \"population_spike_hist_all.png\")\n",
                "fig.savefig(save_path + \"population_spike_hist_all.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# plot reconstructed spikes\n",
                "diffusion_spikes = np.random.poisson(diffusion_rates)\n",
                "max_val_pop_count = np.max([diffusion_spikes.sum(2).max(), gt_spikes.sum(2).max()])\n",
                "max_val_pop_count = int(max_val_pop_count)\n",
                "max_val_pop_count =100\n",
                "bins = np.linspace(-0.5, max_val_pop_count-0.5, max_val_pop_count+1)\n",
                "bins = np.linspace(0, 100, 50+1)\n",
                "\n",
                "plt.figure(figsize=cm2inch((3, 2)))\n",
                "plt.hist(gt_spikes.sum(2).flatten(), density=True, color='grey', bins=bins, alpha=0.5)\n",
                "plt.hist(diffusion_spikes.sum(2).flatten(), density=True, color='darkred', bins=bins, alpha=0.5)\n",
                "\n",
                "plt.legend(['gt', 'ldns'])\n",
                "plt.title('population spike count')\n",
                "plt.ylabel('frequency')\n",
                "plt.xlabel('number of spikes per time bin')\n",
                "plt.savefig(save_path + 'population_spike_count_diff.png')\n",
                "plt.savefig(save_path + 'population_spike_count_diff.pdf')\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import gc\n",
                "gc.collect()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import correlation_matrix\n",
                "from ntldm.utils.plotting_utils import plot_correlation_matrices\n",
                "\n",
                "# Compute correlation matrices for one sample sequence\n",
                "plot_correlation_matrices(\n",
                "    gt_rates,\n",
                "    [ae_rates, diffusion_rates],\n",
                "    sample=0,\n",
                "    model_labels=[\"ae\", \"diffusion\"],\n",
                "    model_colors=[\"midnightblue\", \"darkred\"],\n",
                "    figsize=cm2inch((16, 6)),save = True, \n",
                "    save_path=save_path + \"rates_correlation_one_sample\",  ms=1\n",
                ")\n",
                "# mode average computes the average correlation matrix over all samples\n",
                "plot_correlation_matrices(\n",
                "    gt_rates,\n",
                "    [ae_rates, diffusion_rates],\n",
                "    sample=None,\n",
                "    mode=\"average\",\n",
                "    model_labels=[\"ae\", \"diffusion\"],\n",
                "    model_colors=[\"midnightblue\", \"darkred\"],\n",
                "    figsize=cm2inch((16, 6)),\n",
                "    save = True, \n",
                "    save_path=save_path + \"rates_correlation_average_across_trials\", ms=1\n",
                ")\n",
                "plot_correlation_matrices(\n",
                "    gt_rates,\n",
                "    [ae_rates, diffusion_rates],\n",
                "    sample=None,\n",
                "    mode=\"concatenate\",\n",
                "    model_labels=[\"ae\", \"diffusion\"],\n",
                "    model_colors=[\"midnightblue\", \"darkred\"],\n",
                "    figsize=cm2inch((16, 6)),\n",
                "    save = True, \n",
                "    save_path=save_path + \"rates_correlation_concat\", ms=1\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "plot_correlation_matrices(\n",
                "    gt_rates,\n",
                "    [ae_rates, diffusion_rates],\n",
                "    sample=None,\n",
                "    mode=\"concatenate\",\n",
                "    model_labels=[\"ae\", \"diffusion\"],\n",
                "    model_colors=[\"midnightblue\", \"darkred\"],\n",
                "    figsize=cm2inch((10, 4)),\n",
                "    save = True, \n",
                "    save_path=save_path + \"rates_correlation_concat_small\", ms=1\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "ae_spikes = np.random.poisson(ae_rates)\n",
                "from ntldm.utils.plotting_utils import plot_correlation_matrices_monkey\n",
                "plot_correlation_matrices_monkey(\n",
                "    gt_spikes,\n",
                "    diffusion_spikes],\n",
                "    sample=None,\n",
                "    mode=\"concatenate\",\n",
                "    model_labels=[\"ae\", \"diffusion\"],\n",
                "    model_colors=[\"midnightblue\", \"darkred\"],\n",
                "    figsize=cm2inch((10, 4)),\n",
                "    save = True, \n",
                "    save_path=save_path + \"spikes_correlation_concat_small\", ms=1\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "fig = plt.figure(figsize=cm2inch(8, 4))\n",
                "for i in range(20):\n",
                "    plt.plot(diffusion_latents[i], color= red_cmap(i))\n",
                "plt.xlabel(\"time (a.u.)\")\n",
                "plt.ylabel(\"sampled latents\")\n",
                "fig.savefig(save_path + \"sampled_latents.png\")\n",
                "fig.savefig(save_path + \"sampled_latents.pdf\")\n",
                "\n",
                "fig = plt.figure(figsize=cm2inch(8, 4))\n",
                "for i in range(20):\n",
                "    plt.plot(ae_latents[i], color= blue_cmap(i))\n",
                "plt.xlabel(\"time (a.u.)\")\n",
                "plt.ylabel(\"ae latents\")\n",
                "fig.savefig(save_path + \"ae_latents.png\")\n",
                "fig.savefig(save_path + \"ae_latents.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "num_channels = 64\n",
                "\n",
                "fig, axs = plt.subplots(num_channels // 10 + 1, 10, figsize=cm2inch((60, 40)))\n",
                "for idx in range(num_channels):\n",
                "    plot_sd(\n",
                "        fig=fig,\n",
                "        ax=axs[idx // 10, idx % 10],\n",
                "        arr_one=gt_rates[:, :, idx],\n",
                "        arr_two=ae_rates[:, :, idx],\n",
                "        fs=200,\n",
                "        nperseg=260,\n",
                "        agg_function=np.median,\n",
                "        with_quantiles=True,\n",
                "        x_ss=slice(0, 60),\n",
                "        color_one=\"black\",\n",
                "        color_two=\"C0\",\n",
                "    )\n",
                "fig.savefig(save_path + \"rate_spetral_density_ae_gt.png\")\n",
                "fig.savefig(save_path + \"rate_spetral_density_ae_gt.pdf\")\n",
                "\n",
                "num_channels = 64\n",
                "\n",
                "fig, axs = plt.subplots(num_channels // 10 + 1, 10, figsize=cm2inch((60, 40)))\n",
                "for idx in range(num_channels):\n",
                "    plot_sd(\n",
                "        fig=fig,\n",
                "        ax=axs[idx // 10, idx % 10],\n",
                "        arr_one=gt_rates[:, :, idx],\n",
                "        arr_two=diffusion_rates[:, :, idx],\n",
                "        fs=200,\n",
                "        nperseg=260,\n",
                "        agg_function=np.median,\n",
                "        with_quantiles=True,\n",
                "        x_ss=slice(0, 60),\n",
                "        color_one=\"black\",\n",
                "        color_two=\"C3\",\n",
                "    )\n",
                "fig.savefig(save_path + \"rate_spetral_density_diffusion_gt.png\")\n",
                "fig.savefig(save_path + \"rate_spetral_density_diffusion_gt.pdf\")\n",
                "#plt.title('diffusion')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import bits_per_spike, neg_log_likelihood, rmse, compute_fano_factor\n",
                "bps_recs = bits_per_spike(ae_rates, gt_spikes)\n",
                "bps_gt = bits_per_spike(gt_rates, gt_spikes)\n",
                "nll_recs = neg_log_likelihood(ae_rates, gt_spikes,reduction=\"none\").mean((0,1))\n",
                "nll_gt = neg_log_likelihood(gt_rates, gt_spikes,reduction=\"none\").mean((0,1))\n",
                "mse_gt_recs = rmse(gt_rates, ae_rates)\n",
                "fano_gt = compute_fano_factor(gt_spikes)\n",
                "fano_recs = compute_fano_factor(np.random.poisson(ae_rates))\n",
                "fano_diffusion = compute_fano_factor(np.random.poisson(diffusion_rates))\n",
                "\n",
                "# plot boxplots of the fano factors grey is gt, blue is recs\n",
                "plt.figure(figsize=cm2inch((2, 4)))\n",
                "plt.boxplot([fano_gt, fano_recs], showfliers=False, widths=0.5, meanline=True,\n",
                "             meanprops=dict(color=\"red\"), medianprops=dict(color=\"black\"), patch_artist=True)\n",
                "# plot the dots\n",
                "jitter = np.random.normal(0, 0.05, len(fano_gt))\n",
                "plt.plot(np.ones_like(fano_gt)+jitter, fano_gt, \".\", color=\"grey\", alpha=0.5)\n",
                "plt.plot(2 * np.ones_like(fano_recs)+jitter, fano_recs, \".\", color=\"grey\", alpha=0.5)\n",
                "plt.xticks([1, 2], [\"gt\", \"recs\"])\n",
                "plt.ylabel('fano factor')\n",
                "\n",
                "# plot boxplots of the fano factors grey is gt, blue is recs\n",
                "plt.figure(figsize=cm2inch((2, 4)))\n",
                "plt.boxplot([nll_gt, nll_recs], showfliers=False, widths=0.5, meanline=True,\n",
                "             meanprops=dict(color=\"red\"), medianprops=dict(color=\"black\"), patch_artist=True)\n",
                "# plot the dots\n",
                "jitter = np.random.normal(0, 0.05, len(nll_recs))\n",
                "plt.plot(np.ones_like(nll_gt)+jitter, nll_gt, \".\", color=\"grey\", alpha=0.5)\n",
                "plt.plot(2 * np.ones_like(nll_recs)+jitter, nll_recs, \".\", color=\"grey\", alpha=0.5)\n",
                "plt.xticks([1, 2], [\"gt\", \"recs\"])\n",
                "plt.ylabel('nll')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# Example data\n",
                "np.random.seed(0)  # For reproducibility\n",
                "# Plot setup\n",
                "fig = plt.figure(figsize=cm2inch((2, 4)))\n",
                "boxprops = dict(linestyle='-', linewidth=1, facecolor='grey')\n",
                "boxprops2 = dict(linestyle='-', linewidth=1, facecolor='lightblue')\n",
                "\n",
                "# Boxplots\n",
                "bp = plt.boxplot([nll_gt, nll_recs], positions=[1, 2], showfliers=False, widths=0.5, meanline=True,\n",
                "                 meanprops=dict(color=\"red\"), medianprops=dict(color=\"black\"),\n",
                "                 patch_artist=True, boxprops=boxprops)\n",
                "bp['boxes'][1].set_facecolor('lightblue')\n",
                "\n",
                "# Scatter dots\n",
                "jitter = np.random.normal(0, 0.05, len(nll_recs))\n",
                "plt.plot(np.ones_like(nll_gt) + jitter, nll_gt, \".\", color=\"black\", alpha=0.5, ms=0.5)\n",
                "plt.plot(2 * np.ones_like(nll_recs) + jitter, nll_recs, \".\", color=\"midnightblue\", alpha=0.5, ms=0.5)\n",
                "\n",
                "# Axes and labels\n",
                "plt.xticks([1, 2], [\"gt\", \"ae\"])\n",
                "plt.ylabel('neg log lik')\n",
                "plt.show()\n",
                "fig.savefig(save_path + \"neg_log_lik.png\")\n",
                "fig.savefig(save_path + \"neg_log_lik.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "nll_diffusion = neg_log_likelihood(diffusion_rates, gt_spikes,reduction=\"none\").mean((0,1))\n",
                "#nll_diffusion = neg_log_likelihood(np.mean(gt_spikes, axis=(0,1), keepdims=True)*np.ones_like(gt_spikes), gt_spikes,reduction=\"none\").mean((0,1))\n",
                "\n",
                "# Example data\n",
                "np.random.seed(0)  # For reproducibility\n",
                "# Plot setup\n",
                "fig = plt.figure(figsize=cm2inch((2, 4)))\n",
                "boxprops = dict(linestyle='-', linewidth=1, facecolor='grey')\n",
                "boxprops2 = dict(linestyle='-', linewidth=1, facecolor='lightblue')\n",
                "\n",
                "# Boxplots\n",
                "bp = plt.boxplot([nll_gt, nll_recs, nll_diffusion], positions=[1, 2, 3], showfliers=False, widths=0.5, meanline=True,\n",
                "                 meanprops=dict(color=\"red\"), medianprops=dict(color=\"black\"),\n",
                "                 patch_artist=True, boxprops=boxprops)\n",
                "bp['boxes'][1].set_facecolor('lightblue')\n",
                "bp['boxes'][2].set_facecolor('salmon')\n",
                "\n",
                "\n",
                "# Scatter dots\n",
                "jitter = np.random.normal(0, 0.05, len(nll_recs))\n",
                "plt.plot(np.ones_like(nll_gt) + jitter, nll_gt, \".\", color=\"black\", alpha=0.5, ms=0.5)\n",
                "plt.plot(2 * np.ones_like(nll_recs) + jitter, nll_recs, \".\", color=\"midnightblue\", alpha=0.5, ms=0.5)\n",
                "plt.plot(3 * np.ones_like(nll_diffusion) + jitter, nll_diffusion, \".\", color=\"darkred\", alpha=0.5, ms=0.5)\n",
                "\n",
                "# Axes and labels\n",
                "plt.xticks([1, 2, 3], [\"gt\", \"ae\", \"diff\"])\n",
                "plt.ylabel('neg log lik')\n",
                "plt.show()\n",
                "\n",
                "fig.savefig(save_path + \"neg_log_lik_with_diffusion.png\")\n",
                "fig.savefig(save_path + \"neg_log_lik_with_diffusion.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plot_bits_per_spike(\n",
                "    [bps_recs, bps_gt],\n",
                "    legend_labels=[\"ae\", \"gt\"],\n",
                "    colors=[\"midnightblue\", \"grey\"],\n",
                "    bins=30,\n",
                "    log=True,\n",
                "    save = True, \n",
                "    save_path=save_path + \"all_bits_per_spike\"\n",
                ")\n",
                "plot_bits_per_spike(\n",
                "    [bps_recs.mean((0, 1)), bps_gt.mean((0, 1))],\n",
                "    legend_labels=[\"ae\", \"gt\"],\n",
                "    colors=[\"midnightblue\", \"grey\"],\n",
                "    bins=30,\n",
                "    log=False,\n",
                "    save = True, \n",
                "    save_path=save_path + \"mean_bits_per_spike\"\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_rate_comparisons\n",
                "from ntldm.utils.eval_utils import average_rates, std_rates\n",
                "\n",
                "figsize = cm2inch((8, 8)) \n",
                "plot_rate_comparisons(gt_rates, [ae_rates, diffusion_rates],\n",
                "                      mode='neur', figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'], save = True, \n",
                "    save_path=save_path + \"mean_rate_comparison_gt_ae_diffusion\")\n",
                "plot_rate_comparisons(gt_rates, [ae_rates, diffusion_rates],\n",
                "                      mode='neurtime',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "plot_rate_comparisons(gt_rates, [ae_rates, diffusion_rates],\n",
                "                      mode='neursample',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "\n",
                "\n",
                "plot_rate_comparisons(gt_rates, [ae_rates, diffusion_rates],\n",
                "                      fn = std_rates, mode='neur',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'], save = True, \n",
                "    save_path=save_path + \"std_rate_comparison_gt_ae_diffusion\")\n",
                "plot_rate_comparisons(gt_rates, [ae_rates, diffusion_rates],\n",
                "                      fn = std_rates, mode='neurtime',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "plot_rate_comparisons(gt_rates, [ae_rates, diffusion_rates],\n",
                "                      fn = std_rates, mode='neursample',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import compute_spike_stats, counts_to_spike_trains, compute_spike_stats_per_neuron\n",
                "fps=1\n",
                "spike_trains_gt = counts_to_spike_trains(gt_spikes, fps=fps)\n",
                "spike_trains_ae = counts_to_spike_trains(np.random.poisson(ae_rates), fps=fps)\n",
                "spike_trains_diff = counts_to_spike_trains(np.random.poisson(diffusion_rates), fps=fps)\n",
                "\n",
                "\n",
                "# spike_stats_gt = compute_spike_stats(spike_trains_gt, n_samples=gt_spikes.shape[0], n_neurons=gt_spikes.shape[2])\n",
                "# spike_stats_ae = compute_spike_stats(spike_trains_ae, n_samples=ae_rates.shape[0], n_neurons=ae_rates.shape[2])\n",
                "# spike_stats_diff = compute_spike_stats(spike_trains_diff, n_samples=diffusion_rates.shape[0], n_neurons=diffusion_rates.shape[2])\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import compute_spike_stats_per_neuron\n",
                "from ntldm.utils.plotting_utils import cm2inch, plot_spiketrain_stats\n",
                "\n",
                "spike_stats_gt = compute_spike_stats_per_neuron(\n",
                "    spike_trains_gt,\n",
                "    n_samples=gt_spikes.shape[0],\n",
                "    n_neurons=gt_spikes.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "spike_stats_ae = compute_spike_stats_per_neuron(\n",
                "    spike_trains_ae,\n",
                "    n_samples=ae_rates.shape[0],\n",
                "    n_neurons=ae_rates.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "spike_stats_diff = compute_spike_stats_per_neuron(\n",
                "    spike_trains_diff,\n",
                "    n_samples=diffusion_rates.shape[0],\n",
                "    n_neurons=diffusion_rates.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_gt,\n",
                "    spike_stats_ae,\n",
                "    figsize=cm2inch(12, 4),\n",
                "    color=\"midnightblue\",\n",
                "    labels=[\"gt\", \"ae\"],\n",
                "    save = True, \n",
                "    save_path=save_path + \"compute_spike_stats_per_neuron_gt_ae\"\n",
                ")\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_gt,\n",
                "    spike_stats_diff,\n",
                "    figsize=cm2inch(12, 4),\n",
                "    color=\"darkred\",\n",
                "    labels=[\"gt\", \"diffusion\"],\n",
                "    save = True, \n",
                "    save_path=save_path + \"compute_spike_stats_per_neuron_gt_diff\"\n",
                ")\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_ae,\n",
                "    spike_stats_diff,\n",
                "    figsize=cm2inch(12, 4),\n",
                "    color=\"purple\",\n",
                "    labels=[\"ae\", \"diffusion\"],\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.subplots(1, 1, figsize=cm2inch(8,5))\n",
                "channels = np.arange(0, 128, 32)\n",
                "output_spikes = np.random.poisson(diffusion_rates)[idx]\n",
                "maxval = np.max(\n",
                "    output_spikes.flatten()\n",
                ")\n",
                "plt.imshow(\n",
                "    output_spikes.T,\n",
                "    vmin=0,\n",
                "    vmax=maxval,\n",
                "    aspect=\"auto\",\n",
                "    cmap=\"Reds\",\n",
                ")# add colorbar\n",
                "plt.colorbar()\n",
                "plt.xlabel('time (a.u.)')\n",
                "plt.ylabel('neuron id')\n",
                "plt.savefig(save_path + \"diffusion_output_spikes.png\")\n",
                "plt.savefig(save_path + \"diffusion_output_spikes.pdf\")\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save_path"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import visualise_spikes_trains\n",
                "visualise_spikes_trains(spike_trains_gt,spike_trains_ae, spike_trains_diff, ae_rates, figsize=(cm2inch(10, 6)),ms=0.4, \n",
                "                        save = True, \n",
                "    save_path=save_path + \"spike_trains\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "figsize=(cm2inch(6, 3))\n",
                "ms =1\n",
                "plt.figure(figsize=figsize)\n",
                "for (sample_idx, neuron_idx), spikes in spike_trains_gt.items():\n",
                "    plt.plot(spikes, np.ones_like(spikes) * neuron_idx, \"|\", color=\"black\", markersize=ms)\n",
                "    \n",
                "    \n",
                "    if neuron_idx == ae_rates.shape[-1]-1:\n",
                "        break\n",
                "plt.xlabel('time [s]')\n",
                "plt.ylabel('neuron idx')\n",
                "plt.title('gt spikes')\n",
                "plt.locator_params(nbins=5)\n",
                "plt.figure(figsize=figsize)\n",
                "\n",
                "for (sample_idx, neuron_idx), spikes in spike_trains_ae.items():\n",
                "    plt.plot(spikes, np.ones_like(spikes) * neuron_idx, \"|\", color=\"midnightblue\", markersize=ms)\n",
                "    \n",
                "    \n",
                "    if neuron_idx == ae_rates.shape[-1]-1:\n",
                "        break\n",
                "plt.xlabel('time [s]')\n",
                "plt.ylabel('neuron idx')\n",
                "plt.title('ae spikes')\n",
                "plt.locator_params(nbins=5)    \n",
                "plt.figure(figsize=figsize)\n",
                "\n",
                "for (sample_idx, neuron_idx), spikes in spike_trains_diff.items():\n",
                "    plt.plot(spikes, np.ones_like(spikes) * neuron_idx, \"|\", color=\"darkred\", markersize=ms)\n",
                "    \n",
                "    \n",
                "    if neuron_idx == ae_rates.shape[-1]-1:\n",
                "        break\n",
                "    \n",
                "plt.xlabel('time [s]')\n",
                "plt.ylabel('neuron idx')\n",
                "plt.title('diff spikes')\n",
                "plt.locator_params(nbins=5)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "ldiff",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.8.18"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
