{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import argparse\n",
                "import os\n",
                "import sys\n",
                "\n",
                "ANONAUTHOR = False\n",
                "\n",
                "if ANONAUTHOR:\n",
                "    os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
                "    # append parent directory to path (../notebooks -> ..)\n",
                "    sys.path.append(os.path.dirname(os.getcwd()))\n",
                "    os.chdir(os.path.dirname(os.getcwd()))\n",
                "\n",
                "else:\n",
                "    os.chdir('../')\n",
                "\n",
                "\n",
                "import accelerate\n",
                "import auraloss  # freq loss\n",
                "import lovely_tensors as lt\n",
                "import matplotlib.pyplot as plt\n",
                "import matplotlib\n",
                "import numpy as np\n",
                "import scipy.io as io\n",
                "import seaborn as sns\n",
                "import torch\n",
                "import torch.nn as nn\n",
                "import wandb\n",
                "import yaml\n",
                "from diffusers.optimization import get_scheduler\n",
                "from omegaconf import OmegaConf\n",
                "from scipy.signal import welch\n",
                "from tqdm.auto import tqdm\n",
                "from einops import rearrange\n",
                "\n",
                "from ntldm.data.lds import get_lds_dataset\n",
                "from ntldm.data.latent_attractor import get_attractor_dataset, LatentDataset\n",
                "from ntldm.data.lds import get_lds_dataset\n",
                "from ntldm.networks import S4AE, AutoEncoder, CountWrapper\n",
                "from ntldm.utils.plotting_utils import *\n",
                "from ntldm.losses import latent_regularizer_v2\n",
                "from ntldm.networks import Denoiser\n",
                "from diffusers.training_utils import EMAModel\n",
                "from diffusers.schedulers import DDPMScheduler\n",
                "# always run from ../ntldm\n",
                "\n",
                "\n",
                "lt.monkey_patch()\n",
                "matplotlib.rc_file('matplotlibrc')\n",
                "\n",
                "%load_ext autoreload\n",
                "%autoreload 2\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "#cfg_ae = OmegaConf.load(\"conf/sweeps_count/autoencoder-count_s4-Lorenz_z=4.yaml\")\n",
                "\n",
                "#cfg_ae = OmegaConf.load(\"conf/sweeps_count/autoencoder-count_s4-Lorenz_z=8_true_pointwise_decoder_with_test.yaml\")\n",
                "cfg_ae = OmegaConf.load(\"conf/sweeps_count/autoencoder-count_s4-Lorenz_z=8_new_regularization_03.yaml\")\n",
                "# CHANGE Cin to latetns   \n",
                "cfg_yaml = \"\"\"\n",
                "denoiser_model:\n",
                "  C_in: 8 \n",
                "  C: 64\n",
                "  kernel: s4\n",
                "  num_blocks: 4 \n",
                "  bidirectional: True\n",
                "  num_train_timesteps: 1000\n",
                "training:\n",
                "  lr: 0.001\n",
                "  weight_decay: 0.0\n",
                "  num_epochs: 1000\n",
                "  num_warmup_epochs: 50\n",
                "  batch_size: 512\n",
                "  random_seed: 42\n",
                "  precision: \"no\"\n",
                "exp_name: diffusion_s4-Lorenz_z=8_true-pointwise_decoder_new_regularization_03\n",
                "\"\"\"\n",
                "\n",
                "cfg = OmegaConf.create(yaml.safe_load(cfg_yaml))\n",
                "cfg.dataset = cfg_ae.dataset\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "cfg_ae"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "    \n",
                "ae_model = AutoEncoder(\n",
                "    C_in=cfg_ae.model.C_in,\n",
                "    C=cfg_ae.model.C,\n",
                "    C_latent=cfg_ae.model.C_latent,\n",
                "    L=cfg_ae.dataset.signal_length,\n",
                "    kernel=cfg_ae.model.kernel,\n",
                "    num_blocks=cfg_ae.model.num_blocks,\n",
                "    num_blocks_decoder=cfg_ae.model.num_blocks_decoder,\n",
                "    num_lin_per_mlp=cfg_ae.model.get(\"num_lin_per_mlp\", 2),  # default 2\n",
                ")\n",
                "\n",
                "ae_model = CountWrapper(ae_model, use_sin_enc=cfg_ae.model.get(\"use_sin_enc\", False))\n",
                "\n",
                "ae_model.load_state_dict(torch.load(f\"exp/{cfg_ae.exp_name}/model.pt\", map_location=\"cpu\"))\n",
                "\n",
                "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
                "\n",
                "ae_model = ae_model.to(device)\n",
                "\n",
                "train_dataloader, val_dataloader, test_dataloader = get_attractor_dataset(\n",
                "    system_name=cfg_ae.dataset.system_name,\n",
                "    n_neurons=cfg_ae.model.C_in,\n",
                "    sequence_length=cfg_ae.dataset.signal_length,\n",
                "    noise_std=0.05,\n",
                "    n_ic=cfg_ae.dataset.n_ic,\n",
                "    mean_spike_count=cfg_ae.dataset.mean_rate * cfg.dataset.signal_length,\n",
                "    train_frac=cfg_ae.dataset.split_frac_train,\n",
                "    valid_frac=cfg_ae.dataset.split_frac_val, # test is 1 - train - valid\n",
                "    random_seed=cfg_ae.training.random_seed,\n",
                "    batch_size=cfg_ae.training.batch_size,\n",
                "    softplus_beta=cfg_ae.dataset.get(\"softplus_beta\", 2.0),\n",
                ")\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Set up the accelerator and the latent dataset"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "# check if signal length is power of 2\n",
                "if cfg.dataset.signal_length & (cfg.dataset.signal_length - 1) != 0:\n",
                "    cfg.training.precision = \"no\"  # torch.fft doesnt support half if L!=2^x\n",
                "\n",
                "\n",
                "accelerator = accelerator = accelerate.Accelerator(\n",
                "    mixed_precision=cfg.training.precision,\n",
                "    log_with=\"wandb\",\n",
                ")\n",
                "#\n",
                "\n",
                "\n",
                "# prepare the ae model and dataset\n",
                "(\n",
                "    ae_model,\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ") = accelerator.prepare(\n",
                "    ae_model,\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ")\n",
                "\n",
                "\n",
                "# create the latent dataset\n",
                "latent_dataset_train = LatentDataset(train_dataloader, ae_model)\n",
                "latent_dataset_val = LatentDataset(\n",
                "    val_dataloader,\n",
                "    ae_model,\n",
                "    latent_means=latent_dataset_train.latent_means,\n",
                "    latent_stds=latent_dataset_train.latent_stds,\n",
                ")\n",
                "latent_dataset_test = LatentDataset(\n",
                "    test_dataloader,\n",
                "    ae_model,\n",
                "    latent_means=latent_dataset_train.latent_means,\n",
                "    latent_stds=latent_dataset_train.latent_stds,\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "print(\"latent dataset\", latent_dataset_train.latents)\n",
                "print(\"latent dataset means\", latent_dataset_train.latent_means)\n",
                "print(\"latent dataset stds\", latent_dataset_train.latent_stds)\n",
                "plt.figure(figsize=cm2inch(5, 3))\n",
                "hist = plt.hist(latent_dataset_train.latents[:100].flatten(), bins=200)\n",
                "hist = plt.hist(latent_dataset_val.latents[:100].flatten(), bins=200)\n",
                "hist = plt.hist(latent_dataset_test.latents[:100].flatten(), bins=200)\n",
                "\n",
                "plt.title(\"Latent dataset histogram\")\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import evaluate_autoencoder\n",
                "\n",
                "\n",
                "save_path = 'exp/'+cfg_ae.exp_name\n",
                "save_path = '/home/anonauthor/anonloc1/results/projects/latent-diffusion/LORENZ/PAPER/figures/'\n",
                "\n",
                "_1, _2 , test_dataloader_longer = get_attractor_dataset(\n",
                "    system_name=cfg_ae.dataset.system_name,\n",
                "    n_neurons=cfg_ae.model.C_in,\n",
                "    sequence_length=cfg_ae.dataset.signal_length * 4,\n",
                "    noise_std=0.05,\n",
                "    n_ic=cfg_ae.dataset.n_ic,\n",
                "    mean_spike_count=cfg_ae.dataset.mean_rate * cfg.dataset.signal_length * 4,\n",
                "    train_frac=cfg_ae.dataset.split_frac_train,\n",
                "    valid_frac=cfg_ae.dataset.split_frac_val, # test is 1 - train - valid\n",
                "    random_seed=cfg_ae.training.random_seed,\n",
                "    batch_size=cfg_ae.training.batch_size//16,\n",
                "    softplus_beta=cfg_ae.dataset.get(\"softplus_beta\", 2.0),\n",
                ")\n",
                "\n",
                "test_dataloader_longer = accelerator.prepare(test_dataloader_longer)\n",
                "\n",
                "evaluate_autoencoder(ae_model, test_dataloader, test_dataloader_longer, n_latents=8, save=True, save_path=save_path, idx=10, indices=[1,3,7])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "#save_path = '/home/anonauthor/anonloc1/results/projects/latent-diffusion/LORENZ/figures/'\n",
                "\n",
                "evaluate_autoencoder(ae_model, test_dataloader, test_dataloader_longer, n_latents=8,\n",
                "                      save=True, save_path=save_path, idx=10, indices=[6,4,3])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "model.eval()\n",
                "for batch in dataloader:\n",
                "    signal = batch[\"signal\"]\n",
                "    with torch.no_grad():\n",
                "        output_rates = model(signal)[0].cpu()\n",
                "\n",
                "    signal = signal.cpu()  # move signal to cpu\n",
                "    break\n",
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=figsize)\n",
                "channels = np.arange(0, 128, 32)\n",
                "output_spikes = np.random.poisson(output_rates)[idx]\n",
                "maxval = np.max(\n",
                "    [output_spikes.flatten(), batch[\"signal\"][idx].cpu().numpy().flatten()]\n",
                ")\n",
                "if binary:\n",
                "    maxval = 1\n",
                "\n",
                "ax[0].imshow(\n",
                "    batch[\"signal\"][idx].cpu().numpy(),\n",
                "    vmin=0,\n",
                "    vmax=maxval,\n",
                "    aspect=\"auto\",\n",
                "    cmap=\"Greys\",\n",
                ")\n",
                "im = ax[1].imshow(output_spikes, aspect=\"auto\", vmin=0, vmax=maxval, cmap=\"Greys\")\n",
                "# add colorbar\n",
                "if not binary:\n",
                "    cbar = plt.colorbar(im, ax=ax[1])\n",
                "    cbar = plt.colorbar(im, ax=ax[0])\n",
                "ax[0].set_ylabel(ylabel)\n",
                "ax[0].set_xlabel(xlabel)\n",
                "ax[1].set_xlabel(xlabel)\n",
                "\n",
                "fig.tight_layout()\n",
                "if save and save_path is not None:\n",
                "    plt.savefig(save_path + \".png\")\n",
                "    plt.savefig(save_path + \".pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "def binarized_spikes_and_latents(\n",
                "    model,\n",
                "    dataloader,\n",
                "    idx=0,\n",
                "    true_data=True,\n",
                "    figsize=cm2inch(12, 5),\n",
                "    save=False,\n",
                "    save_path=None,\n",
                "    xlabel=\"time (a.u.)\",\n",
                "    ylabel=\"neuron\",\n",
                "    binary=False,\n",
                "    ax=None\n",
                "):\n",
                "    model.eval()\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = model(signal)[0].cpu()\n",
                "            # also get the latents\n",
                "            latents =  model(signal)[1].cpu()\n",
                "        signal = signal.cpu()  # move signal to cpu\n",
                "        break\n",
                "\n",
                "    fig, ax = plt.subplots(1, 2, figsize=figsize, sharex=True)\n",
                "    channels = np.arange(0, 128, 32)\n",
                "    maxval = np.max( batch[\"signal\"][idx].cpu().numpy().flatten()    )\n",
                "    if binary:\n",
                "        maxval = 1\n",
                "\n",
                "    ax[0].imshow(\n",
                "        batch[\"signal\"][idx].cpu().numpy(),\n",
                "        vmin=0,\n",
                "        vmax=maxval,\n",
                "        aspect=\"auto\",\n",
                "        cmap=\"Greys\",\n",
                "    )\n",
                "    \n",
                "    ax[1].plot(latents[idx].T)\n",
                "    \n",
                "    # add colorbar\n",
                "    if not binary:\n",
                "        cbar = plt.colorbar(im, ax=ax[0])\n",
                "        \n",
                "    ax[0].set_ylabel(ylabel)\n",
                "    ax[0].set_xlabel(xlabel)\n",
                "    ax[1].set_xlabel(xlabel)\n",
                "\n",
                "    fig.tight_layout()\n",
                "    if save and save_path is not None:\n",
                "        plt.savefig(save_path + \".png\")\n",
                "        plt.savefig(save_path + \".pdf\")\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "idx=0\n",
                "plot_spikes_next_to_each_other(\n",
                "    ae_model,\n",
                "    test_dataloader,\n",
                "    idx=idx,\n",
                "    figsize=cm2inch(8, 4),\n",
                "    save=True,\n",
                "    save_path=save_path + \"Fig_2_spikes_next_to_each_other\",\n",
                "    binary=True,\n",
                ")\n",
                "\n",
                "# run all sorts of analyses\n",
                "plot_inferred_latents(\n",
                "    ae_model,\n",
                "    test_dataloader,\n",
                "    n_latents=2,\n",
                "    y_stack=2,\n",
                "    figsize=cm2inch(4, 4.5),\n",
                "    color=\"royalblue\",\n",
                "    idx=idx,\n",
                "    save=True,\n",
                "    save_path=save_path + \"Fig_2_inferred_latents\",\n",
                "    indices=[4,5]\n",
                "    \n",
                ")\n",
                "\n",
                "\n",
                "# run all sorts of analyses\n",
                "plot_rate_traces(\n",
                "    ae_model,\n",
                "    test_dataloader,\n",
                "    y_stack=2,\n",
                "    figsize=cm2inch(4, 4.5),\n",
                "    idx=idx,\n",
                "    save=True,\n",
                "    save_path=save_path + \"Fig_2_predicted_rates\",\n",
                "    indices=[4,5]\n",
                "    \n",
                ")\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "def plot_rate_traces(\n",
                "    model,\n",
                "    dataloader,\n",
                "    figsize=(12, 5),\n",
                "    idx=0,\n",
                "    true_data=True,\n",
                "    color=\"royalblue\",\n",
                "    save=False,\n",
                "    save_path=None,\n",
                "    xlabel=\"time (s)\",\n",
                "    ylabel=\"rate (Hz)\",\n",
                "):\n",
                "    model.eval()\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "\n",
                "        with torch.no_grad():\n",
                "            output_rates = model(signal)[0].cpu()\n",
                "\n",
                "        signal = signal.cpu()  # move signal to cpu\n",
                "        break\n",
                "\n",
                "    fig, ax = plt.subplots(2, 1, figsize=figsize, sharex=True)\n",
                "    channels = np.arange(0, 128, 64)\n",
                "\n",
                "    for i, channel in enumerate(channels):\n",
                "        # replace 0 values with nan\n",
                "        #signal[signal == 0] = np.nan\n",
                "        \n",
                "        # plot on different axis\n",
                "        ax2 = ax[i].twinx()\n",
                "        ax2.plot(signal[idx, channel].cpu().numpy(),\n",
                "                   label=\"spikes\", color=\"black\", lw=0.8,alpha=0.5)\n",
                "        ax[i].plot(output_rates[idx, channel].cpu().numpy(), label=\"pred\", color=color)\n",
                "        #ax[i].set_title(f\"channel {channel}\")\n",
                "        if not true_data:\n",
                "            ax[i].plot(\n",
                "                batch[\"rates\"][idx, channel].cpu().numpy(), '--',lw=1,label=\"real\", color=\"black\", alpha=1,\n",
                "            )\n",
                "        ax2.set_yticks([0,np.max(signal[idx, channel].cpu().numpy())])\n",
                "\n",
                "    # move legend our \n",
                "    ax[-1].legend(loc=\"upper left\", bbox_to_anchor=(1.5, 1.5))\n",
                "    ax2.legend(loc=\"upper left\", bbox_to_anchor=(1.5, 1))\n",
                "    ax2.set_yticks([0,2])\n",
                "    ax[-1].set_xlabel(xlabel)\n",
                "    ax[0].set_ylabel(ylabel)\n",
                "    # fig.suptitle(\"rate traces for channels\")\n",
                "    #fig.tight_layout()\n",
                "    if save and save_path is not None:\n",
                "        plt.savefig(save_path + \".png\")\n",
                "        plt.savefig(save_path + \".pdf\")\n",
                "\n",
                "\n",
                "# run all sorts of analyses\n",
                "plot_rate_traces(\n",
                "    ae_model,\n",
                "    test_dataloader,\n",
                "    figsize=cm2inch(3, 3),\n",
                "    idx=idx,\n",
                "    save=True,\n",
                "    save_path=save_path + \"Fig_2_predicted_rates\",\n",
                "    true_data=False,\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "figsize=cm2inch(3, 3)\n",
                "idx=0\n",
                "true_data=False\n",
                "color=\"royalblue\"\n",
                "save=True\n",
                "save_path = '/home/anonauthor/anonloc1/results/projects/latent-diffusion/LORENZ/PAPER/figures/'\n",
                "figure_path=save_path + \"Fig_2_predicted_rates\"\n",
                "xlabel=\"time (s)\"\n",
                "ylabel=\"rate (Hz)\"\n",
                "\n",
                "ae_model.eval()\n",
                "for batch in test_dataloader:\n",
                "    signal = batch[\"signal\"]\n",
                "\n",
                "    with torch.no_grad():\n",
                "        output_rates = ae_model(signal)[0].cpu()\n",
                "\n",
                "    signal = signal.cpu()  # move signal to cpu\n",
                "    break\n",
                "\n",
                "fig, ax = plt.subplots(2, 2, figsize=figsize, sharex=True)\n",
                "channels = np.arange(0, 128, 64)\n",
                "\n",
                "for i, channel in enumerate(channels):\n",
                "    ax2 = ax[i].twinx()\n",
                "    ax2.plot(signal[idx, channel].cpu().numpy(), label=\"spikes\", color=\"black\", lw=0.8, alpha=0.5)\n",
                "    ax[i].plot(output_rates[idx, channel].cpu().numpy(), label=\"pred\", color=color)\n",
                "\n",
                "    if not true_data:\n",
                "        ax[i].plot(batch[\"rates\"][idx, channel].cpu().numpy(), '--', lw=1, label=\"real\", color=\"black\", alpha=1)\n",
                "    ax2.set_yticks([0, np.max(signal[idx, channel].cpu().numpy())])\n",
                "\n",
                "ax[-1].legend(loc=\"upper left\", bbox_to_anchor=(1.5, 1.5))\n",
                "ax2.legend(loc=\"upper left\", bbox_to_anchor=(1.5, 1))\n",
                "ax[-1].set_xlabel(xlabel)\n",
                "ax[0].set_ylabel(ylabel)\n",
                "\n",
                "if save and figure_path is not None:\n",
                "    plt.savefig(figure_path + \".png\")\n",
                "    plt.savefig(figure_path + \".pdf\")\n",
                "\n",
                "# # Run the function with the specified parameters\n",
                "# plot_rate_traces(\n",
                "#     model=ae_model,\n",
                "#     dataloader=test_dataloader,\n",
                "#     figsize=cm2inch(3, 3),  # Assuming cm2inch is defined elsewhere\n",
                "#     idx=idx,\n",
                "#     save=True,\n",
                "#     save_path=save_path + \"Fig_2_predicted_rates\",\n",
                "#     true_data=False\n",
                "# )"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "figure_path"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "figsize=cm2inch(4.5, 3.2)\n",
                "idx=0\n",
                "true_data=True\n",
                "color=\"midnightblue\"\n",
                "save=True\n",
                "save_path = '/home/anonauthor/anonloc1/results/projects/latent-diffusion/LORENZ/PAPER/figures/'\n",
                "figure_path=save_path + \"Fig_2_predicted_rates_4_channels\"\n",
                "xlabel=\"time (s)\"\n",
                "ylabel=\"rate (Hz)\"\n",
                "\n",
                "ae_model.eval()\n",
                "for batch in test_dataloader:\n",
                "    signal = batch[\"signal\"]\n",
                "\n",
                "    with torch.no_grad():\n",
                "        output_rates = ae_model(signal)[0].cpu()\n",
                "\n",
                "    signal = signal.cpu()  # move signal to cpu\n",
                "    break\n",
                "\n",
                "fig, ax = plt.subplots(2, 2, figsize=figsize, sharex=True)\n",
                "fig.subplots_adjust(wspace=0.2, hspace=0.2)\n",
                "\n",
                "channels = [0,32,64,127] # np.arange(0, 128, 64)\n",
                "ax = ax.flatten()\n",
                "for i, channel in enumerate(channels):\n",
                "    ax2 = ax[i]#.twinx()\n",
                "    ax2.plot(signal[idx, channel].cpu().numpy(), label=\"spikes\", color=\"black\", lw=0.8, alpha=0.5)\n",
                "    ax[i].plot(output_rates[idx, channel].cpu().numpy(), label=\"pred\", color=color, lw=1)\n",
                "\n",
                "    if not true_data:\n",
                "        ax[i].plot(batch[\"rates\"][idx, channel].cpu().numpy(), '--', lw=1, label=\"real\", color=\"black\", alpha=1)\n",
                "    ax2.set_yticks([0, np.max(signal[idx, channel].cpu().numpy())])\n",
                "    ax2.set_yticks([])\n",
                "\n",
                "ax[0].set_yticks([0, 4])\n",
                "ax[2].set_yticks([0, 4])\n",
                "ax[-1].legend(loc=\"upper left\", bbox_to_anchor=(1.5, 1.5))\n",
                "ax2.legend(loc=\"upper left\", bbox_to_anchor=(1.5, 1))\n",
                "ax[2].set_xlabel(xlabel)\n",
                "ax[2].set_ylabel(ylabel)\n",
                "ax[0].set_xlim(-2, 128)\n",
                "#plt.tight_layout()\n",
                "if save and figure_path is not None:\n",
                "    plt.savefig(figure_path + \".png\")\n",
                "    plt.savefig(figure_path + \".pdf\")\n",
                "\n",
                "# # Run the function with the specified parameters\n",
                "# plot_rate_traces(\n",
                "#     model=ae_model,\n",
                "#     dataloader=test_dataloader,\n",
                "#     figsize=cm2inch(3, 3),  # Assuming cm2inch is defined elsewhere\n",
                "#     idx=idx,\n",
                "#     save=True,\n",
                "#     save_path=save_path + \"Fig_2_predicted_rates\",\n",
                "#     true_data=False\n",
                "# )"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def plot_rate_traces(\n",
                "    model,\n",
                "    dataloader,\n",
                "    figsize=(12, 5),\n",
                "    idx=0,\n",
                "    true_data=True,\n",
                "    color=\"royalblue\",\n",
                "    save=False,\n",
                "    save_path=None,\n",
                "    xlabel=\"time (s)\",\n",
                "    ylabel=\"rate (Hz)\",\n",
                "):\n",
                "    model.eval()\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "\n",
                "        with torch.no_grad():\n",
                "            output_rates = model(signal)[0].cpu()\n",
                "\n",
                "        signal = signal.cpu()  # move signal to cpu\n",
                "        break\n",
                "\n",
                "    fig, ax = plt.subplots(2, 1, figsize=figsize, sharex=True)\n",
                "    channels = np.arange(0, 128, 64)\n",
                "\n",
                "    for i, channel in enumerate(channels):\n",
                "        ax2 = ax[i].twinx()\n",
                "        ax2.plot(signal[idx, channel].cpu().numpy(), label=\"spikes\", color=\"black\", lw=0.8, alpha=0.5)\n",
                "        ax[i].plot(output_rates[idx, channel].cpu().numpy(), label=\"pred\", color=color)\n",
                "\n",
                "        if not true_data:\n",
                "            ax[i].plot(batch[\"rates\"][idx, channel].cpu().numpy(), '--', lw=1, label=\"real\", color=\"black\", alpha=1)\n",
                "        ax2.set_yticks([0, np.max(signal[idx, channel].cpu().numpy())])\n",
                "\n",
                "    ax[-1].legend(loc=\"upper left\", bbox_to_anchor=(1.5, 1.5))\n",
                "    ax2.legend(loc=\"upper left\", bbox_to_anchor=(1.5, 1))\n",
                "    ax[-1].set_xlabel(xlabel)\n",
                "    ax[0].set_ylabel(ylabel)\n",
                "\n",
                "    if save and save_path is not None:\n",
                "        plt.savefig(save_path + \".png\")\n",
                "        plt.savefig(save_path + \".pdf\")\n",
                "\n",
                "# Run the function with the specified parameters\n",
                "plot_rate_traces(\n",
                "    model=ae_model,\n",
                "    dataloader=test_dataloader,\n",
                "    figsize=cm2inch(3, 3),  # Assuming cm2inch is defined elsewhere\n",
                "    idx=idx,\n",
                "    save=True,\n",
                "    save_path=save_path + \"Fig_2_predicted_rates\",\n",
                "    true_data=False\n",
                ")\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "evaluate_autoencoder(ae_model, test_dataloader, test_dataloader_longer, n_latents=8,\n",
                "                      save=True, save_path=save_path, idx=10, indices=[6,4,3])    "
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# train the diffusion model"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def sample(\n",
                "    ema_denoiser,\n",
                "    scheduler,\n",
                "    cfg,\n",
                "    batch_size=1,\n",
                "    generator=None,\n",
                "    device=\"cuda\",\n",
                "    signal_length=None\n",
                "):  \n",
                "    if signal_length is None:\n",
                "        signal_length = cfg.dataset.signal_length\n",
                "    z_t = torch.randn(\n",
                "        (batch_size, cfg.denoiser_model.C_in, signal_length)\n",
                "    ).to(device)\n",
                "    ema_denoiser_avg = ema_denoiser.averaged_model\n",
                "    ema_denoiser_avg.eval()\n",
                "\n",
                "    scheduler.set_timesteps(cfg.denoiser_model.num_train_timesteps)\n",
                "\n",
                "    for t in tqdm(scheduler.timesteps, desc=\"Sampling DDPM\"):\n",
                "        with torch.no_grad():\n",
                "            model_output = ema_denoiser_avg(\n",
                "                z_t, torch.tensor([t] * batch_size).to(device).long()\n",
                "            )\n",
                "        z_t = scheduler.step(\n",
                "            model_output, t, z_t, generator=generator, return_dict=False\n",
                "        )[0]\n",
                "\n",
                "    return z_t\n",
                "\n",
                "\n",
                "\n",
                "save_new_model = False\n",
                "load_model = True\n",
                "\n",
                "if save_new_model:\n",
                "\n",
                "    # save model and config file\n",
                "    if accelerator.is_main_process:\n",
                "        os.makedirs(f\"exp/{cfg.exp_name}\", exist_ok=True)\n",
                "        torch.save(accelerator.unwrap_model(denoiser).state_dict(), f\"exp/{cfg.exp_name}/model.pt\")\n",
                "        with open(f\"conf/sweeps_count/{cfg.exp_name}.yaml\", \"w\") as f:\n",
                "            f.write(OmegaConf.to_yaml(cfg))\n",
                "    print('saved model to ', cfg.exp_name)\n",
                "            \n",
                "elif load_model:\n",
                "    # load the congig and model path\n",
                "    with open(f\"conf/sweeps_count/{cfg.exp_name}.yaml\") as f:\n",
                "        cfg = OmegaConf.create(yaml.safe_load(f))\n",
                "\n",
                "    denoiser = Denoiser(\n",
                "        C_in=cfg.denoiser_model.C_in,\n",
                "        C=cfg.denoiser_model.C,\n",
                "        L=cfg.dataset.signal_length,\n",
                "        kernel=cfg.denoiser_model.kernel,\n",
                "        num_blocks=cfg.denoiser_model.num_blocks,\n",
                "        bidirectional=cfg.denoiser_model.get(\"bidirectional\", True),\n",
                "    )\n",
                "\n",
                "    denoiser.load_state_dict(torch.load(f\"exp/{cfg.exp_name}/model.pt\", map_location=\"cpu\"))\n",
                "\n",
                "    scheduler = DDPMScheduler(\n",
                "        num_train_timesteps=cfg.denoiser_model.num_train_timesteps,\n",
                "        clip_sample=False,\n",
                "        beta_schedule=\"linear\", # ddpm doesnt support cosine\n",
                "    )\n",
                "\n",
                "\n",
                "\n",
                "    optimizer = torch.optim.AdamW(\n",
                "        denoiser.parameters(), lr=cfg.training.lr\n",
                "    )  # default wd=0.01 for now\n",
                "\n",
                "    train_latent_dataloader = torch.utils.data.DataLoader(\n",
                "        latent_dataset_train,\n",
                "        batch_size=cfg.training.batch_size,\n",
                "        shuffle=True,\n",
                "        num_workers=4,\n",
                "        pin_memory=True,\n",
                "    )\n",
                "\n",
                "    val_latent_dataloader = torch.utils.data.DataLoader(\n",
                "        latent_dataset_val,\n",
                "        batch_size=cfg.training.batch_size,\n",
                "        shuffle=False,\n",
                "        num_workers=4,\n",
                "        pin_memory=True,\n",
                "    )\n",
                "    test_latent_dataloader = torch.utils.data.DataLoader(\n",
                "        latent_dataset_test,\n",
                "        batch_size=cfg.training.batch_size,\n",
                "        shuffle=False,\n",
                "        num_workers=4,\n",
                "        pin_memory=True,\n",
                "    )\n",
                "\n",
                "    num_batches = len(train_latent_dataloader)\n",
                "    lr_scheduler = get_scheduler(\n",
                "        name=\"cosine\",\n",
                "        optimizer=optimizer,\n",
                "        num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "        num_training_steps=num_batches * cfg.training.num_epochs * 1.3,  # total number of steps\n",
                "    )\n",
                "\n",
                "    # check if signal length is power of 2\n",
                "    if cfg.dataset.signal_length & (cfg.dataset.signal_length - 1) != 0:\n",
                "        cfg.training.precision = \"no\"  # torch.fft doesnt support half if L!=2^x\n",
                "\n",
                "    # prepare the denoiser model and dataset\n",
                "    (\n",
                "        denoiser,\n",
                "        train_latent_dataloader,\n",
                "        val_latent_dataloader,\n",
                "        test_latent_dataloader,\n",
                "        lr_scheduler,\n",
                "    ) = accelerator.prepare(\n",
                "        denoiser,\n",
                "        train_latent_dataloader,\n",
                "        val_latent_dataloader,\n",
                "        test_latent_dataloader,\n",
                "        lr_scheduler,\n",
                "    )\n",
                "\n",
                "    ema_model = EMAModel(denoiser)\n",
                "\n",
                "        \n",
                "sampled_latents = sample(\n",
                "    ema_denoiser=ema_model, scheduler=scheduler, cfg=cfg, batch_size=2, device=\"cuda\", signal_length=cfg.dataset.signal_length * 4\n",
                ")\n",
                "sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "    sampled_latents.device\n",
                ") + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "with torch.no_grad():\n",
                "    sampled_rates = ae_model.decode(sampled_latents).cpu()\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(12,4))\n",
                "im = ax.imshow(sampled_rates[0], aspect='auto')\n",
                "ax.set_title(\"Sampled rates\")\n",
                "fig.colorbar(im, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)\n",
                "fig.tight_layout()\n",
                "plt.show()\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_3d_latent_trajectory_direct\n",
                "     \n",
                "sampled_latents = sample(\n",
                "    ema_denoiser=ema_model, scheduler=scheduler, cfg=cfg, batch_size=20, device=\"cuda\", signal_length=cfg.dataset.signal_length * 16\n",
                ")\n",
                "sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "    sampled_latents.device\n",
                ") + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "with torch.no_grad():\n",
                "    sampled_rates = ae_model.decode(sampled_latents).cpu()\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(12,4))\n",
                "im = ax.imshow(sampled_rates[0], aspect='auto')\n",
                "ax.set_title(\"Sampled rates\")\n",
                "fig.colorbar(im, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)\n",
                "fig.tight_layout()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "train_long, _, _ = get_attractor_dataset(\n",
                "    system_name=cfg_ae.dataset.system_name,\n",
                "    n_neurons=cfg_ae.model.C_in,\n",
                "    sequence_length=cfg_ae.dataset.signal_length*16,\n",
                "    noise_std=0.05,\n",
                "    n_ic=20, # cfg_ae.dataset.n_ic,\n",
                "    mean_spike_count=cfg_ae.dataset.mean_rate * cfg.dataset.signal_length,\n",
                "    train_frac=cfg_ae.dataset.split_frac_train,\n",
                "    valid_frac=cfg_ae.dataset.split_frac_val, # test is 1 - train - valid\n",
                "    random_seed=cfg_ae.training.random_seed,\n",
                "    batch_size=cfg_ae.training.batch_size,\n",
                "    softplus_beta=cfg_ae.dataset.get(\"softplus_beta\", 2.0),\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "for batch in train_long:\n",
                "    batch['latents']"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "%matplotlib widget\n",
                "save_name = save_path+'/Latents/'\n",
                "os.makedirs(save_name, exist_ok=True)\n",
                "plot_3d_latent_trajectory_direct(batch['latents'], cmap=\"Greys\",\n",
                "save=True, save_path=save_path+'true_lorenz', sample_idx=3, figsize=cm2inch(4,4),  indices=[1,2,0], ticksoff=True, ms=0.01, lw=0.5)\n",
                "plt.savefig(save_path+'true_lorenz.png')\n",
                "plt.savefig(save_path+'true_lorenz.pdf')\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "%matplotlib widget\n",
                "indices = [4, 6,5]\n",
                "sample_idx=8\n",
                "plot_3d_latent_trajectory_direct(sampled_latents,cmap=\"Reds\",\n",
                "save=True, save_path=save_name+'trajectory_long', sample_idx=sample_idx, figsize=cm2inch(4,4),  indices=indices, ticksoff=True, ms=0.5, lw=0.5)\n",
                "plt.savefig(save_name+'trajectory_long.png')\n",
                "plt.savefig(save_name+'trajectory_long.pdf')\n",
                "\n",
                "\n",
                "plot_3d_latent_trajectory_direct(sampled_latents[:,:,:cfg.dataset.signal_length],cmap=\"Reds\",\n",
                "save=True, save_path=save_name+'trajectory_short', sample_idx=sample_idx, figsize=cm2inch(4,4),  indices=indices, ticksoff=True, ms=0.5, lw=0.5)\n",
                "plt.savefig(save_name+'trajectory_short.png')\n",
                "plt.savefig(save_name+'trajectory_short.pdf')"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Evaluation Diffusion on latents"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "true_data = False\n",
                "n_siglen = 1\n",
                "sample_cutoff = int(10e5)\n",
                "\n",
                "ae_model.eval()\n",
                "\n",
                "ae_rates = []\n",
                "ae_latents = []\n",
                "diffusion_rates = []\n",
                "diffusion_latents = []\n",
                "gt_spikes = []\n",
                "\n",
                "if not true_data:\n",
                "    gt_rates = []\n",
                "    gt_latents = []\n",
                "\n",
                "\n",
                "count = 0\n",
                "# autoencoder eval \n",
                "for batch in train_dataloader:\n",
                "    signal = batch[\"signal\"]\n",
                "    with torch.no_grad():\n",
                "        output_rates, latent = ae_model(signal)\n",
                "        #output_rates = ae_model(signal)[0].cpu()\n",
                "    ae_rates.append(output_rates.cpu())\n",
                "    ae_latents.append(latent.cpu())\n",
                "    gt_spikes.append(signal.cpu())\n",
                "    if not true_data:\n",
                "        gt_rates.append(batch[\"rates\"].cpu())\n",
                "        gt_latents.append(batch[\"latents\"].cpu())\n",
                "    count += 1\n",
                "    # if count > 1:\n",
                "    #     break\n",
                "\n",
                "# concatenate along batch dimension\n",
                "ae_rates = torch.cat(ae_rates, dim=0)\n",
                "ae_latents = torch.cat(ae_latents, dim=0)\n",
                "gt_spikes = torch.cat(gt_spikes, dim=0)\n",
                "if not true_data:\n",
                "    gt_rates = torch.cat(gt_rates, dim=0)\n",
                "    gt_latents = torch.cat(gt_latents, dim=0)\n",
                "    \n",
                "\n",
                "\n",
                "# diffusion eval\n",
                "sampled_latents = sample(\n",
                "    ema_denoiser=ema_model, scheduler=scheduler, cfg=cfg, batch_size=ae_rates.shape[0], device=\"cuda\", signal_length=cfg.dataset.signal_length * n_siglen\n",
                ")\n",
                "\n",
                "# project back to non standardized space\n",
                "sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "    sampled_latents.device\n",
                ") + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "\n",
                "with torch.no_grad():\n",
                "    sampled_rates = ae_model.decode(sampled_latents).cpu()\n",
                "    \n",
                "diffusion_rates.append(sampled_rates)\n",
                "diffusion_latents.append(sampled_latents.cpu())\n",
                "\n",
                "# concatenate along batch dimension\n",
                "diffusion_rates = torch.cat(diffusion_rates, dim=0)\n",
                "diffusion_latents = torch.cat(diffusion_latents, dim=0)\n",
                "\n",
                "vecs = [ae_rates, ae_latents, diffusion_rates, diffusion_latents, gt_spikes]\n",
                "\n",
                "if not true_data:\n",
                "    vecs.append(gt_rates)\n",
                "    vecs.append(gt_latents)\n",
                "\n",
                "vecs = [vec.cpu().numpy() for vec in vecs]\n",
                "vecs = [vec[:sample_cutoff] for vec in vecs]\n",
                "vecs = [rearrange(vec, 'b n t -> b t n') for vec in vecs]\n",
                "\n",
                "\n",
                "if not true_data:\n",
                "    ae_rates, ae_latents, diffusion_rates, diffusion_latents, gt_spikes, gt_rates, gt_latents = vecs\n",
                "else:\n",
                "    ae_rates, ae_latents, diffusion_rates, diffusion_latents, gt_spikes = vecs\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# generate figures"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save_path"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "%matplotlib inline"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "color_dict = {\n",
                "    \"ae\": \"midnightblue\",##2d3047\",\n",
                "    \"ael\": \"royalblue\",#\"cornflowerblue\",#6ca6c1\",\n",
                "    \"diff\": \"darkred\",#A44A3F\",\n",
                "    \"diffl\": \"orangered\", #84271F\",\n",
                "    # \"sim\": \"#afcb90\",\n",
                "    # \"siml\": \"#495f41\",\n",
                "    \"gt\": \"darkgrey\",\n",
                "    \"gtl\": \"#808080\",\n",
                "}\n",
                "lab_dict = {\n",
                "    \"ae\": \"ae\",\n",
                "    \"ael\": \"latents ae\",\n",
                "    \"diff\": \"diffusion\",\n",
                "    \"diffl\": \"latents diffusion\",\n",
                "    \"gt\": \"gt\",\n",
                "    \"gtl\": \"latents gt\",\n",
                "}\n",
                "\n",
                "# plot all colors \n",
                "plt.figure(figsize=cm2inch(8, 5))\n",
                "for i, (key, color) in enumerate(color_dict.items()):\n",
                "    plt.plot(ae_rates[0, :, i*4], color=color, label=lab_dict[key])\n",
                "plt.legend()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Perform L2 checks for overfitting the trianign set "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "run_overfit=False\n",
                "if run_overfit:\n",
                "    from ntldm.utils.utils import l2_distances\n",
                "    samp_dists, _, _, _ = l2_distances(\n",
                "        torch.from_numpy(np.float32(diffusion_latents)),\n",
                "        torch.from_numpy(np.float32(diffusion_latents)),\n",
                "    )\n",
                "    real_dists, _, _, _ = l2_distances(\n",
                "        torch.from_numpy(np.float32(ae_latents)),\n",
                "        torch.from_numpy(np.float32(ae_latents)),\n",
                "    )\n",
                "\n",
                "    dists, close_ids, ex1, ex2 = l2_distances(\n",
                "        torch.from_numpy(np.float32(diffusion_latents)), torch.from_numpy(np.float32(ae_latents))\n",
                "    )\n",
                "\n",
                "    plt.plot(np.sort(dists.numpy().flatten()), label='distance real train')\n",
                "    plt.plot(np.sort(real_dists.numpy().flatten()), label='distance within train')\n",
                "    plt.plot(np.sort(samp_dists.numpy().flatten()), label='distance within samples')\n",
                "    plt.vlines(len(diffusion_latents), 0.0, 70.0, colors=\"black\", linestyles=\"dashed\")\n",
                "    plt.xscale(\"log\")\n",
                "    plt.legend()\n",
                "    plt.show()\n",
                "\n",
                "\n",
                "    samp_dists, _, _, _ = l2_distances(\n",
                "        torch.from_numpy(np.float32(diffusion_rates)),\n",
                "        torch.from_numpy(np.float32(diffusion_rates)),\n",
                "    )\n",
                "    real_dists, _, _, _ = l2_distances(\n",
                "        torch.from_numpy(np.float32(ae_rates)),\n",
                "        torch.from_numpy(np.float32(ae_rates)),\n",
                "    )\n",
                "\n",
                "    dists, close_ids, ex1, ex2 = l2_distances(\n",
                "        torch.from_numpy(np.float32(diffusion_rates)), torch.from_numpy(np.float32(ae_rates))\n",
                "    )\n",
                "\n",
                "    plt.plot(np.sort(dists.numpy().flatten()), label='distance real train')\n",
                "    plt.plot(np.sort(real_dists.numpy().flatten()), label='distance within train')\n",
                "    plt.plot(np.sort(samp_dists.numpy().flatten()), label='distance within samples')\n",
                "    plt.vlines(len(diffusion_latents), 0.0, 70.0, colors=\"black\", linestyles=\"dashed\")\n",
                "    plt.xscale(\"log\")\n",
                "    plt.legend()\n",
                "    plt.show()\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "ae_rates.shape, ae_latents.shape, diffusion_rates.shape, diffusion_latents.shape, gt_spikes.shape"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# make 2 subplot with height ratio 3:1\n",
                "fig, axs = plt.subplots(2, 1, figsize=cm2inch(8, 8), sharex=True, gridspec_kw={\"height_ratios\": [3, 1]})\n",
                " "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# save_path = 'exp/'+cfg.exp_name+'/ae_diffusion_comparison/'\n",
                "# save_path = '/home/anonauthor/anonloc1/results/projects/latent-diffusion/LORENZ/ae_diffusion_comparison/'\n",
                "# os.makedirs(save_path, exist_ok=True)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "fps=1\n",
                "plot_n_channel_sd(gt_rates,\n",
                "                  diffusion_rates,\n",
                "                  channels=[0, 10, 30, 50, 80, 100, 120, 127], fps=fps,\n",
                "                  save=True, save_path=save_path+'rate_power_spec_density_diffusion_gt',\n",
                "                  colors=[color_dict['gt'], color_dict['diff']], \n",
                "                  labels=['gt', 'diff'], ystack=2, figsize=cm2inch(14, 6.5), lw=1)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "fps=1\n",
                "plot_n_channel_sd(gt_rates,\n",
                "                  ae_rates,\n",
                "                  channels=[0, 10, 30, 50, 80, 100, 120, 127], fps=fps,\n",
                "                  save=True, save_path=save_path+'rate_power_spec_density',\n",
                "                  colors=[color_dict['gt'], color_dict['ae']], \n",
                "                  labels=['gt', 'ae'], ystack=2, figsize=cm2inch(14, 6.5), lw=1)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "channels = np.random.choice(range(8),8,replace=False)\n",
                "\n",
                "plot_n_channel_sd(ae_latents,diffusion_latents,\n",
                "                  channels=[4,5,6,7,0,1,2,3],\n",
                "                  fps=fps, save=True, save_path=save_path+'Fig2_latents_power_spec_density_diff_ae',\n",
                "                  colors=[color_dict['ael'], color_dict['diffl']],\n",
                "                  labels=['ae', 'diff'], ystack=2, figsize=cm2inch(14, 6.5), lw=1)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plot_n_channel_sd( ae_rates,diffusion_rates,\n",
                "                channels=[0, 10, 30, 50, 80, 100, 120, 127],\n",
                "                  fps=fps,  save=True, save_path=save_path+'rates_power_spec_density_diff_ae',\n",
                "                  colors=[ color_dict['ae'], color_dict['diff']],\n",
                "                  labels=['ae', 'diff'],ystack=2, figsize=cm2inch(14, 6.5), lw=1)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def group_neurons_by_corr(data, num_groups=4):\n",
                "    \"\"\"Group neurons by their overall correlation.\"\"\"\n",
                "    C_mat = correlation_matrix(gt_spikes, mode=\"concatenate\")\n",
                "    np.fill_diagonal(C_mat, 0)\n",
                "    # sum up the square of the correlations\n",
                "    summed_sq_corr = np.sum(C_mat ** 2, axis=0)\n",
                "    sorted_indices = np.argsort(-summed_sq_corr)\n",
                "    \n",
                "    # group size \n",
                "    group_size = len(sorted_indices) // num_groups\n",
                "    # split the indices into groups \n",
                "    groups = [sorted_indices[i * group_size:(i + 1) * group_size] for i in range(num_groups)]\n",
                "    return groups\n",
                "\n",
                "grouped_neurons = group_neurons_by_corr(gt_spikes, num_groups=4)\n",
                "\n",
                "grouped_neurons"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import group_neurons_temp_corr, get_temp_corr_summary\n",
                "from ntldm.utils.plotting_utils import plot_temp_corr_summary\n",
                "import pickle\n",
                "# group neurons acccording to their correlation to make pairwise correlation plots more efficient to compute\n",
                "# calculate the time lagged auto correlation and cross correlation for each group\n",
                "# set correlation of neuron with itself to 0 at lag 0\n",
                "calc_again=True\n",
                "\n",
                "if calc_again:\n",
                "    ae_spikes = np.random.poisson(ae_rates)\n",
                "    diff_spikes = np.random.poisson(diffusion_rates)\n",
                "\n",
                "    groups = group_neurons_temp_corr(gt_spikes.transpose(1,0,2), num_groups=4)\n",
                "    cross_corr_groups, auto_corr_groups = get_temp_corr_summary(gt_spikes.transpose(1,0,2), groups, nlags=30,mode='biased',\n",
                "                                                                batch_first=False)\n",
                "    fig_cross, fig_auto = plot_temp_corr_summary(cross_corr_groups, auto_corr_groups, name='Data')\n",
                "\n",
                "    cross_corr_groups_sampled, auto_corr_groups_sampled = get_temp_corr_summary(ae_spikes.transpose(1,0,2), groups, nlags=30, mode='biased',\n",
                "                                                                batch_first=False)\n",
                "    fig_cross_sampled, fig_auto_sampled = plot_temp_corr_summary(cross_corr_groups_sampled, auto_corr_groups_sampled, name='AE Samples')\n",
                "\n",
                "\n",
                "    cross_corr_groups_diff, auto_corr_groups_diff = get_temp_corr_summary(diff_spikes.transpose(1,0,2), groups, nlags=30, mode='biased',\n",
                "                                                                batch_first=False)\n",
                "    fig_cross_sampled, fig_auto_sampled = plot_temp_corr_summary(cross_corr_groups_diff, auto_corr_groups_diff, name='Diffusion Samples')\n",
                "\n",
                "\n",
                "else:\n",
                "        \n",
                "    # load those pickle files\n",
                "    import pickle\n",
                "    with open(save_path+'cross_corr_groups.pkl', 'rb') as f:\n",
                "        cross_corr_groups, auto_corr_groups = pickle.load(f)\n",
                "    with open(save_path+'cross_corr_groups_sampled.pkl', 'rb') as f:\n",
                "        cross_corr_groups_sampled, auto_corr_groups_sampled = pickle.load(f)\n",
                "    with open(save_path+'cross_corr_groups_diff.pkl', 'rb') as f:\n",
                "        cross_corr_groups_diff, auto_corr_groups_diff = pickle.load(f)\n",
                "        \n",
                "        "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# # pickle all of the results\n",
                "# import pickle\n",
                "# with open(save_path+'cross_corr_groups.pkl', 'wb') as f:\n",
                "#     pickle.dump([cross_corr_groups, auto_corr_groups], f)\n",
                "# with open(save_path+'cross_corr_groups_sampled.pkl', 'wb') as f:\n",
                "#     pickle.dump([cross_corr_groups_sampled, auto_corr_groups_sampled], f)\n",
                "# with open(save_path+'cross_corr_groups_diff.pkl', 'wb') as f:\n",
                "#     pickle.dump([cross_corr_groups_diff, auto_corr_groups_diff], f)\n",
                "    \n",
                "    \n",
                "# # load those pickle files\n",
                "# import pickle\n",
                "# with open(save_path+'cross_corr_groups.pkl', 'rb') as f:\n",
                "#     cross_corr_groups, auto_corr_groups = pickle.load(f)\n",
                "# with open(save_path+'cross_corr_groups_sampled.pkl', 'rb') as f:\n",
                "#     cross_corr_groups_sampled, auto_corr_groups_sampled = pickle.load(f)\n",
                "# with open(save_path+'cross_corr_groups_diff.pkl', 'rb') as f:\n",
                "#     cross_corr_groups_diff, auto_corr_groups_diff = pickle.load(f)\n",
                "    \n",
                "    "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_cross_corr_summary\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(3, 2))\n",
                "save = True\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=2,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=2,\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"Fig_2_cross_corr_gt_diff.png\")\n",
                "    fig.savefig(save_path + \"Fig_2_cross_corr_gt_diff.pdf\")\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(3, 2))\n",
                "save = True\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(3, 1.5),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=2,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(3, 1.5),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=2,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"Fig2_auto_corr_gt_diff.png\")\n",
                "    fig.savefig(save_path + \"Fig2_auto_corr_gt_diff.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_cross_corr_summary\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    save=True,\n",
                "    save_path=save_path + \"cross_corr_gt\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    save=True,\n",
                "    save_path=save_path + \"cross_corr_diffusion\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_sampled,\n",
                "    name=\"ae\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Blues\",\n",
                "    save=True,\n",
                "    save_path=save_path + \"cross_corr_ae\",\n",
                ")\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8, 6))\n",
                "save = True\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=3,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=3,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_sampled,\n",
                "    name=\"ae\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Blues\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"ae \",\n",
                "    ncol=3,\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"cross_corr_all.png\")\n",
                "    fig.savefig(save_path + \"cross_corr_all.pdf\")\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(6, 4))\n",
                "save = True\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=2,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=2,\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"cross_corr_gt_diff.png\")\n",
                "    fig.savefig(save_path + \"cross_corr_gt_diff.pdf\")\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8, 6))\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=3,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=3,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups_sampled,\n",
                "    name=\"ae\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Blues\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"ae \",\n",
                "    ncol=3,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"auto_corr_all.png\")\n",
                "    fig.savefig(save_path + \"auto_corr_all.pdf\")\n",
                "    \n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(6, 4))\n",
                "save = True\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=2,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=2,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"auto_corr_gt_diff.png\")\n",
                "    fig.savefig(save_path + \"auto_corr_gt_diff.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save=True\n",
                "# Define custom function to create a colormap with specified colors\n",
                "def create_custom_cmap(base_cmap, num_colors):\n",
                "    \"\"\"Create a custom colormap with num_colors from a base colormap.\"\"\"\n",
                "    cmap = plt.cm.get_cmap(base_cmap)\n",
                "    colors = [cmap(i) for i in range(cmap.N)]\n",
                "    return cmap.from_list('Custom cmap', colors, num_colors)\n",
                "\n",
                "# Generate custom colormaps for different shades of grey, blue, and another shade of blue\n",
                "num_shades = 20\n",
                "grey_cmap = create_custom_cmap('Greys', num_shades)\n",
                "red_cmap = create_custom_cmap('Reds', num_shades)\n",
                "blue_cmap = create_custom_cmap('Blues', num_shades)  # Adjust base cmap as needed\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(3, 1, figsize=cm2inch(8, 8), sharex=True)\n",
                "for i in range(20):\n",
                "    ax[0].plot(gt_rates[0,:, i], color=grey_cmap(i), label=f\"gt {i}\")\n",
                "    ax[1].plot(ae_rates[0,:, i], color=blue_cmap(i), label=f\"ae {i}\")\n",
                "\n",
                "    ax[2].plot(diffusion_rates[0,:, i], color=red_cmap(i), label=f\"diffusion {i}\")\n",
                "\n",
                "\n",
                "ax[0].set_ylabel(\"gt rates\")\n",
                "ax[1].set_ylabel(\"ae rates\")\n",
                "ax[2].set_ylabel(\"diffusion\")\n",
                "\n",
                "plt.tight_layout()\n",
                "plt.xlabel(\"time (a.u.)\")\n",
                "\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"supp_fig_rates_all.png\")\n",
                "    fig.savefig(save_path + \"supp_fig_rates_all.pdf\")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Plot population spike histogram"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_population_spike_histogram\n",
                "plot_population_spike_histogram(\n",
                "    gt_spikes,\n",
                "    np.random.poisson(ae_rates),\n",
                "    labels=[\"gt\", \"ae\"],\n",
                "    colors=[color_dict[\"gt\"], color_dict[\"ae\"]],\n",
                "    save=True,\n",
                "    save_path=save_path + \"population_spike_hist_ae\",\n",
                ")\n",
                "plot_population_spike_histogram(\n",
                "    gt_spikes,\n",
                "    np.random.poisson(diffusion_rates),\n",
                "    labels=[\"gt\", \"diff\"],\n",
                "    colors=[color_dict[\"gt\"], color_dict[\"diff\"]],\n",
                "    save=True,\n",
                "    save_path=save_path + \"population_spike_hist_diff\",\n",
                ")\n",
                "\n",
                "fig, axs = plt.subplots(1, 2, figsize=cm2inch(8, 3.5))\n",
                "plot_population_spike_histogram(\n",
                "    gt_spikes,\n",
                "    np.random.poisson(ae_rates),\n",
                "    labels=[\"gt\", \"ae\"],\n",
                "    colors=[color_dict[\"gt\"], color_dict[\"ae\"], color_dict[\"diff\"]],\n",
                "    ax=axs[0],\n",
                "    x_label=\"# spikes/bin\",\n",
                "    y_label=\"frequency\",\n",
                ")\n",
                "plot_population_spike_histogram(\n",
                "    gt_spikes,\n",
                "    np.random.poisson(diffusion_rates),\n",
                "    labels=[\"gt\", \"diff\"],\n",
                "    colors=[color_dict[\"gt\"], color_dict[\"diff\"], color_dict[\"diff\"]],\n",
                "    ax=axs[1],\n",
                "    x_label=\"# spikes/bin\",\n",
                "    y_label=\"frequency\",\n",
                ")\n",
                "plt.tight_layout()\n",
                "fig.savefig(save_path + \"population_spike_hist_all.png\")\n",
                "fig.savefig(save_path + \"population_spike_hist_all.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# plot reconstructed spikes\n",
                "diffusion_spikes = np.random.poisson(diffusion_rates)\n",
                "max_val_pop_count = np.max([diffusion_spikes.sum(2).max(), gt_spikes.sum(2).max()])\n",
                "max_val_pop_count = int(max_val_pop_count)\n",
                "max_val_pop_count =100\n",
                "bins = np.linspace(-0.5, max_val_pop_count-0.5, max_val_pop_count+1)\n",
                "bins = np.linspace(0, 100, 50+1)\n",
                "\n",
                "plt.figure(figsize=cm2inch((2.5, 2)))\n",
                "plt.hist(gt_spikes.sum(2).flatten(), density=True, color='grey', bins=bins, alpha=0.5)\n",
                "plt.hist(diffusion_spikes.sum(2).flatten(), density=True, color='darkred', bins=bins, alpha=0.5)\n",
                "\n",
                "plt.legend(['gt', 'ldns'])\n",
                "plt.title('population spike count')\n",
                "plt.ylabel('frequency')\n",
                "plt.xlabel('number of spikes per time bin')\n",
                "plt.savefig(save_path + 'Fig_2_population_spike_count_diff.png')\n",
                "plt.savefig(save_path + 'Fig_2_population_spike_count_diff.pdf')\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plot_population_spike_histogram(\n",
                "    gt_spikes,\n",
                "    np.random.poisson(diffusion_rates),\n",
                "    labels=[\"gt\", \"diff\"],\n",
                "    colors=[color_dict[\"gt\"], color_dict[\"diff\"]],\n",
                "    save=True,\n",
                "    save_path=save_path + \"population_spike_hist_diff\",\n",
                "    figsize=cm2inch(4, 3.5),\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import rmse_nan, average_rates, std_rates, kl_div_nan, correlation_matrix\n",
                "from ntldm.utils.eval_utils import compute_spike_stats_per_neuron, counts_to_spike_trains\n",
                "from scipy.stats import gaussian_kde\n",
                "from ntldm.utils.eval_utils import group_neurons_temp_corr, get_temp_corr_summary\n",
                "\n",
                "\n",
                "\n",
                "def rmse_mean_rate(spikes, gt):\n",
                "    gt_m = average_rates(gt, mode='neur', fps=fps_monkey,).flatten()\n",
                "    spikes_m = average_rates(spikes, mode='neur', fps=fps_monkey).flatten()\n",
                "    return rmse_nan(gt_m, spikes_m)\n",
                "\n",
                "def rmse_std_rate(spikes, gt):\n",
                "    gt_s = std_rates(gt, mode='neur', fps=fps_monkey,).flatten()\n",
                "    spikes_s = std_rates(spikes, mode='neur', fps=fps_monkey).flatten()\n",
                "    return rmse_nan(gt_s, spikes_s)\n",
                "\n",
                "def get_spike_train_and_stats(spikes, fps=200):\n",
                "    spiketrain = counts_to_spike_trains(spikes, fps=fps)\n",
                "    spike_stats_m = compute_spike_stats_per_neuron(\n",
                "        spiketrain,\n",
                "        n_samples=spikes.shape[0],\n",
                "        n_neurons=spikes.shape[2],\n",
                "        mean_output=False,\n",
                "    )\n",
                "    return spike_stats_m\n",
                "\n",
                "\n",
                "# plot population spike histogram for diffusion as well as correlation and isi\n",
                "color_dict = {\n",
                "    'lfads': 'purple',\n",
                "    'ldns': 'darkred',\n",
                "    'hist': 'darkgreen',\n",
                "    'train': 'grey'\n",
                "}\n",
                "\n",
                "spikes_dict_single = {\n",
                "    'ldns': diffusion_spikes\n",
                "}\n",
                "# get the overall maximum population spike count to calculate the kl div\n",
                "max_vals = []\n",
                "for key in spikes_dict_single.keys():\n",
                "    max_vals.append(spikes_dict_single[key].sum(2).max())\n",
                "max_vals.append(gt_spikes.sum(2).max())\n",
                "maxval = int(np.max(max_vals))\n",
                "\n",
                "\n",
                "\n",
                "for i, (key, val) in enumerate(spikes_dict_single.items()):\n",
                "    # if key != 'ldns':\n",
                "    #     continue\n",
                "    fig, ax = plt.subplots(2, 2, figsize=cm2inch(9, 7))\n",
                "    ax =ax.flatten()\n",
                "    kde_model = gaussian_kde(val.sum(2).flatten())\n",
                "    kde_gt = gaussian_kde(gt_spikes.sum(2).flatten())\n",
                "\n",
                "    # Evaluating densities over a common range derived from data\n",
                "    x_eval = np.linspace(0, maxval, maxval+1)\n",
                "    density_model = kde_model(x_eval)\n",
                "    density_gt = kde_gt(x_eval)\n",
                "    density_model /= density_model.sum()\n",
                "    density_gt /= density_gt.sum()\n",
                "    \n",
                "    bins_psc = np.arange(-0.5, maxval-0.5, 1)\n",
                "    # plot the population spike count histogram for the lfads model and the data\n",
                "    ax[0].hist(gt_spikes.sum(2).flatten(), bins=bins_psc, density=True, alpha=0.5, label='data', color='grey')\n",
                "    ax[0].hist(val.sum(2).flatten(), bins=bins_psc, density=True, alpha=0.5, label=key, color=color_dict[key])\n",
                "\n",
                "    # now plot the density estimate\n",
                "    ax[0].plot(x_eval, density_gt, '.-', label='data kde', color='black')\n",
                "    ax[0].plot(x_eval, density_model, '.-', label=key+' kde', color=color_dict[key])\n",
                "    \n",
                "    ax[0].legend(fontsize=6)\n",
                "    ax[0].set_xlabel('spike count')\n",
                "    ax[0].set_ylabel('density')\n",
                "    \n",
                "    \n",
                "    # now plot the correlation matrix\n",
                "    \n",
                "    # get the correlation structure\n",
                "    C_model = correlation_matrix(val, mode=\"concatenate\")\n",
                "    np.fill_diagonal(C_model, 0)\n",
                "    C_model = np.tril(C_model, k=-1)\n",
                "    \n",
                "    C_gt = correlation_matrix(gt_spikes, mode=\"concatenate\")\n",
                "    np.fill_diagonal(C_gt, 0)\n",
                "    C_gt = np.tril(C_gt, k=-1)\n",
                "    \n",
                "    ax[1].plot(C_gt.flatten(), C_model.flatten(), '.', alpha=0.5, color=color_dict[key], ms=2)\n",
                "    data_limits = [min(C_gt.min(), C_model.min()), max(C_gt.max(), C_model.max())]\n",
                "    #data_limits = [-0.02, 0.05]\n",
                "    ax[1].plot([data_limits[0], data_limits[1]], [data_limits[0], data_limits[1]], '--', color='black')\n",
                "    # make a 1:1 line that spans the whole plot\n",
                "    #ax[1].plot([0, 1], [0, 1], '--', color='black', transform= ax[1].transAxes)\n",
                "    ax[1].set_xlabel('gt')\n",
                "    ax[1].set_ylabel(key)\n",
                "    ax[1].set_aspect('equal')\n",
                "    data_limis_ax = [data_limits[0]-0.15*np.abs(data_limits[0]), data_limits[1]+0.15*np.abs(data_limits[1])]\n",
                "    ax[1].set_xlim(data_limis_ax)\n",
                "    ax[1].set_ylim(data_limis_ax)\n",
                "\n",
                "\n",
                "    # compute the spikes stats\n",
                "    spike_stats_gt = get_spike_train_and_stats(gt_spikes, fps=1)\n",
                "\n",
                "    spike_stats_model = get_spike_train_and_stats(val, fps=1)\n",
                "\n",
                "\n",
                "    \n",
                "    ax[2].plot(spike_stats_gt['mean_isi'].flatten(),spike_stats_model['mean_isi'].flatten(), '.', alpha=0.5, color=color_dict[key], ms=2)\n",
                "    data_limits = [min(spike_stats_gt['mean_isi'].flatten().min(), spike_stats_model['mean_isi'].flatten().min()), max(spike_stats_gt['mean_isi'].flatten().max(), spike_stats_model['mean_isi'].flatten().max())]\n",
                "    #data_limits = [-0.02, 0.05]\n",
                "    ax[2].plot([data_limits[0], data_limits[1]], [data_limits[0], data_limits[1]], '--', color='black')\n",
                "    # make a 1:1 line that spans the whole plot\n",
                "    #ax[1].plot([0, 1], [0, 1], '--', color='black', transform= ax[1].transAxes)\n",
                "    ax[2].set_xlabel('gt mean isi')\n",
                "    ax[2].set_ylabel(key + ' mean isi')\n",
                "    ax[2].set_aspect('equal')\n",
                "    data_limis_ax = [data_limits[0]-0.15*np.abs(data_limits[0]), data_limits[1]+0.15*np.abs(data_limits[1])]\n",
                "    ax[2].set_xlim(data_limis_ax)\n",
                "    ax[2].set_ylim(data_limis_ax)\n",
                "\n",
                "    ax[3].plot(spike_stats_gt['std_isi'].flatten(),spike_stats_model['std_isi'].flatten(), '.', alpha=0.5, color=color_dict[key], ms=2)\n",
                "    data_limits = [min(spike_stats_gt['std_isi'].flatten().min(), spike_stats_model['std_isi'].flatten().min()), max(spike_stats_gt['std_isi'].flatten().max(), spike_stats_model['std_isi'].flatten().max())]\n",
                "    #data_limits = [-0.02, 0.05]\n",
                "    ax[3].plot([data_limits[0], data_limits[1]], [data_limits[0], data_limits[1]], '--', color='black')\n",
                "    # make a 1:1 line that spans the whole plot\n",
                "    #ax[1].plot([0, 1], [0, 1], '--', color='black', transform= ax[1].transAxes)\n",
                "    ax[3].set_xlabel('gt std isi')\n",
                "    ax[3].set_ylabel(key + ' std isi')\n",
                "    ax[3].set_aspect('equal')\n",
                "    data_limis_ax = [data_limits[0]-0.15*np.abs(data_limits[0]), data_limits[1]+0.15*np.abs(data_limits[1])]\n",
                "    ax[3].set_xlim(data_limis_ax)\n",
                "    ax[3].set_ylim(data_limis_ax)\n",
                "\n",
                "    plt.tight_layout()\n",
                "    # save the figure\n",
                "    plt.savefig(save_path + f'Fig_population_and_single_spike_stats_{key}.png')\n",
                "    plt.savefig(save_path + f'Fig_population_and_single_spike_stats_{key}.pdf')\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from scipy import stats\n",
                "%matplotlib inline\n",
                "\n",
                "for i, (key, val) in enumerate(spikes_dict_single.items()):\n",
                "    # if key != 'ldns':\n",
                "    #     continue\n",
                "    fig, ax = plt.subplots(2, 2, figsize=cm2inch(9, 7))\n",
                "    ax =ax.flatten()\n",
                "    kde_model = gaussian_kde(val.sum(2).flatten())\n",
                "    kde_gt = gaussian_kde(gt_spikes.sum(2).flatten())\n",
                "\n",
                "    # Evaluating densities over a common range derived from data\n",
                "    x_eval = np.linspace(0, maxval, maxval+1)\n",
                "    density_model = kde_model(x_eval)\n",
                "    density_gt = kde_gt(x_eval)\n",
                "    density_model /= density_model.sum()\n",
                "    density_gt /= density_gt.sum()\n",
                "    \n",
                "    bins_psc = np.arange(-0.5, maxval-0.5, 1)\n",
                "    # plot the population spike count histogram for the lfads model and the data\n",
                "    ax[0].hist(gt_spikes.sum(2).flatten(), bins=bins_psc, density=True, alpha=0.5, label='data', color='grey')\n",
                "    ax[0].hist(val.sum(2).flatten(), bins=bins_psc, density=True, alpha=0.5, label=key, color=color_dict[key])\n",
                "\n",
                "    # now plot the density estimate\n",
                "    ax[0].plot(x_eval, density_gt, '.-', label='data kde', color='black')\n",
                "    ax[0].plot(x_eval, density_model, '.-', label=key+' kde', color=color_dict[key])\n",
                "    \n",
                "    ax[0].legend(fontsize=6)\n",
                "    ax[0].set_xlabel('spike count')\n",
                "    #ax[0].set_ylabel('density')\n",
                "    ax[0].set_yticks([])\n",
                "\n",
                "    \n",
                "    \n",
                "    # now plot the correlation matrix\n",
                "    \n",
                "    # get the correlation structure\n",
                "    C_model = correlation_matrix(val, mode=\"concatenate\")\n",
                "    np.fill_diagonal(C_model, 0)\n",
                "    C_model = np.tril(C_model, k=-1)\n",
                "    \n",
                "    C_gt = correlation_matrix(gt_spikes, mode=\"concatenate\")\n",
                "    np.fill_diagonal(C_gt, 0)\n",
                "    C_gt = np.tril(C_gt, k=-1)\n",
                "    \n",
                "    ax[1].plot(C_gt.flatten(), C_model.flatten(), '.', alpha=0.5, color=color_dict[key], ms=2)\n",
                "    data_limits = [min(C_gt.min(), C_model.min()), max(C_gt.max(), C_model.max())]\n",
                "    #data_limits = [-0.02, 0.05]\n",
                "    ax[1].plot([data_limits[0], data_limits[1]], [data_limits[0], data_limits[1]], '--', color='black')\n",
                "    # make a 1:1 line that spans the whole plot\n",
                "    #ax[1].plot([0, 1], [0, 1], '--', color='black', transform= ax[1].transAxes)\n",
                "    ax[1].set_xlabel('gt')\n",
                "    ax[1].set_ylabel(key)\n",
                "    ax[1].set_aspect('equal')\n",
                "    data_limis_ax = [data_limits[0]-0.15*np.abs(data_limits[0]), data_limits[1]+0.15*np.abs(data_limits[1])]\n",
                "    ax[1].set_xlim(data_limis_ax)\n",
                "    ax[1].set_ylim(data_limis_ax)\n",
                "\n",
                "\n",
                "    # # compute the spikes stats\n",
                "    # spike_stats_gt = get_spike_train_and_stats(gt_spikes, fps=1)\n",
                "\n",
                "    # spike_stats_model = get_spike_train_and_stats(val, fps=1)\n",
                "\n",
                "\n",
                "    \n",
                "    ax[2].plot(spike_stats_gt['mean_isi'].flatten(),spike_stats_model['mean_isi'].flatten(), '.', alpha=0.5, color=color_dict[key], ms=2)\n",
                "    data_limits = [min(spike_stats_gt['mean_isi'].flatten().min(), spike_stats_model['mean_isi'].flatten().min()), max(spike_stats_gt['mean_isi'].flatten().max(), spike_stats_model['mean_isi'].flatten().max())]\n",
                "    #data_limits = [-0.02, 0.05]\n",
                "    ax[2].plot([data_limits[0], data_limits[1]], [data_limits[0], data_limits[1]], '--', color='black')\n",
                "    # make a 1:1 line that spans the whole plot\n",
                "    #ax[1].plot([0, 1], [0, 1], '--', color='black', transform= ax[1].transAxes)\n",
                "    ax[2].set_xlabel('gt mean isi')\n",
                "    ax[2].set_ylabel(key + ' mean isi')\n",
                "    ax[2].set_aspect('equal')\n",
                "    data_limis_ax = [data_limits[0]-0.15*np.abs(data_limits[0]), data_limits[1]+0.15*np.abs(data_limits[1])]\n",
                "    ax[2].set_xlim(data_limis_ax)\n",
                "    ax[2].set_ylim(data_limis_ax)\n",
                "\n",
                "    ax[3].plot(spike_stats_gt['std_isi'].flatten(),spike_stats_model['std_isi'].flatten(), '.', alpha=0.5, color=color_dict[key], ms=2)\n",
                "    data_limits = [min(spike_stats_gt['std_isi'].flatten().min(), spike_stats_model['std_isi'].flatten().min()), max(spike_stats_gt['std_isi'].flatten().max(), spike_stats_model['std_isi'].flatten().max())]\n",
                "    #data_limits = [-0.02, 0.05]\n",
                "    ax[3].plot([data_limits[0], data_limits[1]], [data_limits[0], data_limits[1]], '--', color='black')\n",
                "    # make a 1:1 line that spans the whole plot\n",
                "    #ax[1].plot([0, 1], [0, 1], '--', color='black', transform= ax[1].transAxes)\n",
                "    ax[3].set_xlabel('gt std isi')\n",
                "    ax[3].set_ylabel(key + ' std isi')\n",
                "    ax[3].set_aspect('equal')\n",
                "    data_limis_ax = [data_limits[0]-0.08*np.abs(data_limits[0]), data_limits[1]+0.08*np.abs(data_limits[1])]\n",
                "    ax[3].set_xlim(data_limis_ax)\n",
                "    ax[3].set_ylim(data_limis_ax)\n",
                "    \n",
                "    print('pearsen R mean isi', stats.pearsonr(spike_stats_gt['mean_isi'].flatten(),spike_stats_model['mean_isi'].flatten()))\n",
                "    print('pearsen R std isi', stats.pearsonr(spike_stats_gt['std_isi'].flatten(),spike_stats_model['std_isi'].flatten()))\n",
                "    print('pearsen R corr matrices', stats.pearsonr(C_gt.flatten(), C_model.flatten()))\n",
                "    print('pearsen R density gt model', stats.pearsonr(density_gt, density_model))\n",
                "\n",
                "    plt.tight_layout()\n",
                "    # save the figure\n",
                "    plt.savefig(save_path + f'Fig_population_and_single_spike_stats_{key}.png')\n",
                "    plt.savefig(save_path + f'Fig_population_and_single_spike_stats_{key}.pdf')\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import numpy as np\n",
                "import matplotlib.pyplot as plt\n",
                "from scipy import stats\n",
                "\n",
                "\n",
                "# Perform linear regression\n",
                "slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)\n",
                "\n",
                "# Compute R^2 value\n",
                "r_squared = r_value**2\n",
                "\n",
                "# Create a scatter plot\n",
                "plt.figure(figsize=(8, 6))\n",
                "plt.scatter(x, y, color='blue', label='Data points')\n",
                "plt.plot(x, slope*x + intercept, color='red', label=f'Fit line: y={slope:.2f}x + {intercept:.2f}')\n",
                "plt.title('scatter plot with regression line')\n",
                "plt.xlabel('x')\n",
                "plt.ylabel('y')\n",
                "plt.legend()\n",
                "plt.grid(True)\n",
                "plt.show()\n",
                "\n",
                "# Returning the computed statistics\n",
                "(slope, intercept, r_squared, p_value, std_err)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "stats.pearsonr(spike_stats_gt['mean_isi'].flatten(),spike_stats_model['mean_isi'].flatten())\n",
                "# spike_stats_gt['std_isi'].flatten(),spike_stats_model['std_isi'].flatten()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import gc\n",
                "gc.collect()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import correlation_matrix\n",
                "from ntldm.utils.plotting_utils import plot_correlation_matrices\n",
                "\n",
                "# Compute correlation matrices for one sample sequence\n",
                "plot_correlation_matrices(\n",
                "    gt_rates,\n",
                "    [ae_rates, diffusion_rates],\n",
                "    sample=0,\n",
                "    model_labels=[\"ae\", \"diffusion\"],\n",
                "    model_colors=[\"midnightblue\", \"darkred\"],\n",
                "    figsize=cm2inch((16, 6)),save = True, \n",
                "    save_path=save_path + \"rates_correlation_one_sample\",  ms=1\n",
                ")\n",
                "# mode average computes the average correlation matrix over all samples\n",
                "plot_correlation_matrices(\n",
                "    gt_rates,\n",
                "    [ae_rates, diffusion_rates],\n",
                "    sample=None,\n",
                "    mode=\"average\",\n",
                "    model_labels=[\"ae\", \"diffusion\"],\n",
                "    model_colors=[\"midnightblue\", \"darkred\"],\n",
                "    figsize=cm2inch((16, 6)),\n",
                "    save = True, \n",
                "    save_path=save_path + \"rates_correlation_average_across_trials\", ms=1\n",
                ")\n",
                "plot_correlation_matrices(\n",
                "    gt_rates,\n",
                "    [ae_rates, diffusion_rates],\n",
                "    sample=None,\n",
                "    mode=\"concatenate\",\n",
                "    model_labels=[\"ae\", \"diffusion\"],\n",
                "    model_colors=[\"midnightblue\", \"darkred\"],\n",
                "    figsize=cm2inch((16, 6)),\n",
                "    save = True, \n",
                "    save_path=save_path + \"rates_correlation_concat\", ms=1\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "plot_correlation_matrices(\n",
                "    gt_rates,\n",
                "    [ae_rates, diffusion_rates],\n",
                "    sample=None,\n",
                "    mode=\"concatenate\",\n",
                "    model_labels=[\"ae\", \"diffusion\"],\n",
                "    model_colors=[\"midnightblue\", \"darkred\"],\n",
                "    figsize=cm2inch((10, 3)),\n",
                "    save = True, \n",
                "    save_path=save_path + \"Fig_2_rates_correlation_concat\", ms=1\n",
                ")\n",
                "\n",
                "\n",
                "plot_correlation_matrices(\n",
                "    gt_spikes,\n",
                "    [diffusion_spikes, diffusion_spikes],\n",
                "    sample=None,\n",
                "    mode=\"concatenate\",\n",
                "    model_labels=[None, \"diffusion\"],\n",
                "    model_colors=[\"darkred\", \"darkred\"],\n",
                "    figsize=cm2inch((10, 3)),\n",
                "    save = True, \n",
                "    save_path=save_path + \"Fig_2_spikes_correlation_concat\", ms=1\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "ae_spikes = np.random.poisson(ae_rates)\n",
                "from ntldm.utils.plotting_utils import plot_correlation_matrices_monkey\n",
                "plot_correlation_matrices_monkey(\n",
                "    gt_spikes,\n",
                "    [diffusion_spikes, diffusion_spikes],\n",
                "    sample=None,\n",
                "    mode=\"concatenate\",\n",
                "    model_labels=[\"diffusion\", None],\n",
                "    model_colors=[ \"darkred\", \"darkred\"],\n",
                "    figsize=cm2inch((12, 4)),\n",
                "    save = True, \n",
                "    save_path=save_path + \"Fig_2_spikes_correlation_concat_small\", ms=1\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "fig = plt.figure(figsize=cm2inch(8, 4))\n",
                "for i in range(20):\n",
                "    plt.plot(diffusion_latents[i], color= red_cmap(i))\n",
                "plt.xlabel(\"time (a.u.)\")\n",
                "plt.ylabel(\"sampled latents\")\n",
                "fig.savefig(save_path + \"sampled_latents.png\")\n",
                "fig.savefig(save_path + \"sampled_latents.pdf\")\n",
                "\n",
                "fig = plt.figure(figsize=cm2inch(8, 4))\n",
                "for i in range(20):\n",
                "    plt.plot(ae_latents[i], color= blue_cmap(i))\n",
                "plt.xlabel(\"time (a.u.)\")\n",
                "plt.ylabel(\"ae latents\")\n",
                "fig.savefig(save_path + \"ae_latents.png\")\n",
                "fig.savefig(save_path + \"ae_latents.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "num_channels = 64\n",
                "\n",
                "fig, axs = plt.subplots(num_channels // 10 + 1, 10, figsize=cm2inch((60, 40)))\n",
                "for idx in range(num_channels):\n",
                "    plot_sd(\n",
                "        fig=fig,\n",
                "        ax=axs[idx // 10, idx % 10],\n",
                "        arr_one=gt_rates[:, :, idx],\n",
                "        arr_two=ae_rates[:, :, idx],\n",
                "        fs=200,\n",
                "        nperseg=260,\n",
                "        agg_function=np.median,\n",
                "        with_quantiles=True,\n",
                "        x_ss=slice(0, 60),\n",
                "        color_one=\"black\",\n",
                "        color_two=\"C0\",\n",
                "    )\n",
                "fig.savefig(save_path + \"rate_spetral_density_ae_gt.png\")\n",
                "fig.savefig(save_path + \"rate_spetral_density_ae_gt.pdf\")\n",
                "\n",
                "num_channels = 64\n",
                "\n",
                "fig, axs = plt.subplots(num_channels // 10 + 1, 10, figsize=cm2inch((60, 40)))\n",
                "for idx in range(num_channels):\n",
                "    plot_sd(\n",
                "        fig=fig,\n",
                "        ax=axs[idx // 10, idx % 10],\n",
                "        arr_one=gt_rates[:, :, idx],\n",
                "        arr_two=diffusion_rates[:, :, idx],\n",
                "        fs=200,\n",
                "        nperseg=260,\n",
                "        agg_function=np.median,\n",
                "        with_quantiles=True,\n",
                "        x_ss=slice(0, 60),\n",
                "        color_one=\"black\",\n",
                "        color_two=\"C3\",\n",
                "    )\n",
                "fig.savefig(save_path + \"rate_spetral_density_diffusion_gt.png\")\n",
                "fig.savefig(save_path + \"rate_spetral_density_diffusion_gt.pdf\")\n",
                "#plt.title('diffusion')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import bits_per_spike, neg_log_likelihood, rmse, compute_fano_factor\n",
                "bps_recs = bits_per_spike(ae_rates, gt_spikes)\n",
                "bps_gt = bits_per_spike(gt_rates, gt_spikes)\n",
                "nll_recs = neg_log_likelihood(ae_rates, gt_spikes,reduction=\"none\").mean((0,1))\n",
                "nll_gt = neg_log_likelihood(gt_rates, gt_spikes,reduction=\"none\").mean((0,1))\n",
                "mse_gt_recs = rmse(gt_rates, ae_rates)\n",
                "fano_gt = compute_fano_factor(gt_spikes)\n",
                "fano_recs = compute_fano_factor(np.random.poisson(ae_rates))\n",
                "fano_diffusion = compute_fano_factor(np.random.poisson(diffusion_rates))\n",
                "\n",
                "# plot boxplots of the fano factors grey is gt, blue is recs\n",
                "plt.figure(figsize=cm2inch((2, 4)))\n",
                "plt.boxplot([fano_gt, fano_recs], showfliers=False, widths=0.5, meanline=True,\n",
                "             meanprops=dict(color=\"red\"), medianprops=dict(color=\"black\"), patch_artist=True)\n",
                "# plot the dots\n",
                "jitter = np.random.normal(0, 0.05, len(fano_gt))\n",
                "plt.plot(np.ones_like(fano_gt)+jitter, fano_gt, \".\", color=\"grey\", alpha=0.5)\n",
                "plt.plot(2 * np.ones_like(fano_recs)+jitter, fano_recs, \".\", color=\"grey\", alpha=0.5)\n",
                "plt.xticks([1, 2], [\"gt\", \"recs\"])\n",
                "plt.ylabel('fano factor')\n",
                "\n",
                "# plot boxplots of the fano factors grey is gt, blue is recs\n",
                "plt.figure(figsize=cm2inch((2, 4)))\n",
                "plt.boxplot([nll_gt, nll_recs], showfliers=False, widths=0.5, meanline=True,\n",
                "             meanprops=dict(color=\"red\"), medianprops=dict(color=\"black\"), patch_artist=True)\n",
                "# plot the dots\n",
                "jitter = np.random.normal(0, 0.05, len(nll_recs))\n",
                "plt.plot(np.ones_like(nll_gt)+jitter, nll_gt, \".\", color=\"grey\", alpha=0.5)\n",
                "plt.plot(2 * np.ones_like(nll_recs)+jitter, nll_recs, \".\", color=\"grey\", alpha=0.5)\n",
                "plt.xticks([1, 2], [\"gt\", \"recs\"])\n",
                "plt.ylabel('nll')"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Fix to test set again"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# Example data\n",
                "np.random.seed(0)  # For reproducibility\n",
                "# Plot setup\n",
                "fig = plt.figure(figsize=cm2inch((1.5, 2)))\n",
                "boxprops = dict(linestyle='-', linewidth=1, facecolor='grey')\n",
                "boxprops2 = dict(linestyle='-', linewidth=1, facecolor='lightblue')\n",
                "\n",
                "# Boxplots\n",
                "bp = plt.boxplot([nll_gt, nll_recs], positions=[1, 2], showfliers=False, widths=0.5, meanline=True,\n",
                "                 meanprops=dict(color=\"red\"), medianprops=dict(color=\"black\"),\n",
                "                 patch_artist=True, boxprops=boxprops)\n",
                "bp['boxes'][1].set_facecolor('lightblue')\n",
                "\n",
                "# Scatter dots\n",
                "jitter = np.random.normal(0, 0.05, len(nll_recs))\n",
                "plt.plot(np.ones_like(nll_gt) + jitter, nll_gt, \".\", color=\"black\", alpha=0.5, ms=0.5)\n",
                "plt.plot(2 * np.ones_like(nll_recs) + jitter, nll_recs, \".\", color=\"midnightblue\", alpha=0.5, ms=0.5)\n",
                "\n",
                "# Axes and labels\n",
                "plt.xticks([1, 2], [\"gt\", \"ae\"])\n",
                "plt.ylabel('neg log lik')\n",
                "plt.show()\n",
                "fig.savefig(save_path + \"Fig_2_neg_log_lik.png\")\n",
                "fig.savefig(save_path + \"Fig_2_neg_log_lik.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "nll_diffusion = neg_log_likelihood(diffusion_rates, gt_spikes,reduction=\"none\").mean((0,1))\n",
                "#nll_diffusion = neg_log_likelihood(np.mean(gt_spikes, axis=(0,1), keepdims=True)*np.ones_like(gt_spikes), gt_spikes,reduction=\"none\").mean((0,1))\n",
                "\n",
                "# Example data\n",
                "np.random.seed(0)  # For reproducibility\n",
                "# Plot setup\n",
                "fig = plt.figure(figsize=cm2inch((2, 4)))\n",
                "boxprops = dict(linestyle='-', linewidth=1, facecolor='grey')\n",
                "boxprops2 = dict(linestyle='-', linewidth=1, facecolor='lightblue')\n",
                "\n",
                "# Boxplots\n",
                "bp = plt.boxplot([nll_gt, nll_recs, nll_diffusion], positions=[1, 2, 3], showfliers=False, widths=0.5, meanline=True,\n",
                "                 meanprops=dict(color=\"red\"), medianprops=dict(color=\"black\"),\n",
                "                 patch_artist=True, boxprops=boxprops)\n",
                "bp['boxes'][1].set_facecolor('lightblue')\n",
                "bp['boxes'][2].set_facecolor('salmon')\n",
                "\n",
                "\n",
                "# Scatter dots\n",
                "jitter = np.random.normal(0, 0.05, len(nll_recs))\n",
                "plt.plot(np.ones_like(nll_gt) + jitter, nll_gt, \".\", color=\"black\", alpha=0.5, ms=0.5)\n",
                "plt.plot(2 * np.ones_like(nll_recs) + jitter, nll_recs, \".\", color=\"midnightblue\", alpha=0.5, ms=0.5)\n",
                "plt.plot(3 * np.ones_like(nll_diffusion) + jitter, nll_diffusion, \".\", color=\"darkred\", alpha=0.5, ms=0.5)\n",
                "\n",
                "# Axes and labels\n",
                "plt.xticks([1, 2, 3], [\"gt\", \"ae\", \"diff\"])\n",
                "plt.ylabel('neg log lik')\n",
                "plt.show()\n",
                "\n",
                "fig.savefig(save_path + \"neg_log_lik_with_diffusion.png\")\n",
                "fig.savefig(save_path + \"neg_log_lik_with_diffusion.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plot_bits_per_spike(\n",
                "    [bps_recs, bps_gt],\n",
                "    legend_labels=[\"ae\", \"gt\"],\n",
                "    colors=[\"midnightblue\", \"grey\"],\n",
                "    bins=30,\n",
                "    log=True,\n",
                "    save = True, \n",
                "    save_path=save_path + \"all_bits_per_spike\"\n",
                ")\n",
                "plot_bits_per_spike(\n",
                "    [bps_recs.mean((0, 1)), bps_gt.mean((0, 1))],\n",
                "    legend_labels=[\"ae\", \"gt\"],\n",
                "    colors=[\"midnightblue\", \"grey\"],\n",
                "    bins=30,\n",
                "    log=False,\n",
                "    save = True, \n",
                "    save_path=save_path + \"mean_bits_per_spike\"\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_rate_comparisons\n",
                "from ntldm.utils.eval_utils import average_rates, std_rates\n",
                "\n",
                "figsize = cm2inch((8, 8)) \n",
                "plot_rate_comparisons(gt_rates, [ae_rates, diffusion_rates],\n",
                "                      mode='neur', figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'], save = True, \n",
                "    save_path=save_path + \"mean_rate_comparison_gt_ae_diffusion\")\n",
                "plot_rate_comparisons(gt_rates, [ae_rates, diffusion_rates],\n",
                "                      mode='neurtime',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "plot_rate_comparisons(gt_rates, [ae_rates, diffusion_rates],\n",
                "                      mode='neursample',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "\n",
                "\n",
                "plot_rate_comparisons(gt_rates, [ae_rates, diffusion_rates],\n",
                "                      fn = std_rates, mode='neur',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'], save = True, \n",
                "    save_path=save_path + \"std_rate_comparison_gt_ae_diffusion\")\n",
                "plot_rate_comparisons(gt_rates, [ae_rates, diffusion_rates],\n",
                "                      fn = std_rates, mode='neurtime',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "plot_rate_comparisons(gt_rates, [ae_rates, diffusion_rates],\n",
                "                      fn = std_rates, mode='neursample',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# make subplot with 4 subpanels\n",
                "fig, axs = plt.subplots(2, 2, figsize=cm2inch(8, 8), sharex=True, sharey=True)\n",
                "plot_rate_comparisons(gt_rates, [ae_rates, diffusion_rates],\n",
                "                      mode='neur', ax = axs[0,0],\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import compute_spike_stats, counts_to_spike_trains, compute_spike_stats_per_neuron\n",
                "fps=1\n",
                "calc_again=False\n",
                "if calc_again:\n",
                "    spike_trains_gt = counts_to_spike_trains(gt_spikes, fps=fps)\n",
                "    spike_trains_ae = counts_to_spike_trains(np.random.poisson(ae_rates), fps=fps)\n",
                "    spike_trains_diff = counts_to_spike_trains(np.random.poisson(diffusion_rates), fps=fps)\n",
                "\n",
                "        # save them as picke files\n",
                "    with open(save_path+'spike_stats_gt.pkl', 'wb') as f:\n",
                "        pickle.dump(spike_stats_gt, f)\n",
                "    with open(save_path+'spike_stats_ae.pkl', 'wb') as f:\n",
                "        pickle.dump(spike_stats_ae, f)\n",
                "    with open(save_path+'spike_stats_diff.pkl', 'wb') as f:\n",
                "        pickle.dump(spike_stats_diff, f)\n",
                "else:\n",
                "    with open(save_path+'spike_stats_gt.pkl', 'rb') as f:\n",
                "        spike_stats_gt = pickle.load(f)\n",
                "    with open(save_path+'spike_stats_ae.pkl', 'rb') as f:\n",
                "        spike_stats_ae = pickle.load(f)\n",
                "    with open(save_path+'spike_stats_diff.pkl', 'rb') as f:\n",
                "        spike_stats_diff = pickle.load(f)\n",
                "# spike_stats_gt = compute_spike_stats(spike_trains_gt, n_samples=gt_spikes.shape[0], n_neurons=gt_spikes.shape[2])\n",
                "# spike_stats_ae = compute_spike_stats(spike_trains_ae, n_samples=ae_rates.shape[0], n_neurons=ae_rates.shape[2])\n",
                "# spike_stats_diff = compute_spike_stats(spike_trains_diff, n_samples=diffusion_rates.shape[0], n_neurons=diffusion_rates.shape[2])\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# save them as picke files\n",
                "with open(save_path+'spike_stats_gt.pkl', 'wb') as f:\n",
                "    pickle.dump(spike_stats_gt, f)\n",
                "with open(save_path+'spike_stats_ae.pkl', 'wb') as f:\n",
                "    pickle.dump(spike_stats_ae, f)\n",
                "with open(save_path+'spike_stats_diff.pkl', 'wb') as f:\n",
                "    pickle.dump(spike_stats_diff, f)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import compute_spike_stats_per_neuron\n",
                "from ntldm.utils.plotting_utils import cm2inch, plot_spiketrain_stats\n",
                "\n",
                "spike_stats_gt = compute_spike_stats_per_neuron(\n",
                "    spike_trains_gt,\n",
                "    n_samples=gt_spikes.shape[0],\n",
                "    n_neurons=gt_spikes.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "spike_stats_ae = compute_spike_stats_per_neuron(\n",
                "    spike_trains_ae,\n",
                "    n_samples=ae_rates.shape[0],\n",
                "    n_neurons=ae_rates.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "spike_stats_diff = compute_spike_stats_per_neuron(\n",
                "    spike_trains_diff,\n",
                "    n_samples=diffusion_rates.shape[0],\n",
                "    n_neurons=diffusion_rates.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_gt,\n",
                "    spike_stats_ae,\n",
                "    figsize=cm2inch(12, 4),\n",
                "    color=\"midnightblue\",\n",
                "    labels=[\"gt\", \"ae\"],\n",
                "    save = True, \n",
                "    save_path=save_path + \"compute_spike_stats_per_neuron_gt_ae\"\n",
                ")\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_gt,\n",
                "    spike_stats_diff,\n",
                "    figsize=cm2inch(12, 4),\n",
                "    color=\"darkred\",\n",
                "    labels=[\"gt\", \"diffusion\"],\n",
                "    save = True, \n",
                "    save_path=save_path + \"compute_spike_stats_per_neuron_gt_diff\"\n",
                ")\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_ae,\n",
                "    spike_stats_diff,\n",
                "    figsize=cm2inch(12, 4),\n",
                "    color=\"purple\",\n",
                "    labels=[\"ae\", \"diffusion\"],\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plot_spiketrain_stats(\n",
                "    spike_stats_gt,\n",
                "    spike_stats_diff,\n",
                "    figsize=cm2inch(10, 4),\n",
                "    color=\"darkred\",\n",
                "    labels=[\"gt\", \"diffusion\"],\n",
                "    save = True, \n",
                "    save_path=save_path + \"Fig_2_compute_spike_stats_per_neuron_gt_diff\"\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plot_rate_comparisons(gt_spikes, [diffusion_spikes],\n",
                "                      mode='neur', figsize=cm2inch((5, 1.5)),fps=1,\n",
                "                      colors=['midnightblue', 'darkred'], save = True, \n",
                "    save_path=save_path + \"Fig_2_mean_rate_comparison_gt_ae_diffusion\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "# plot reconstructed spikes\n",
                "plt.figure(figsize=cm2inch((1.5, 1.5)))\n",
                "plt.hist(gt_spikes.sum(2).flatten(), density=True, color='grey', bins=bins, alpha=0.5)\n",
                "plt.hist(diffusion_spikes.sum(2).flatten(), density=True, color='darkred', bins=bins, alpha=0.5)\n",
                "\n",
                "plt.legend(['gt', 'ldns'], loc='upper left', bbox_to_anchor=(1.5, 1.5))\n",
                "plt.title('population spike count')\n",
                "plt.ylabel('frequency')\n",
                "plt.xlabel('#spikes per 1s')\n",
                "plt.savefig(save_path + 'Fig_2_population_spike_count_diff.png')\n",
                "plt.savefig(save_path + 'Fig_2_population_spike_count_diff.pdf')\n",
                "\n",
                "plt.figure(figsize=cm2inch((1.5, 1.5)))\n",
                "plt.plot(average_rates(gt_spikes, mode='neur', fps=1,).flatten(), \n",
                "         average_rates(diffusion_spikes, mode='neur', fps=1,).flatten(),\n",
                "         '.', color='darkred', alpha=0.5, ms=1)\n",
                "# plot black dotted line over span of datarange\n",
                "min_val = np.min([average_rates(gt_spikes, mode='neur', fps=1,).flatten().min(),\n",
                "                  average_rates(diffusion_spikes, mode='neur', fps=1,).flatten().min()])\n",
                "max_val = np.max([average_rates(gt_spikes, mode='neur', fps=1,).flatten().max(),\n",
                "                  average_rates(diffusion_spikes, mode='neur', fps=1,).flatten().max()])\n",
                "plt.plot([min_val, max_val], [min_val, max_val], 'k--')\n",
                "plt.xlabel('gt rate (Hz)')\n",
                "plt.ylabel('ldns rate (Hz)')\n",
                "plt.title('mean rate')\n",
                "plt.savefig(save_path + 'Fig_2_firing_rate.png')\n",
                "plt.savefig(save_path + 'Fig_2_firing_rate.pdf')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.subplots(1, 1, figsize=cm2inch(8,5))\n",
                "channels = np.arange(0, 128, 32)\n",
                "output_spikes = np.random.poisson(diffusion_rates)[idx]\n",
                "maxval = np.max(\n",
                "    output_spikes.flatten()\n",
                ")\n",
                "plt.imshow(\n",
                "    output_spikes.T,\n",
                "    vmin=0,\n",
                "    vmax=maxval,\n",
                "    aspect=\"auto\",\n",
                "    cmap=\"Reds\",\n",
                ")# add colorbar\n",
                "plt.colorbar()\n",
                "plt.xlabel('time (a.u.)')\n",
                "plt.ylabel('neuron id')\n",
                "plt.savefig(save_path + \"diffusion_output_spikes.png\")\n",
                "plt.savefig(save_path + \"diffusion_output_spikes.pdf\")\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save_path"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import visualise_spikes_trains\n",
                "visualise_spikes_trains(spike_trains_gt,spike_trains_ae, spike_trains_diff, ae_rates, figsize=(cm2inch(10, 6)),ms=0.4, \n",
                "                        save = True, \n",
                "    save_path=save_path + \"spike_trains\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "figsize=(cm2inch(6, 3))\n",
                "ms =1\n",
                "plt.figure(figsize=figsize)\n",
                "for (sample_idx, neuron_idx), spikes in spike_trains_gt.items():\n",
                "    plt.plot(spikes, np.ones_like(spikes) * neuron_idx, \"|\", color=\"black\", markersize=ms)\n",
                "    \n",
                "    \n",
                "    if neuron_idx == ae_rates.shape[-1]-1:\n",
                "        break\n",
                "plt.xlabel('time [s]')\n",
                "plt.ylabel('neuron idx')\n",
                "plt.title('gt spikes')\n",
                "plt.locator_params(nbins=5)\n",
                "plt.figure(figsize=figsize)\n",
                "\n",
                "for (sample_idx, neuron_idx), spikes in spike_trains_ae.items():\n",
                "    plt.plot(spikes, np.ones_like(spikes) * neuron_idx, \"|\", color=\"midnightblue\", markersize=ms)\n",
                "    \n",
                "    \n",
                "    if neuron_idx == ae_rates.shape[-1]-1:\n",
                "        break\n",
                "plt.xlabel('time [s]')\n",
                "plt.ylabel('neuron idx')\n",
                "plt.title('ae spikes')\n",
                "plt.locator_params(nbins=5)    \n",
                "plt.figure(figsize=figsize)\n",
                "\n",
                "for (sample_idx, neuron_idx), spikes in spike_trains_diff.items():\n",
                "    plt.plot(spikes, np.ones_like(spikes) * neuron_idx, \"|\", color=\"darkred\", markersize=ms)\n",
                "    \n",
                "    \n",
                "    if neuron_idx == ae_rates.shape[-1]-1:\n",
                "        break\n",
                "    \n",
                "plt.xlabel('time [s]')\n",
                "plt.ylabel('neuron idx')\n",
                "plt.title('diff spikes')\n",
                "plt.locator_params(nbins=5)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "ldiff",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.8.18"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
