{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "%load_ext autoreload\n",
                "%autoreload 2\n",
                "\n",
                "\n",
                "import argparse\n",
                "import os\n",
                "import sys\n",
                "\n",
                "ANONAUTHOR = False\n",
                "\n",
                "if ANONAUTHOR:\n",
                "    os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
                "    # append parent directory to path (../notebooks -> ..)\n",
                "    sys.path.append(os.path.dirname(os.getcwd()))\n",
                "    os.chdir(os.path.dirname(os.getcwd()))\n",
                "\n",
                "else:\n",
                "    os.chdir('../')\n",
                "\n",
                "\n",
                "import accelerate\n",
                "import auraloss  # freq loss\n",
                "import lovely_tensors as lt\n",
                "import matplotlib.pyplot as plt\n",
                "import matplotlib\n",
                "import numpy as np\n",
                "import scipy.io as io\n",
                "import seaborn as sns\n",
                "import torch\n",
                "import torch.nn as nn\n",
                "import wandb\n",
                "import yaml\n",
                "from diffusers.optimization import get_scheduler\n",
                "from omegaconf import OmegaConf\n",
                "from scipy.signal import welch\n",
                "from tqdm.auto import tqdm\n",
                "from einops import rearrange\n",
                "\n",
                "from ntldm.data.lds import get_lds_dataset\n",
                "from ntldm.data.latent_attractor import get_attractor_dataset\n",
                "from ntldm.data.lds import get_lds_dataset\n",
                "from ntldm.networks import S4AE, AutoEncoder, CountWrapper\n",
                "from ntldm.utils.plotting_utils import *\n",
                "from ntldm.losses import latent_regularizer\n",
                "from ntldm.networks import Denoiser\n",
                "from diffusers.training_utils import EMAModel\n",
                "from diffusers.schedulers import DDPMScheduler\n",
                "from ntldm.utils.plotting_utils import angle_to_color\n",
                "\n",
                "# always run from ../ntldm\n",
                "\n",
                "\n",
                "lt.monkey_patch()\n",
                "matplotlib.rc_file('matplotlibrc')\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "## load config and model path\n",
                "\n",
                "#cfg_ae = OmegaConf.load(\"conf/sweeps_count/autoencoder-count_s4-monkey_new.yaml\")\n",
                "cfg_ae = OmegaConf.load(\"conf/sweeps_count/autoencoder-count_s4-monkey_new_new_regularisation.yaml\")\n",
                "cfg_yaml = \"\"\"\n",
                "denoiser_model:\n",
                "  C_in: 16\n",
                "  C: 256\n",
                "  kernel: s4\n",
                "  num_blocks: 6\n",
                "  bidirectional: True\n",
                "  num_train_timesteps: 1000\n",
                "training:\n",
                "  lr: 0.001\n",
                "  weight_decay: 0.0\n",
                "  num_epochs: 2000\n",
                "  num_warmup_epochs: 50\n",
                "  batch_size: 512\n",
                "  random_seed: 42\n",
                "  precision: \"no\"\n",
                "exp_name: diffusion_s4-monkey_epoch_140\n",
                "\"\"\"\n",
                "\n",
                "cfg = OmegaConf.create(yaml.safe_load(cfg_yaml))\n",
                "cfg.dataset = cfg_ae.dataset\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "\n",
                "### 0. Load autoencoder (with checkpoint) and autoencoder dataset (run for all points 2-5)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import math\n",
                "from ntldm.data.monkey import get_monkey_dataloaders\n",
                "\n",
                "    \n",
                "ae_model = AutoEncoder(\n",
                "    C_in=cfg_ae.model.C_in,\n",
                "    C=cfg_ae.model.C,\n",
                "    C_latent=cfg_ae.model.C_latent,\n",
                "    L=cfg_ae.dataset.signal_length,\n",
                "    kernel=cfg_ae.model.kernel,\n",
                "    num_blocks=cfg_ae.model.num_blocks,\n",
                "    num_blocks_decoder=cfg_ae.model.get(\"num_blocks_decoder\", cfg_ae.model.num_blocks),\n",
                "    num_lin_per_mlp=cfg_ae.model.get(\"num_lin_per_mlp\", 2),  # default 2\n",
                "    bidirectional=cfg_ae.model.get(\"bidirectional\", True),\n",
                ")\n",
                "\n",
                "ae_model = CountWrapper(ae_model, use_sin_enc=cfg_ae.model.get(\"use_sin_enc\", False))\n",
                "\n",
                "\n",
                "\n",
                "# set seed\n",
                "torch.manual_seed(cfg.training.random_seed)\n",
                "np.random.seed(cfg.training.random_seed)\n",
                "\n",
                "train_dataloader, val_dataloader, test_dataloader = get_monkey_dataloaders(\n",
                "        cfg_ae.dataset.task, cfg_ae.dataset.datapath, bin_width=5, batch_size=cfg_ae.training.batch_size\n",
                "    )\n",
                "\n",
                "accelerator = accelerator = accelerate.Accelerator(\n",
                "    mixed_precision='no',\n",
                "    log_with=\"wandb\",\n",
                ")\n",
                "\n",
                "# prepare the ae model and dataset\n",
                "ae_model = accelerator.prepare(ae_model)\n",
                "    \n",
                "print(cfg_ae.exp_name)\n",
                "accelerator.load_state(f\"exp/{cfg_ae.exp_name}/epoch_140\") # best checkpoint CHANGED\n",
                "\n",
                "(\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ") = accelerator.prepare(\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ")\n",
                "\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save_path = '/home/anonauthor/anonloc1/results/projects/latent-diffusion/MONKEY/figures_cond/'\n",
                "os.makedirs(save_path, exist_ok=True)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import reconstruct_spikes\n",
                "rec_dict = reconstruct_spikes(ae_model, test_dataloader)\n",
                "\n",
                "\n",
                "# plot reconstructed spikes\n",
                "plt.figure(figsize=cm2inch((6, 4)))\n",
                "bins = np.linspace(0, 15, 15)\n",
                "plt.hist(rec_dict['spikes'].sum(1).flatten(), density=True, color='grey', bins=bins, alpha=0.5)\n",
                "plt.hist(rec_dict['rec_spikes'].sum(1).flatten(), density=True, color='darkblue', bins=bins, alpha=0.5)\n",
                "\n",
                "plt.legend(['gt', 'ae'])\n",
                "plt.title('population spike count (test set)')\n",
                "plt.ylabel('frequency')\n",
                "plt.xlabel('number of spikes per time bin')\n",
                "plt.savefig(save_path + 'population_spike_count_test.png')\n",
                "plt.savefig(save_path + 'population_spike_count_test.pdf')\n",
                "\n",
                "\n",
                "\n",
                "# plot reconstructed spikes\n",
                "plt.figure(figsize=cm2inch((6, 4)))\n",
                "bins = np.linspace(0, 20, 20)\n",
                "plt.hist(rec_dict['spikes'].sum(2).flatten(), density=True, color='grey', bins=bins, alpha=0.5)\n",
                "plt.hist(rec_dict['rec_spikes'].sum(2).flatten(), density=True, color='darkblue', bins=bins, alpha=0.5)\n",
                "\n",
                "plt.legend(['gt', 'ae'])\n",
                "plt.title('spike count distribution (test set)')\n",
                "plt.ylabel('frequency')\n",
                "plt.xlabel('number of spikes per trial')\n",
                "plt.savefig(save_path + 'trial_spike_count_test.png')\n",
                "plt.savefig(save_path + 'trial_spike_count_test.pdf')\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### 1. Create dataset containing behavior, behavior angle, spike dataset, latents from ae ✅ (run for all points 2-5)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.data.monkey import LatentMonkeyDataset\n",
                "latent_dataset_train = LatentMonkeyDataset(train_dataloader, ae_model, clip=False)\n",
                "latent_dataset_val = LatentMonkeyDataset(\n",
                "    val_dataloader,\n",
                "    ae_model,\n",
                "    latent_means=latent_dataset_train.latent_means,\n",
                "    latent_stds=latent_dataset_train.latent_stds,\n",
                "    clip=False,\n",
                ")\n",
                "latent_dataset_test = LatentMonkeyDataset(\n",
                "    test_dataloader,\n",
                "    ae_model,\n",
                "    latent_means=latent_dataset_train.latent_means,\n",
                "    latent_stds=latent_dataset_train.latent_stds,\n",
                "    clip=False,\n",
                ")\n",
                "\n",
                "train_latent_dataloader = torch.utils.data.DataLoader(\n",
                "    latent_dataset_train,\n",
                "    batch_size=cfg.training.batch_size,\n",
                "    shuffle=True,\n",
                "    num_workers=4,\n",
                "    pin_memory=True,\n",
                ")\n",
                "\n",
                "val_latent_dataloader = torch.utils.data.DataLoader(\n",
                "    latent_dataset_val,\n",
                "    batch_size=cfg.training.batch_size,\n",
                "    shuffle=False,\n",
                "    num_workers=4,\n",
                "    pin_memory=True,\n",
                ")\n",
                "\n",
                "test_latent_dataloader = torch.utils.data.DataLoader(\n",
                "    latent_dataset_test,\n",
                "    batch_size=cfg.training.batch_size,\n",
                "    shuffle=False,\n",
                "    num_workers=4,\n",
                "    pin_memory=True,\n",
                ")\n",
                "\n",
                "num_batches = len(train_latent_dataloader)\n",
                "\n",
                "# check if signal length is power of 2\n",
                "if cfg.dataset.signal_length & (cfg.dataset.signal_length - 1) != 0:\n",
                "    cfg.training.precision = \"no\"  # torch.fft doesnt support half if L!=2^x\n",
                "\n",
                "# prepare the denoiser model and dataset\n",
                "(\n",
                "    train_latent_dataloader,\n",
                "    val_latent_dataloader,\n",
                "    test_latent_dataloader,\n",
                ") = accelerator.prepare(\n",
                "    train_latent_dataloader,\n",
                "    val_latent_dataloader,\n",
                "    test_latent_dataloader,\n",
                ")\n",
                "\n",
                "# ------------------ visualize dataset ------------------\n",
                "\n",
                "\n",
                "print(\"latent dataset\", latent_dataset_train.latents)\n",
                "print(\"latent dataset means\", latent_dataset_train.latent_means)\n",
                "print(\"latent dataset stds\", latent_dataset_train.latent_stds)\n",
                "plt.figure(figsize=cm2inch(5, 3))\n",
                "hist = plt.hist(latent_dataset_train.latents[:100].flatten(), bins=200, density=True)\n",
                "hist = plt.hist(latent_dataset_val.latents[:100].flatten(), bins=200, density=True)\n",
                "hist = plt.hist(latent_dataset_test.latents[:100].flatten(), bins=200, density=True)\n",
                "\n",
                "plt.title(\"Latent dataset histogram\")\n",
                "plt.show()\n",
                "\n",
                "plt.figure(figsize=cm2inch(5, 3))\n",
                "plt.hist(latent_dataset_train.behavior_angles.flatten(), bins=200, label=\"train\")\n",
                "plt.hist(latent_dataset_test.behavior_angles.flatten(), bins=200, label=\"test\")\n",
                "plt.legend()\n",
                "# xticks in terms of pi\n",
                "plt.xticks(\n",
                "    np.linspace(-np.pi, np.pi, 5),\n",
                "    [r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"],\n",
                ")\n",
                "plt.ylabel(\"number of reaches\")\n",
                "plt.xlabel(\"behavior angle\")\n",
                "plt.savefig(save_path + \"behavior_angles.png\")\n",
                "plt.savefig(save_path + \"behavior_angles.pdf\")\n",
                "\n",
                "# create cmap based on behavior angles and plot behavior_cumsum on 2d plot \n",
                "cmap = plt.get_cmap(\"hsv\")\n",
                "colors = cmap((latent_dataset_train.behavior_angles.squeeze() + np.pi) / (2 * np.pi))\n",
                "plt.figure(figsize=cm2inch(8, 6))\n",
                "\n",
                "for i in range(0, len(latent_dataset_train.behavior), 10):\n",
                "    plt.plot(\n",
                "        latent_dataset_train.behavior_cumsum[i, 0],\n",
                "        latent_dataset_train.behavior_cumsum[i, 1],\n",
                "        color=colors[i],\n",
                "        alpha=0.3,\n",
                "    )\n",
                "# colormap\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "# set ticks formatting for colorbar in terms of pi\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), label=\"behavior angle\", ax=plt.gca())\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "plt.xlabel('x')\n",
                "plt.ylabel('y')\n",
                "plt.savefig(save_path + \"all_reaches_colored_by_angles.png\")\n",
                "plt.savefig(save_path + \"all_reaches_colored_by_angles.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import angle_to_color\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Decode behaviors from test set"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "## sanity check: ae decoded signal, mapped to behavior (trained on train data) should be similar to the actual behavior\n",
                "\n",
                "from ntldm.utils.behav_eval_utils import compute_decoded_behavior\n",
                "\n",
                "\n",
                "behav_dict = compute_decoded_behavior(\n",
                "    ae_model, train_latent_dataloader, test_latent_dataloader\n",
                ")\n",
                "\n",
                "predicted_val_behavior = behav_dict[\"predicted_val_behavior\"]\n",
                "real_val_behavior = behav_dict[\"real_val_behavior\"]\n",
                "predicted_traj = np.cumsum(predicted_val_behavior, axis=-1)\n",
                "real_traj = np.cumsum(real_val_behavior, axis=-1)\n",
                "# get the reach angle \n",
                "angle_real_val_behavior = np.arctan2(real_traj[:, 1, 50], real_traj[:, 0, 50])\n",
                "\n",
                "print(predicted_traj.shape, real_traj.shape)\n",
                "plt.figure(figsize=cm2inch(8, 6))\n",
                "\n",
                "for i, idx in enumerate(range(0, 70, 5)):\n",
                "    plt.plot(\n",
                "        real_traj[idx, 0, :], real_traj[idx, 1, :], label=\"real trajectory\", color=angle_to_color(angle_real_val_behavior[idx]),\n",
                "    )\n",
                "    plt.plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=angle_to_color(angle_real_val_behavior[idx]),\n",
                "    )\n",
                "plt.xlabel(\"x position\")\n",
                "plt.ylabel(\"y position\")\n",
                "# plt.legend()\n",
                "plt.title(\"real vs predicted trajectory\")\n",
                "\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "# set ticks formatting for colorbar in terms of pi\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), label=\"behavior angle\", ax=plt.gca())\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "plt.xlabel('x')\n",
                "plt.ylabel('y')\n",
                "plt.savefig(save_path + \"test_reaches_and_recs_colored_by_angles.png\")\n",
                "plt.savefig(save_path + \"test_reaches_and_recs_colored_by_angles.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.figure(figsize=cm2inch(3, 2.2))\n",
                "\n",
                "for i, idx in enumerate(range(0, 70, 5)):\n",
                "    plt.plot(\n",
                "        real_traj[idx, 0, :], real_traj[idx, 1, :],\n",
                "         label=\"real trajectory\", color=angle_to_color(angle_real_val_behavior[idx]),\n",
                "                 lw=1\n",
                "    )\n",
                "    plt.plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=angle_to_color(angle_real_val_behavior[idx]),\n",
                "        lw=1\n",
                "    )\n",
                "plt.xlabel(\"x position\")\n",
                "plt.ylabel(\"y position\")\n",
                "# plt.legend()\n",
                "#plt.title(\"real vs predicted trajectory\")\n",
                "\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "# set ticks formatting for colorbar in terms of pi\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), label=\"behavior angle\", ax=plt.gca())\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "plt.xlabel('x')\n",
                "plt.ylabel('y')\n",
                "plt.savefig(save_path + \"Fig_4_test_reaches_and_recs_colored_by_angles.png\")\n",
                "plt.savefig(save_path + \"Fig_4_test_reaches_and_recs_colored_by_angles.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.sampling_utils import sample, sample_spikes\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Save or load the unconditional diffusion model "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "load_model = True\n",
                "\n",
                "            \n",
                "if load_model:\n",
                "\n",
                "    with open(f\"conf/sweeps_count/{cfg.exp_name}_save_after_train_new_regularisation.yaml\") as f:\n",
                "        cfg = OmegaConf.create(yaml.safe_load(f))\n",
                "        \n",
                "    denoiser = Denoiser(\n",
                "        C_in=cfg.denoiser_model.C_in,\n",
                "        C=cfg.denoiser_model.C,\n",
                "        L=cfg.dataset.signal_length,\n",
                "        kernel=cfg.denoiser_model.kernel,\n",
                "        num_blocks=cfg.denoiser_model.num_blocks,\n",
                "        bidirectional=cfg.denoiser_model.get(\"bidirectional\", True),\n",
                "    )\n",
                "\n",
                "    # initial values may be way off so better to scale down the output layer\n",
                "    denoiser.conv_out.weight.data = denoiser.conv_out.weight.data * 0.1\n",
                "    denoiser.conv_out.bias.data = denoiser.conv_out.bias.data * 0.1\n",
                "\n",
                "    scheduler = DDPMScheduler(\n",
                "        num_train_timesteps=cfg.denoiser_model.num_train_timesteps,\n",
                "        clip_sample=False,\n",
                "        beta_schedule=\"linear\", # ddpm doesnt support cosine\n",
                "    )\n",
                "\n",
                "\n",
                "    optimizer = torch.optim.AdamW(\n",
                "        denoiser.parameters(), lr=cfg.training.lr\n",
                "    )  # default wd=0.01 for now\n",
                "\n",
                "\n",
                "\n",
                "    num_batches = len(train_latent_dataloader)\n",
                "    lr_scheduler = get_scheduler(\n",
                "        name=\"cosine\",\n",
                "        optimizer=optimizer,\n",
                "        num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "        num_training_steps=num_batches * cfg.training.num_epochs * 1.3,  # total number of steps\n",
                "    )\n",
                "\n",
                "    denoiser.load_state_dict(torch.load(f\"exp/{cfg.exp_name}/model_new_reg.pt\", map_location=\"cpu\"))\n",
                "\n",
                "    # prepare the denoiser model and dataset\n",
                "    (\n",
                "        denoiser,\n",
                "        optimizer,\n",
                "        lr_scheduler,\n",
                "    ) = accelerator.prepare(\n",
                "        denoiser,\n",
                "        optimizer,\n",
                "        lr_scheduler,\n",
                "    )\n",
                "\n",
                "    ema_model = EMAModel(denoiser)\n",
                "    \n",
                "        \n",
                "    print(f\"loaded model from exp/{cfg.exp_name}/model_new_reg.pt\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save=True\n",
                "idx=0\n",
                "plot_rate_traces(\n",
                "        ae_model,\n",
                "        test_dataloader,\n",
                "        idx=idx,\n",
                "        figsize=cm2inch(10, 10),\n",
                "        true_data=True,\n",
                "        save=save,\n",
                "        save_path=save_path + \"rate_traces\",\n",
                "    )\n",
                "\n",
                "plot_spikes_next_to_each_other(\n",
                "    ae_model,\n",
                "    test_dataloader,\n",
                "    idx=idx,\n",
                "    figsize=cm2inch(15, 5),\n",
                "    save=save,\n",
                "    save_path=save_path + \"spikes_next_to_each_other\",\n",
                ")\n",
                "\n",
                "# run all sorts of analyses\n",
                "plot_inferred_latents(\n",
                "    ae_model,\n",
                "    test_dataloader,\n",
                "    n_latents=8,\n",
                "    y_stack=4,\n",
                "    figsize=cm2inch(10, 10),\n",
                "    color=\"royalblue\",\n",
                "    idx=idx,\n",
                "    save=save,\n",
                "    save_path=save_path + \"inferred_latents\",\n",
                "    )\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Create tensors for comparison"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "true_data = True\n",
                "n_siglen = 1\n",
                "sample_cutoff = int(10e5)\n",
                "\n",
                "ae_model.eval()\n",
                "\n",
                "ae_rates = []\n",
                "ae_latents = []\n",
                "diffusion_rates = []\n",
                "diffusion_latents = []\n",
                "gt_spikes = []\n",
                "\n",
                "if not true_data:\n",
                "    gt_rates = []\n",
                "    gt_latents = []\n",
                "\n",
                "\n",
                "# autoencoder eval \n",
                "for batch in train_dataloader:\n",
                "    signal = batch[\"signal\"]\n",
                "    with torch.no_grad():\n",
                "        output_rates, latent = ae_model(signal)\n",
                "        #output_rates = ae_model(signal)[0].cpu()\n",
                "    ae_rates.append(output_rates.cpu())\n",
                "    ae_latents.append(latent.cpu())\n",
                "    gt_spikes.append(signal.cpu())\n",
                "    if not true_data:\n",
                "        gt_rates.append(batch[\"rates\"].cpu())\n",
                "        gt_latents.append(batch[\"latents\"].cpu())\n",
                "\n",
                "    #break\n",
                "\n",
                "# concatenate along batch dimension\n",
                "ae_rates = torch.cat(ae_rates, dim=0)\n",
                "ae_latents = torch.cat(ae_latents, dim=0)\n",
                "gt_spikes = torch.cat(gt_spikes, dim=0)\n",
                "if not true_data:\n",
                "    gt_rates = torch.cat(gt_rates, dim=0)\n",
                "    gt_latents = torch.cat(gt_latents, dim=0)\n",
                "    \n",
                "\n",
                "\n",
                "# diffusion eval\n",
                "sampled_latents = sample(\n",
                "    ema_denoiser=ema_model, scheduler=scheduler, cfg=cfg, batch_size=ae_rates.shape[0], device=\"cuda\", signal_length=cfg.dataset.signal_length * n_siglen\n",
                ")\n",
                "\n",
                "# project back to non standardized space\n",
                "sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "    sampled_latents.device\n",
                ") + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "\n",
                "with torch.no_grad():\n",
                "    sampled_rates = ae_model.decode(sampled_latents).cpu()\n",
                "    \n",
                "diffusion_rates.append(sampled_rates)\n",
                "diffusion_latents.append(sampled_latents.cpu())\n",
                "\n",
                "# concatenate along batch dimension\n",
                "diffusion_rates = torch.cat(diffusion_rates, dim=0)\n",
                "diffusion_latents = torch.cat(diffusion_latents, dim=0)\n",
                "\n",
                "vecs = [ae_rates, ae_latents, diffusion_rates, diffusion_latents, gt_spikes]\n",
                "\n",
                "if not true_data:\n",
                "    vecs.append(gt_rates)\n",
                "    vecs.append(gt_latents)\n",
                "\n",
                "vecs = [vec.cpu().numpy() for vec in vecs]\n",
                "vecs = [vec[:sample_cutoff] for vec in vecs]\n",
                "vecs = [rearrange(vec, 'b n t -> b t n') for vec in vecs]\n",
                "\n",
                "\n",
                "if not true_data:\n",
                "    ae_rates, ae_latents, diffusion_rates, diffusion_latents, gt_spikes, gt_rates, gt_latents = vecs\n",
                "else:\n",
                "    ae_rates, ae_latents, diffusion_rates, diffusion_latents, gt_spikes = vecs\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "fps = 1000 / 5\n",
                "plot_n_channel_sd(diffusion_latents,\n",
                "                  ae_latents,\n",
                "                  channels=np.arange(16),\n",
                "                  fps=fps, save=False,\n",
                "                  save_path=None,\n",
                "                  colors=['darkred', \"midnightblue\"],\n",
                "                  labels=['diffusion', 'ae'], ystack=4, figsize=cm2inch(20, 20))\n",
                "\n",
                "\n",
                "plot_n_channel_sd(ae_latents,diffusion_latents,\n",
                "                  channels=np.arange(16),\n",
                "                  fps=fps, save=True, save_path=save_path+'latents_power_spec_density_diff_ae',\n",
                "                  colors=[color_dict['ael'], color_dict['diffl']],\n",
                "                  labels=['ae', 'diff'], ystack=4, figsize=cm2inch(20, 20), lw=1)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# calculate correlation matrices"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "C_gt = correlation_matrix(gt_spikes, mode=\"concatenate\")\n",
                "np.fill_diagonal(C_gt, 0)\n",
                "min_max = np.nanmax(np.abs(C_gt))\n",
                "plt.figure(figsize=cm2inch(8, 6))\n",
                "\n",
                "plt.imshow(C_gt, vmax=min_max,vmin=-min_max, cmap=\"coolwarm\")\n",
                "plt.colorbar()\n",
                "plt.tight_layout()\n",
                "plt.xlabel('neuron index')\n",
                "plt.ylabel('neuron index')\n",
                "plt.savefig(save_path + 'gt_correlation_matrix.png')\n",
                "plt.savefig(save_path + 'gt_correlation_matrix.pdf')\n",
                "\n",
                "plt.figure(figsize=cm2inch(8, 6))\n",
                "ae_spikes = np.random.poisson(ae_rates)\n",
                "C_ae = correlation_matrix(ae_spikes, mode=\"concatenate\")\n",
                "np.fill_diagonal(C_ae, 0)\n",
                "min_max_ae = np.nanmax(np.abs(C_ae))\n",
                "\n",
                "plt.imshow(C_ae, vmax=min_max_ae,vmin=-min_max_ae, cmap=\"coolwarm\")\n",
                "plt.colorbar()\n",
                "plt.tight_layout()\n",
                "plt.xlabel('neuron index')\n",
                "plt.ylabel('neuron index')\n",
                "plt.savefig(save_path + 'ae_correlation_matrix.png')\n",
                "plt.savefig(save_path + 'ae_correlation_matrix.pdf')\n",
                "\n",
                "\n",
                "plt.figure(figsize=cm2inch(8, 6))\n",
                "diff_spikes = np.random.poisson(diffusion_rates)\n",
                "C_diff = correlation_matrix(diff_spikes, mode=\"concatenate\")\n",
                "np.fill_diagonal(C_diff, 0)\n",
                "min_max_diff = np.nanmax(np.abs(C_diff))\n",
                "\n",
                "plt.imshow(C_diff, vmax=min_max_diff,vmin=-min_max_diff, cmap=\"coolwarm\")\n",
                "plt.colorbar()\n",
                "plt.tight_layout()\n",
                "plt.xlabel('neuron index')\n",
                "plt.ylabel('neuron index')\n",
                "plt.savefig(save_path + 'diffusion_correlation_matrix.png')\n",
                "plt.savefig(save_path + 'diffusion_correlation_matrix.pdf')\n",
                "\n",
                "\n",
                "max_all = np.max([min_max, min_max_ae, min_max_diff])\n",
                "\n",
                "fig, ax = plt.subplots(1, 4, figsize=cm2inch(18, 6))\n",
                "\n",
                "ax[0].imshow(C_gt, vmax=max_all,vmin=-max_all, cmap=\"coolwarm\")\n",
                "ax[1].imshow(C_ae, vmax=max_all,vmin=-max_all, cmap=\"coolwarm\")\n",
                "ax[2].imshow(C_diff, vmax=max_all,vmin=-max_all, cmap=\"coolwarm\")\n",
                "ax[3].plot(C_gt.flatten(), C_ae.flatten(), 'o', alpha=0.5, color=\"midnightblue\")\n",
                "ax[3].plot(C_gt.flatten(), C_diff.flatten(), 'o', alpha=0.5, color=\"darkred\")\n",
                "ax[3].plot([-max_all, max_all], [-max_all, max_all], 'k--')\n",
                "\n",
                "plt.tight_layout()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import correlation_matrix\n",
                "from ntldm.utils.plotting_utils import plot_correlation_matrices_monkey\n",
                "# Compute correlation matrices for one sample sequence\n",
                "ae_spikes = np.random.poisson(ae_rates)\n",
                "diffusion_spikes = np.random.poisson(diffusion_rates)\n",
                "\n",
                "\n",
                "plot_correlation_matrices_monkey(gt_spikes,\n",
                "                          [ae_spikes, diffusion_spikes],\n",
                "                          sample=None,mode=\"concatenate\",\n",
                "                          model_labels=[ 'ae', 'diffusion'],\n",
                "                          model_colors=['midnightblue', 'darkred'],\n",
                "                          figsize = cm2inch((16, 6)))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "    \n",
                "from ntldm.utils.eval_utils import group_neurons_temp_corr, get_temp_corr_summary\n",
                "from ntldm.utils.plotting_utils import plot_cross_corr_summary, plot_temp_corr_summary\n",
                "\n",
                "\n",
                "ae_spikes = np.random.poisson(ae_rates)\n",
                "diff_spikes = np.random.poisson(diffusion_rates)\n",
                "\n",
                "\n",
                "groups = group_neurons_temp_corr(gt_spikes.transpose(1,0,2), num_groups=4)\n",
                "cross_corr_groups, auto_corr_groups = get_temp_corr_summary(gt_spikes.transpose(1,0,2), groups, nlags=30,mode='biased',\n",
                "                                                            batch_first=False)\n",
                "fig_cross, fig_auto = plot_temp_corr_summary(cross_corr_groups, auto_corr_groups, name='Data')\n",
                "\n",
                "cross_corr_groups_sampled, auto_corr_groups_sampled = get_temp_corr_summary(ae_spikes.transpose(1,0,2), groups, nlags=30, mode='biased',\n",
                "                                                            batch_first=False)\n",
                "fig_cross_sampled, fig_auto_sampled = plot_temp_corr_summary(cross_corr_groups_sampled, auto_corr_groups_sampled, name='AE Samples')\n",
                "\n",
                "\n",
                "cross_corr_groups_diff, auto_corr_groups_diff = get_temp_corr_summary(diff_spikes.transpose(1,0,2), groups, nlags=30, mode='biased',\n",
                "                                                            batch_first=False)\n",
                "fig_cross_sampled, fig_auto_sampled = plot_temp_corr_summary(cross_corr_groups_diff, auto_corr_groups_diff, name='Diffusion Samples')\n",
                "\n",
                "\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_cross_corr_summary\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    save=True,\n",
                "    save_path=save_path + \"cross_corr_gt\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    save=True,\n",
                "    save_path=save_path + \"cross_corr_diffusion\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_sampled,\n",
                "    name=\"ae\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Blues\",\n",
                "    save=True,\n",
                "    save_path=save_path + \"cross_corr_ae\",\n",
                ")\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8, 6))\n",
                "save = True\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=3,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=3,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_sampled,\n",
                "    name=\"ae\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Blues\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"ae \",\n",
                "    ncol=3,\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"cross_corr_all.png\")\n",
                "    fig.savefig(save_path + \"cross_corr_all.pdf\")\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(6, 4))\n",
                "save = True\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=2,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    cross_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=2,\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"cross_corr_gt_diff.png\")\n",
                "    fig.savefig(save_path + \"cross_corr_gt_diff.pdf\")\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8, 6))\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=3,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=3,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups_sampled,\n",
                "    name=\"ae\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Blues\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"ae \",\n",
                "    ncol=3,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"auto_corr_all.png\")\n",
                "    fig.savefig(save_path + \"auto_corr_all.pdf\")\n",
                "    \n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(6, 4))\n",
                "save = True\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups,\n",
                "    name=\"gt\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Greys\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"gt \",\n",
                "    ncol=2,\n",
                ")\n",
                "plot_cross_corr_summary(\n",
                "    auto_corr_groups_diff,\n",
                "    name=\"diffusion\",\n",
                "    figsize=cm2inch(6, 4),\n",
                "    cmap=\"Reds\",\n",
                "    ax_corr=ax,\n",
                "    labels=\"diff \",\n",
                "    ncol=2,\n",
                "    title=\"auto-corr\",\n",
                "    ylabel=\"auto-corr\",\n",
                ")\n",
                "if save and save_path is not None:\n",
                "    fig.savefig(save_path + \"auto_corr_gt_diff.png\")\n",
                "    fig.savefig(save_path + \"auto_corr_gt_diff.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "\n",
                "plot_cross_corr_summary(cross_corr_groups, name='gt', figsize=cm2inch(6, 4), cmap='Greys')\n",
                "plot_cross_corr_summary(cross_corr_groups_diff, name='diffusion', figsize=cm2inch(6, 4), cmap='Reds')\n",
                "plot_cross_corr_summary(cross_corr_groups_sampled, name='ae', figsize=cm2inch(6, 4), cmap='Blues')\n",
                "\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8,6))\n",
                "plot_cross_corr_summary(cross_corr_groups, name='gt', figsize=cm2inch(6, 4), cmap='Greys', ax_corr=ax, labels='gt g', ncol=3)\n",
                "plot_cross_corr_summary(cross_corr_groups_diff, name='diffusion', figsize=cm2inch(6, 4), cmap='Reds', ax_corr=ax, labels='diff g', ncol=3)\n",
                "plot_cross_corr_summary(cross_corr_groups_sampled, name='ae', figsize=cm2inch(6, 4), cmap='Blues', ax_corr=ax, labels='ae g', ncol=3)\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8,6))\n",
                "plot_cross_corr_summary(auto_corr_groups, name='gt', figsize=cm2inch(6, 4), cmap='Greys', ax_corr=ax, labels='gt g', ncol=3, title='auto-corr', ylabel='auto-corr')\n",
                "plot_cross_corr_summary(auto_corr_groups_diff, name='diffusion', figsize=cm2inch(6, 4), cmap='Reds', ax_corr=ax, labels='diff g', ncol=3, title='auto-corr', ylabel='auto-corr')\n",
                "plot_cross_corr_summary(auto_corr_groups_sampled, name='ae', figsize=cm2inch(6, 4), cmap='Blues', ax_corr=ax, labels='ae g', ncol=3, title='auto-corr', ylabel='auto-corr')\n",
                "\n",
                "\n",
                "\n",
                "\n",
                "plot_cross_corr_summary(auto_corr_groups, name='gt', figsize=cm2inch(6, 4), cmap='Greys', title='auto-corr', ylabel='auto-corr')\n",
                "plot_cross_corr_summary(auto_corr_groups_diff, name='diffusion', figsize=cm2inch(6, 4), cmap='Reds', title='auto-corr', ylabel='auto-corr')\n",
                "plot_cross_corr_summary(auto_corr_groups_sampled, name='ae', figsize=cm2inch(6, 4), cmap='Blues', title='auto-corr', ylabel='auto-corr')\n",
                "\n",
                "\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8,6))\n",
                "plot_cross_corr_summary(cross_corr_groups, name='gt', figsize=cm2inch(6, 4), cmap='Greys', ax_corr=ax, labels='gt g', ncol=2)\n",
                "plot_cross_corr_summary(cross_corr_groups_sampled, name='ae', figsize=cm2inch(6, 4), cmap='Blues', ax_corr=ax, labels='ae g', ncol=2)\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(8,6))\n",
                "plot_cross_corr_summary(auto_corr_groups, name='gt', figsize=cm2inch(6, 4), cmap='Greys', ax_corr=ax, labels='gt g', ncol=2, title='auto-corr', ylabel='auto-corr')\n",
                "plot_cross_corr_summary(auto_corr_groups_sampled, name='ae', figsize=cm2inch(6, 4), cmap='Blues', ax_corr=ax, labels='ae g', ncol=2, title='auto-corr', ylabel='auto-corr')\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_rate_comparisons\n",
                "from ntldm.utils.eval_utils import average_rates, std_rates\n",
                "\n",
                "figsize = cm2inch((8, 8)) \n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],\n",
                "                      mode='neur', figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "\n",
                "\n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],\n",
                "                      fn = std_rates, mode='neur',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_rate_comparisons\n",
                "from ntldm.utils.eval_utils import average_rates, std_rates\n",
                "\n",
                "figsize = cm2inch((8, 8)) \n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],fps=1000/5,\n",
                "                      mode='neur', figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'], save = True, xlabel='firing rate (Hz)',\n",
                "    save_path=save_path + \"mean_spike_comparison_gt_ae_diffusion\")\n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],\n",
                "                      mode='neurtime',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],\n",
                "                      mode='neursample',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "\n",
                "\n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],\n",
                "                      fn = std_rates, mode='neur',figsize = figsize,fps=1000/5,\n",
                "                      colors=['midnightblue', 'darkred'], save = True, xlabel='std fr (Hz)',\n",
                "    save_path=save_path + \"std_std_comparison_gt_ae_diffusion\")\n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],\n",
                "                      fn = std_rates, mode='neurtime',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])\n",
                "plot_rate_comparisons(gt_spikes, [ae_spikes, diffusion_spikes],\n",
                "                      fn = std_rates, mode='neursample',figsize = figsize,\n",
                "                      colors=['midnightblue', 'darkred'])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import compute_spike_stats, counts_to_spike_trains\n",
                "fps=1000/5\n",
                "spike_trains_gt = counts_to_spike_trains(gt_spikes, fps=fps)\n",
                "spike_trains_ae = counts_to_spike_trains(np.random.poisson(ae_rates), fps=fps)\n",
                "spike_trains_diff = counts_to_spike_trains(np.random.poisson(diffusion_rates), fps=fps)\n",
                "\n",
                "\n",
                "# spike_stats_gt = compute_spike_stats(spike_trains_gt, n_samples=gt_spikes.shape[0], n_neurons=gt_spikes.shape[2])\n",
                "# spike_stats_ae = compute_spike_stats(spike_trains_ae, n_samples=ae_rates.shape[0], n_neurons=ae_rates.shape[2])\n",
                "# spike_stats_diff = compute_spike_stats(spike_trains_diff, n_samples=diffusion_rates.shape[0], n_neurons=diffusion_rates.shape[2])\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "# n_samples, n_seqlen, n_neurons = gt_spikes.shape\n",
                "# mean_isis_per_sample_neuron = np.full((n_samples, n_neurons), np.nan)\n",
                "# for (sample_idx, neuron_idx), spikes in spike_trains_gt.items():\n",
                "#         if len(spikes) > 1:\n",
                "#             isis = np.diff(spikes)\n",
                "#             mean_isi = np.nanmean(isis) if len(isis) > 0 else np.nan\n",
                "#             mean_isis_per_sample_neuron[sample_idx, neuron_idx] = mean_isi\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import compute_spike_stats_per_neuron\n",
                "from ntldm.utils.plotting_utils import cm2inch, plot_spiketrain_stats\n",
                "\n",
                "spike_stats_gt = compute_spike_stats_per_neuron(\n",
                "    spike_trains_gt,\n",
                "    n_samples=gt_spikes.shape[0],\n",
                "    n_neurons=gt_spikes.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "spike_stats_ae = compute_spike_stats_per_neuron(\n",
                "    spike_trains_ae,\n",
                "    n_samples=ae_rates.shape[0],\n",
                "    n_neurons=ae_rates.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "spike_stats_diff = compute_spike_stats_per_neuron(\n",
                "    spike_trains_diff,\n",
                "    n_samples=diffusion_rates.shape[0],\n",
                "    n_neurons=diffusion_rates.shape[2],\n",
                "    mean_output=False,\n",
                ")\n",
                "\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_gt,\n",
                "    spike_stats_ae,\n",
                "    figsize=cm2inch(12, 4),\n",
                "    color=\"midnightblue\",\n",
                "    labels=[\"gt\", \"ae\"],\n",
                "    save = True, \n",
                "    save_path=save_path + \"compute_spike_stats_per_neuron_gt_ae\"\n",
                ")\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_gt,\n",
                "    spike_stats_diff,\n",
                "    figsize=cm2inch(12, 4),\n",
                "    color=\"darkred\",\n",
                "    labels=[\"gt\", \"diffusion\"],\n",
                "    save = True, \n",
                "    save_path=save_path + \"compute_spike_stats_per_neuron_gt_diff\"\n",
                ")\n",
                "plot_spiketrain_stats(\n",
                "    spike_stats_ae,\n",
                "    spike_stats_diff,\n",
                "    figsize=cm2inch(12, 4),\n",
                "    color=\"purple\",\n",
                "    labels=[\"ae\", \"diffusion\"],\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# take all not mean aggregated\n",
                "from ntldm.utils.eval_utils import compute_spike_stats_per_neuron\n",
                "from ntldm.utils.plotting_utils import cm2inch, plot_spiketrain_stats\n",
                "\n",
                "spike_stats_gt = compute_spike_stats_per_neuron(spike_trains_gt, n_samples=gt_spikes.shape[0], n_neurons=gt_spikes.shape[2], mean_output=False)\n",
                "spike_stats_ae = compute_spike_stats_per_neuron(spike_trains_ae, n_samples=ae_rates.shape[0], n_neurons=ae_rates.shape[2], mean_output=False)\n",
                "spike_stats_diff = compute_spike_stats_per_neuron(spike_trains_diff, n_samples=diffusion_rates.shape[0], n_neurons=diffusion_rates.shape[2], mean_output=False)\n",
                "\n",
                "plot_spiketrain_stats(spike_stats_gt, spike_stats_ae, figsize=cm2inch(12, 4), color=\"midnightblue\", labels=[\"gt\", \"ae\"])\n",
                "plot_spiketrain_stats(spike_stats_gt, spike_stats_diff, figsize=cm2inch(12, 4), color=\"darkred\", labels=[\"gt\", \"diffusion\"])\n",
                "plot_spiketrain_stats(spike_stats_ae, spike_stats_diff, figsize=cm2inch(12, 4), color=\"purple\", labels=[\"ae\", \"diffusion\"])\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.plot(spike_stats_gt[\"mean_isi\"],spike_stats_ae[\"mean_isi\"], 'o')\n",
                "plt.plot([0,0.3],[0,0.3], 'k-')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "## eval to decode behavior from generated (smooth) rates, compare with real behavior\n",
                "\n",
                "\n",
                "def gen_rates_and_train_decoded_behavior(\n",
                "    ema_denoiser,\n",
                "    scheduler,\n",
                "    ae,\n",
                "    cfg,\n",
                "    train_latent_dataloader,\n",
                "    val_latent_dataloader,\n",
                "    num_samples=100,\n",
                "    device=\"cuda\",\n",
                "):\n",
                "\n",
                "    avg_denoiser = ema_denoiser.averaged_model\n",
                "    avg_denoiser.eval()\n",
                "\n",
                "    ae.eval()\n",
                "\n",
                "    train_rates = []\n",
                "    train_behavior = []\n",
                "    for batch in train_latent_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        behavior = batch[\"behavior\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = ae(signal)[0].cpu()\n",
                "        train_rates.append(output_rates)\n",
                "        train_behavior.append(behavior.cpu())\n",
                "\n",
                "    train_rates = torch.cat(train_rates, 0)  # [B C L]\n",
                "    train_behavior = torch.cat(train_behavior, 0)  # [B 2 L]\n",
                "    print(train_rates, train_behavior)\n",
                "\n",
                "    val_rates = []\n",
                "    val_behavior = []\n",
                "    for batch in val_latent_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        behavior = batch[\"behavior\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = ae(signal)[0].cpu()\n",
                "        val_rates.append(output_rates)\n",
                "        val_behavior.append(behavior.cpu())\n",
                "\n",
                "    val_rates = torch.cat(val_rates, 0)  # [B C L]\n",
                "    val_behavior = torch.cat(val_behavior, 0)  # [B 2 L]\n",
                "    print(val_rates, val_behavior)\n",
                "\n",
                "    # decode rates to behavior\n",
                "    from sklearn.linear_model import Ridge\n",
                "\n",
                "    bs_train = train_rates.shape[0]\n",
                "    bs_val = val_rates.shape[0]\n",
                "    train_rates = rearrange(train_rates, \"b c l -> (b l) c\").numpy()\n",
                "    val_rates = rearrange(val_rates, \"b c l -> (b l) c\").numpy()\n",
                "    train_behavior = rearrange(train_behavior, \"b c l -> (b l) c\").numpy()\n",
                "    val_behavior = rearrange(val_behavior, \"b c l -> (b l) c\").numpy()\n",
                "\n",
                "\n",
                "    RidgeRegressionModel = Ridge(alpha=1e-6)\n",
                "    RidgeRegressionModel.fit(train_rates, train_behavior)\n",
                "    predicted_behavior = RidgeRegressionModel.predict(val_rates)\n",
                "    # r2 score\n",
                "    from sklearn.metrics import r2_score\n",
                "\n",
                "    r2 = r2_score(val_behavior, predicted_behavior)\n",
                "    print(f\"R2 score on val: {r2:.3f}\")\n",
                "\n",
                "    \n",
                "    # sample from the denoiser\n",
                "    sampled_latents = sample(\n",
                "        ema_denoiser=ema_denoiser,\n",
                "        scheduler=scheduler,\n",
                "        cfg=cfg,\n",
                "        batch_size=num_samples,\n",
                "        device=device,\n",
                "    )\n",
                "    sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "        sampled_latents.device\n",
                "    ) + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "\n",
                "    with torch.no_grad():\n",
                "        sampled_rates = ae.decode(sampled_latents).cpu()\n",
                "\n",
                "    sampled_rates = rearrange(sampled_rates, \"b c l -> (b l) c\").numpy()\n",
                "    predicted_sampled_behavior = RidgeRegressionModel.predict(sampled_rates)\n",
                "\n",
                "    return {\n",
                "        \"predicted_val_behavior\": rearrange(\n",
                "            predicted_behavior, \"(b l) c -> b c l\", b=bs_val\n",
                "        ),\n",
                "        \"real_val_behavior\": rearrange(val_behavior, \"(b l) c -> b c l\", b=bs_val),\n",
                "        \"predicted_sampled_behavior\": rearrange(predicted_sampled_behavior, \"(b l) c -> b c l\", b=num_samples),\n",
                "    }\n",
                "\n",
                "\n",
                "\n",
                "behav_dict = gen_rates_and_train_decoded_behavior(\n",
                "    ema_denoiser=ema_model,\n",
                "    scheduler=scheduler,\n",
                "    ae=ae_model,\n",
                "    cfg=cfg,\n",
                "    train_latent_dataloader=train_latent_dataloader,\n",
                "    val_latent_dataloader=test_latent_dataloader,\n",
                "    num_samples=20,\n",
                "    device=\"cuda\",\n",
                ")\n",
                "\n",
                "\n",
                "predicted_val_behavior = behav_dict[\"predicted_val_behavior\"]\n",
                "real_val_behavior = behav_dict[\"real_val_behavior\"]\n",
                "predicted_sampled_behavior = behav_dict[\"predicted_sampled_behavior\"]\n",
                "\n",
                "predicted_traj = np.cumsum(predicted_val_behavior, axis=-1)\n",
                "real_traj = np.cumsum(real_val_behavior, axis=-1)\n",
                "predicted_sampled_traj = np.cumsum(predicted_sampled_behavior, axis=-1)\n",
                "\n",
                "print(predicted_traj.shape, real_traj.shape, predicted_sampled_traj.shape)\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(10, 5))\n",
                "\n",
                "colors = sns.color_palette(\"tab10\")\n",
                "for i, idx in enumerate(range(0, 70, 7)):\n",
                "    ax[0].plot(\n",
                "        real_traj[idx, 0, :], real_traj[idx, 1, :], label=\"Real trajectory\", color=colors[i]\n",
                "    )\n",
                "    ax[0].plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"Predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=colors[i],\n",
                "    )\n",
                "ax[0].set_xlabel(\"X position\")\n",
                "ax[0].set_ylabel(\"Y position\")\n",
                "ax[0].set_title(\"Real vs predicted trajectory (Val)\")\n",
                "\n",
                "colors = sns.color_palette(\"tab20\")\n",
                "for i, idx in enumerate(range(0, 20)):\n",
                "    ax[1].plot(\n",
                "        predicted_sampled_traj[idx, 0, :],\n",
                "        predicted_sampled_traj[idx, 1, :],\n",
                "        # label=\"Predicted trajectory\",\n",
                "        linestyle=\"-\",\n",
                "        color=colors[i],\n",
                "    )\n",
                "ax[1].set_xlabel(\"X position\")\n",
                "ax[1].set_ylabel(\"Y position\")\n",
                "ax[1].set_title(\"predicted trajectory (Sampled)\")\n",
                "\n",
                "fig.tight_layout()\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(10, 5))\n",
                "\n",
                "colors = sns.color_palette(\"tab10\")\n",
                "for i, idx in enumerate(range(0, 70, 7)):\n",
                "    ax[0].plot(\n",
                "        real_traj[idx, 0, :], real_traj[idx, 1, :], label=\"real trajectory\", color=colors[i]\n",
                "    )\n",
                "    ax[0].plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"Predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=colors[i],\n",
                "    )\n",
                "ax[0].set_xlabel(\"x position\")\n",
                "ax[0].set_ylabel(\"y position\")\n",
                "ax[0].set_title(\"real vs predicted trajectory\")\n",
                "\n",
                "colors = sns.color_palette(\"tab20\")\n",
                "for i, idx in enumerate(range(0, 20)):\n",
                "    ax[1].plot(\n",
                "        predicted_sampled_traj[idx, 0, :],\n",
                "        predicted_sampled_traj[idx, 1, :],\n",
                "        # label=\"Predicted trajectory\",\n",
                "        linestyle=\"-\",\n",
                "        color=colors[i],\n",
                "    )\n",
                "ax[1].set_xlabel(\"x position\")\n",
                "ax[1].set_ylabel(\"y position\")\n",
                "ax[1].set_title(\"unconditional diffusion\")\n",
                "\n",
                "fig.tight_layout()\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "behav_dict = compute_decoded_behavior(\n",
                "    ae_model, train_latent_dataloader, test_latent_dataloader\n",
                ")\n",
                "predicted_val_behavior = behav_dict[\"predicted_val_behavior\"]\n",
                "real_val_behavior = behav_dict[\"real_val_behavior\"]\n",
                "predicted_traj = np.cumsum(predicted_val_behavior, axis=-1)\n",
                "real_traj = np.cumsum(real_val_behavior, axis=-1)\n",
                "# get the reach angle \n",
                "angle_real_val_behavior = np.arctan2(real_traj[:, 1, 50], real_traj[:, 0, 50])\n",
                "\n",
                "print(predicted_traj.shape, real_traj.shape)\n",
                "plt.figure(figsize=cm2inch(3, 2.2))\n",
                "\n",
                "for i, idx in enumerate(range(0, 70, 5)):\n",
                "    plt.plot(\n",
                "        real_traj[idx, 0, :], real_traj[idx, 1, :],\n",
                "         label=\"real trajectory\", color=angle_to_color(angle_real_val_behavior[idx]),lw=1\n",
                "    )\n",
                "    plt.plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=angle_to_color(angle_real_val_behavior[idx]),\n",
                "        lw=1\n",
                "    )\n",
                "plt.xlabel(\"x position\")\n",
                "plt.ylabel(\"y position\")\n",
                "# plt.legend()\n",
                "plt.title(\"real vs predicted trajectory\")\n",
                "\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "# set ticks formatting for colorbar in terms of pi\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), label=\"behavior angle\", ax=plt.gca())\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "plt.xlabel('x')\n",
                "plt.ylabel('y')\n",
                "plt.savefig(save_path + \"test_reaches_and_recs_colored_by_angles.png\")\n",
                "plt.savefig(save_path + \"test_reaches_and_recs_colored_by_angles.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save_path"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "\n",
                "### 3. Train diffusion to generate rates conditioned on behavior angle, check if these spikes can be decoded into behavior and its angle\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "cfg_yaml = \"\"\"\n",
                "denoiser_model:\n",
                "  C_in: 16\n",
                "  C: 256\n",
                "  kernel: s4\n",
                "  num_blocks: 6\n",
                "  bidirectional: True\n",
                "  condition_dim: 2 # behavior angle [sin, cos]\n",
                "  num_train_timesteps: 1000\n",
                "training:\n",
                "  lr: 0.001\n",
                "  weight_decay: 0.0\n",
                "  num_epochs: 2000\n",
                "  num_warmup_epochs: 50\n",
                "  batch_size: 512\n",
                "  random_seed: 42\n",
                "  precision: \"no\"\n",
                "exp_name: diffusion_s4-monkey_cond_angle_new_reg_epoch_140\n",
                "\"\"\"\n",
                "\n",
                "cfg = OmegaConf.create(yaml.safe_load(cfg_yaml))\n",
                "cfg.dataset = cfg_ae.dataset"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# ## initialize (unconditional) denoiser\n",
                "\n",
                "# from ntldm.networks import ConditionalDenoiser\n",
                "\n",
                "# denoiser = ConditionalDenoiser(\n",
                "#     C_in=cfg.denoiser_model.C_in,\n",
                "#     C=cfg.denoiser_model.C,\n",
                "#     L=cfg.dataset.signal_length,\n",
                "#     kernel=cfg.denoiser_model.kernel,\n",
                "#     num_blocks=cfg.denoiser_model.num_blocks,\n",
                "#     bidirectional=cfg.denoiser_model.get(\"bidirectional\", True),\n",
                "#     condition_dim=cfg.denoiser_model.condition_dim,\n",
                "# )\n",
                "\n",
                "# # initial values may be way off so better to scale down the output layer\n",
                "# denoiser.conv_out.weight.data = denoiser.conv_out.weight.data * 0.1\n",
                "# denoiser.conv_out.bias.data = denoiser.conv_out.bias.data * 0.1\n",
                "\n",
                "# scheduler = DDPMScheduler(\n",
                "#     num_train_timesteps=cfg.denoiser_model.num_train_timesteps,\n",
                "#     clip_sample=False,\n",
                "#     beta_schedule=\"linear\", # ddpm doesnt support cosine\n",
                "# )\n",
                "\n",
                "\n",
                "# optimizer = torch.optim.AdamW(\n",
                "#     denoiser.parameters(), lr=cfg.training.lr\n",
                "# )  # default wd=0.01 for now\n",
                "\n",
                "\n",
                "\n",
                "# num_batches = len(train_latent_dataloader)\n",
                "# lr_scheduler = get_scheduler(\n",
                "#     name=\"cosine\",\n",
                "#     optimizer=optimizer,\n",
                "#     num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "#     num_training_steps=num_batches * cfg.training.num_epochs * 1.3,  # total number of steps\n",
                "# )\n",
                "\n",
                "# # prepare the denoiser model and dataset\n",
                "# (\n",
                "#     denoiser,\n",
                "#     optimizer,\n",
                "#     train_latent_dataloader,\n",
                "#     val_latent_dataloader,\n",
                "#     lr_scheduler,\n",
                "# ) = accelerator.prepare(\n",
                "#     denoiser,\n",
                "#     optimizer,\n",
                "#     train_latent_dataloader,\n",
                "#     val_latent_dataloader,\n",
                "#     lr_scheduler,\n",
                "# )\n",
                "\n",
                "# ema_model = EMAModel(denoiser)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def sample_conditioned_on_angle(\n",
                "    ema_denoiser,\n",
                "    scheduler,\n",
                "    cfg,\n",
                "    batch_size=1,\n",
                "    generator=None,\n",
                "    device=\"cuda\",\n",
                "    signal_length=None,\n",
                "    angle: float = None,\n",
                "):\n",
                "    if signal_length is None:\n",
                "        signal_length = cfg.dataset.signal_length\n",
                "\n",
                "    if angle is None:\n",
                "        angle = torch.rand(batch_size) * 2 * np.pi - np.pi # sample unoformly from -pi to pi\n",
                "        angle_c = angle.unsqueeze(1)\n",
                "    else:\n",
                "        #angle_c = torch.tensor(angle).unsqueeze(1).repeat(batch_size, 1)\n",
                "        # Ensure angle is a tensor and has correct shape\n",
                "        if not isinstance(angle, torch.Tensor):\n",
                "            angle = torch.tensor([angle], dtype=torch.float32)\n",
                "        if angle.dim() == 0:\n",
                "            angle = angle.unsqueeze(0)  # Make it 1D if it's a scalar\n",
                "        angle_c = angle.unsqueeze(1).repeat(batch_size, 1)\n",
                "\n",
                "    z_t = torch.randn((batch_size, cfg.denoiser_model.C_in, signal_length)).to(device)\n",
                "\n",
                "    ema_denoiser_avg = ema_denoiser.averaged_model\n",
                "\n",
                "    scheduler.set_timesteps(cfg.denoiser_model.num_train_timesteps)\n",
                "\n",
                "    for t in tqdm(scheduler.timesteps, desc=\"Sampling DDPM\"):\n",
                "        with torch.no_grad():\n",
                "            model_output = ema_denoiser_avg(\n",
                "                z_t,\n",
                "                t=torch.tensor([t] * batch_size).to(device).long(),\n",
                "                c=torch.cat([torch.cos(angle_c), torch.sin(angle_c)], dim=1),\n",
                "            )\n",
                "        z_t = scheduler.step(\n",
                "            model_output, t, z_t, generator=generator, return_dict=False\n",
                "        )[0]\n",
                "\n",
                "    return z_t, angle"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "## eval to decode behavior from generated (smooth) rates, compare with real behavior\n",
                "\n",
                "\n",
                "def gen_rates_and_train_decoded_behavior_conditioned_on_angle(\n",
                "    ema_denoiser,\n",
                "    scheduler,\n",
                "    ae,\n",
                "    cfg,\n",
                "    train_latent_dataloader,\n",
                "    val_latent_dataloader,\n",
                "    num_samples=100,\n",
                "    device=\"cuda\",\n",
                "    angle=None,\n",
                "):\n",
                "\n",
                "    avg_denoiser = ema_denoiser.averaged_model\n",
                "    avg_denoiser.eval()\n",
                "\n",
                "    ae.eval()\n",
                "\n",
                "    train_rates = []\n",
                "    train_behavior = []\n",
                "    for batch in train_latent_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        behavior = batch[\"behavior\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = ae(signal)[0].cpu()\n",
                "        train_rates.append(output_rates)\n",
                "        train_behavior.append(behavior.cpu())\n",
                "\n",
                "    train_rates = torch.cat(train_rates, 0)  # [B C L]\n",
                "    train_behavior = torch.cat(train_behavior, 0)  # [B 2 L]\n",
                "    print(train_rates, train_behavior)\n",
                "\n",
                "    val_rates = []\n",
                "    val_behavior = []\n",
                "    for batch in val_latent_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        behavior = batch[\"behavior\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = ae(signal)[0].cpu()\n",
                "        val_rates.append(output_rates)\n",
                "        val_behavior.append(behavior.cpu())\n",
                "\n",
                "    val_rates = torch.cat(val_rates, 0)  # [B C L]\n",
                "    val_behavior = torch.cat(val_behavior, 0)  # [B 2 L]\n",
                "    print(val_rates, val_behavior)\n",
                "\n",
                "    # decode rates to behavior\n",
                "    from sklearn.linear_model import Ridge\n",
                "\n",
                "    bs_train = train_rates.shape[0]\n",
                "    bs_val = val_rates.shape[0]\n",
                "    train_rates = rearrange(train_rates, \"b c l -> (b l) c\").numpy()\n",
                "    val_rates = rearrange(val_rates, \"b c l -> (b l) c\").numpy()\n",
                "    train_behavior = rearrange(train_behavior, \"b c l -> (b l) c\").numpy()\n",
                "    val_behavior = rearrange(val_behavior, \"b c l -> (b l) c\").numpy()\n",
                "\n",
                "\n",
                "    RidgeRegressionModel = Ridge(alpha=1e-6)\n",
                "    RidgeRegressionModel.fit(train_rates, train_behavior)\n",
                "    predicted_behavior = RidgeRegressionModel.predict(val_rates)\n",
                "    # r2 score\n",
                "    from sklearn.metrics import r2_score\n",
                "\n",
                "    r2 = r2_score(val_behavior, predicted_behavior)\n",
                "    print(f\"R2 score on val: {r2:.3f}\")\n",
                "\n",
                "    \n",
                "    # sample from the denoiser\n",
                "    sampled_latents, angles = sample_conditioned_on_angle(\n",
                "        ema_denoiser=ema_denoiser,\n",
                "        scheduler=scheduler,\n",
                "        cfg=cfg,\n",
                "        batch_size=num_samples,\n",
                "        device=device,\n",
                "        angle=angle,\n",
                "    )\n",
                "    sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "        sampled_latents.device\n",
                "    ) + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "\n",
                "    with torch.no_grad():\n",
                "        sampled_rates = ae.decode(sampled_latents).cpu()\n",
                "\n",
                "    sampled_rates = rearrange(sampled_rates, \"b c l -> (b l) c\").numpy()\n",
                "    predicted_sampled_behavior = RidgeRegressionModel.predict(sampled_rates)\n",
                "\n",
                "    return {\n",
                "        \"predicted_val_behavior\": rearrange(\n",
                "            predicted_behavior, \"(b l) c -> b c l\", b=bs_val\n",
                "        ),\n",
                "        \"real_val_behavior\": rearrange(val_behavior, \"(b l) c -> b c l\", b=bs_val),\n",
                "        \"predicted_sampled_behavior\": rearrange(predicted_sampled_behavior, \"(b l) c -> b c l\", b=num_samples),\n",
                "        \"angles\": angles,\n",
                "    }\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "loss_fn = torch.nn.SmoothL1Loss(\n",
                "    beta=0.1, reduction=\"mean\"\n",
                ")  # faster convergence than mse\n",
                "\n",
                "pbar = tqdm(range(0, cfg.training.num_epochs), desc=\"epochs\")\n",
                "for epoch in pbar:\n",
                "    for i, batch in enumerate(train_latent_dataloader):\n",
                "\n",
                "        optimizer.zero_grad()\n",
                "\n",
                "        z = batch[\"latent\"]\n",
                "        t = torch.randint(\n",
                "            0, cfg.denoiser_model.num_train_timesteps, (z.shape[0],), device=\"cpu\"\n",
                "        ).long()\n",
                "\n",
                "        c = torch.cat(\n",
                "            [torch.cos(batch[\"behavior_angle\"]), torch.sin(batch[\"behavior_angle\"])],\n",
                "            dim=1,\n",
                "        )\n",
                "\n",
                "        # print(z.shape, t.shape)\n",
                "        noise = torch.randn_like(z)\n",
                "        noisy_z = scheduler.add_noise(z, noise, t)\n",
                "        noise_pred = denoiser(noisy_z, t=t, c=c)\n",
                "\n",
                "        loss = loss_fn(noise, noise_pred)\n",
                "\n",
                "        accelerator.backward(loss)\n",
                "        accelerator.clip_grad_norm_(denoiser.parameters(), 1.0)\n",
                "\n",
                "        optimizer.step()\n",
                "        lr_scheduler.step()\n",
                "\n",
                "        if i % 10 == 0:\n",
                "            pbar.set_postfix({\"loss\": loss.item(), \"lr\": lr_scheduler.get_last_lr()[0]})\n",
                "\n",
                "        ema_model.step(denoiser)\n",
                "\n",
                "    if (epoch) % 100 == 0:\n",
                "\n",
                "        # plot samples\n",
                "        sampled_latents, angles = sample_conditioned_on_angle(\n",
                "            ema_denoiser=ema_model,\n",
                "            scheduler=scheduler,\n",
                "            cfg=cfg,\n",
                "            batch_size=2,\n",
                "            device=\"cuda\",\n",
                "        )\n",
                "        sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "            sampled_latents.device\n",
                "        ) + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "\n",
                "        real_latents = latent_dataset_train.latents[:2].cuda()\n",
                "        real_latents = real_latents * latent_dataset_train.latent_stds.to(\n",
                "            real_latents.device\n",
                "        ) + latent_dataset_train.latent_means.to(real_latents.device)\n",
                "\n",
                "        with torch.no_grad():\n",
                "            sampled_rates = ae_model.decode(sampled_latents).cpu()\n",
                "            decoded_rates_from_real_latents = ae_model.decode(real_latents).cpu()\n",
                "\n",
                "        fig, ax = plt.subplots(1, 2, figsize=cm2inch(12, 4))\n",
                "        im = ax[0].imshow(sampled_rates[0], aspect=\"auto\")\n",
                "        ax[0].set_title(\"Sampled rates\")\n",
                "        fig.colorbar(im, ax=ax[0], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "\n",
                "        im = ax[1].imshow(decoded_rates_from_real_latents[0], aspect=\"auto\")\n",
                "        ax[1].set_title(\"Real rates\")\n",
                "        fig.colorbar(im, ax=ax[1], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "        fig.tight_layout()\n",
                "        plt.show()\n",
                "\n",
                "        # save\n",
                "        accelerator.save_state(f\"exp/{cfg.exp_name}/epoch_{epoch}\")\n",
                "\n",
                "        # plotting inferred behaviors from sampled rates\n",
                "\n",
                "        if epoch < 330:\n",
                "            continue # only plot trajectories once sufficiently learned\n",
                "\n",
                "        behav_dict = gen_rates_and_train_decoded_behavior_conditioned_on_angle(\n",
                "            ema_denoiser=ema_model,\n",
                "            scheduler=scheduler,\n",
                "            ae=ae_model,\n",
                "            cfg=cfg,\n",
                "            train_latent_dataloader=train_latent_dataloader,\n",
                "            val_latent_dataloader=val_latent_dataloader,\n",
                "            num_samples=20,\n",
                "            device=\"cuda\",\n",
                "        )\n",
                "\n",
                "        predicted_val_behavior = behav_dict[\"predicted_val_behavior\"]\n",
                "        real_val_behavior = behav_dict[\"real_val_behavior\"]\n",
                "        predicted_sampled_behavior = behav_dict[\"predicted_sampled_behavior\"]\n",
                "        angles = behav_dict[\"angles\"]\n",
                "\n",
                "        predicted_traj = np.cumsum(predicted_val_behavior, axis=-1)\n",
                "        real_traj = np.cumsum(real_val_behavior, axis=-1)\n",
                "        predicted_sampled_traj = np.cumsum(predicted_sampled_behavior, axis=-1)\n",
                "\n",
                "        print(predicted_traj.shape, real_traj.shape, predicted_sampled_traj.shape)\n",
                "\n",
                "        fig, ax = plt.subplots(1, 2, figsize=(6, 2))\n",
                "\n",
                "        \n",
                "        # create cmap based on behavior angles and plot behavior_cumsum on 2d plot \n",
                "        cmap = plt.get_cmap(\"hsv\")\n",
                "        colors = cmap(angles)\n",
                "        # angles = (angles * 20).numpy().tolist()\n",
                "        for i, idx in enumerate(range(0, 20)):\n",
                "            ax[1].plot(\n",
                "                predicted_sampled_traj[idx, 0, :],\n",
                "                predicted_sampled_traj[idx, 1, :],\n",
                "                # label=\"Predicted trajectory\",\n",
                "                linestyle=\"-\",\n",
                "                color=colors[i],\n",
                "            ) \n",
                "        ax[1].set_xlabel(\"X position\")\n",
                "        ax[1].set_ylabel(\"Y position\")\n",
                "        ax[1].set_title(\"Predicted trajectory (Sampled conditioned on $\\\\theta$)\")\n",
                "\n",
                "        # set ticks formatting for colorbar in terms of pi\n",
                "        sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "        sm.set_array([])\n",
                "        cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), label=\"Behavior angle\", ax=ax[1])\n",
                "        cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "\n",
                "        fig.tight_layout()\n",
                "        plt.show()\n",
                "        \n",
                "pbar.close()\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save_new_model = False\n",
                "load_model = True\n",
                "\n",
                "if save_new_model:\n",
                "\n",
                "    if accelerator.is_main_process:\n",
                "        os.makedirs(f\"exp/{cfg.exp_name}\", exist_ok=True)\n",
                "        torch.save(accelerator.unwrap_model(denoiser).state_dict(), f\"exp/{cfg.exp_name}/model.pt\")\n",
                "        with open(f\"conf/sweeps_count/{cfg.exp_name}_save_after_train.yaml\", \"w\") as f:\n",
                "            f.write(OmegaConf.to_yaml(cfg))\n",
                "\n",
                "        print('saved new model at ', f\"exp/{cfg.exp_name}/model.pt\")\n",
                "            \n",
                "elif load_model:\n",
                "\n",
                "    #     # load the congig and model path\n",
                "    # with open(f\"conf/sweeps_count/{cfg.exp_name}.yaml\") as f:\n",
                "    #     cfg = OmegaConf.create(yaml.safe_load(f))\n",
                "\n",
                "    with open(f\"conf/sweeps_count/{cfg.exp_name}_save_after_train.yaml\") as f:\n",
                "        cfg = OmegaConf.create(yaml.safe_load(f))\n",
                "        \n",
                "    from ntldm.networks import ConditionalDenoiser\n",
                "\n",
                "    denoiser = ConditionalDenoiser(\n",
                "        C_in=cfg.denoiser_model.C_in,\n",
                "        C=cfg.denoiser_model.C,\n",
                "        L=cfg.dataset.signal_length,\n",
                "        kernel=cfg.denoiser_model.kernel,\n",
                "        num_blocks=cfg.denoiser_model.num_blocks,\n",
                "        bidirectional=cfg.denoiser_model.get(\"bidirectional\", True),\n",
                "        condition_dim=cfg.denoiser_model.condition_dim,\n",
                "    )\n",
                "\n",
                "    # initial values may be way off so better to scale down the output layer\n",
                "    denoiser.conv_out.weight.data = denoiser.conv_out.weight.data * 0.1\n",
                "    denoiser.conv_out.bias.data = denoiser.conv_out.bias.data * 0.1\n",
                "\n",
                "    scheduler = DDPMScheduler(\n",
                "        num_train_timesteps=cfg.denoiser_model.num_train_timesteps,\n",
                "        clip_sample=False,\n",
                "        beta_schedule=\"linear\", # ddpm doesnt support cosine\n",
                "    )\n",
                "\n",
                "\n",
                "    optimizer = torch.optim.AdamW(\n",
                "        denoiser.parameters(), lr=cfg.training.lr\n",
                "    )  # default wd=0.01 for now\n",
                "\n",
                "\n",
                "\n",
                "    num_batches = len(train_latent_dataloader)\n",
                "    lr_scheduler = get_scheduler(\n",
                "        name=\"cosine\",\n",
                "        optimizer=optimizer,\n",
                "        num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "        num_training_steps=num_batches * cfg.training.num_epochs * 1.3,  # total number of steps\n",
                "    )\n",
                "\n",
                "    denoiser.load_state_dict(torch.load(f\"exp/{cfg.exp_name}/model.pt\", map_location=\"cpu\"))\n",
                "\n",
                "    # prepare the denoiser model and dataset\n",
                "    (\n",
                "        denoiser,\n",
                "        optimizer,\n",
                "        train_latent_dataloader,\n",
                "        val_latent_dataloader,\n",
                "        test_latent_dataloader,\n",
                "        lr_scheduler,\n",
                "    ) = accelerator.prepare(\n",
                "        denoiser,\n",
                "        optimizer,\n",
                "        train_latent_dataloader,\n",
                "        val_latent_dataloader,\n",
                "        test_latent_dataloader,\n",
                "        lr_scheduler,\n",
                "    )\n",
                "\n",
                "    ema_model = EMAModel(denoiser)\n",
                "\n",
                "        \n",
                "    print(f\"loaded model from exp/{cfg.exp_name}/model.pt\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# plot samples\n",
                "sampled_latents, angles = sample_conditioned_on_angle(\n",
                "    ema_denoiser=ema_model,\n",
                "    scheduler=scheduler,\n",
                "    cfg=cfg,\n",
                "    batch_size=2,\n",
                "    device=\"cuda\",\n",
                ")\n",
                "sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "    sampled_latents.device\n",
                ") + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "\n",
                "real_latents = latent_dataset_train.latents[:2].cuda()\n",
                "real_latents = real_latents * latent_dataset_train.latent_stds.to(\n",
                "    real_latents.device\n",
                ") + latent_dataset_train.latent_means.to(real_latents.device)\n",
                "\n",
                "with torch.no_grad():\n",
                "    sampled_rates = ae_model.decode(sampled_latents).cpu()\n",
                "    decoded_rates_from_real_latents = ae_model.decode(real_latents).cpu()\n",
                "\n",
                "\n",
                "sampled_rates_poisson = torch.poisson(sampled_rates)\n",
                "decoded_rates_from_real_latents_poisson = torch.poisson(decoded_rates_from_real_latents)\n",
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(10, 3))\n",
                "im = ax[0].imshow(sampled_rates_poisson[0], aspect=\"auto\", cmap='Greys', vmax=3)\n",
                "ax[0].set_title(\"sampled\")\n",
                "fig.colorbar(im, ax=ax[0], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "\n",
                "im = ax[1].imshow(latent_dataset_train.train_spikes[0], aspect=\"auto\", cmap='Greys', vmax=3)\n",
                "ax[1].set_title(\"data\")\n",
                "fig.colorbar(im, ax=ax[1], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "fig.tight_layout()\n",
                "\n",
                "plt.show()\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(10, 3.3))\n",
                "\n",
                "im = ax[0].imshow(latent_dataset_train.train_spikes[0], aspect=\"auto\", cmap='Greys')\n",
                "ax[0].set_title(\"data\")\n",
                "fig.colorbar(im, ax=ax[0], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "fig.tight_layout()\n",
                "\n",
                "\n",
                "im = ax[1].imshow(sampled_rates_poisson[0], aspect=\"auto\", cmap='Greys')\n",
                "ax[1].set_title(\"sampled\")\n",
                "fig.colorbar(im, ax=ax[1], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "\n",
                "\n",
                "plt.show()\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "behav_dict = gen_rates_and_train_decoded_behavior_conditioned_on_angle(\n",
                "    ema_denoiser=ema_model,\n",
                "    scheduler=scheduler,\n",
                "    ae=ae_model,\n",
                "    cfg=cfg,\n",
                "    train_latent_dataloader=train_latent_dataloader,\n",
                "    val_latent_dataloader=test_latent_dataloader,\n",
                "    num_samples=200,\n",
                "    device=\"cuda\",\n",
                ")\n",
                "\n",
                "predicted_val_behavior = behav_dict[\"predicted_val_behavior\"]\n",
                "real_val_behavior = behav_dict[\"real_val_behavior\"]\n",
                "predicted_sampled_behavior = behav_dict[\"predicted_sampled_behavior\"]\n",
                "angles = behav_dict[\"angles\"].numpy() + np.pi # shift to 0-2pi\n",
                "angles = angles / (2 * np.pi) # normalize to 0-1 for the colourmap\n",
                "\n",
                "\n",
                "predicted_traj = np.cumsum(predicted_val_behavior, axis=-1)\n",
                "real_traj = np.cumsum(real_val_behavior, axis=-1)\n",
                "predicted_sampled_traj = np.cumsum(predicted_sampled_behavior, axis=-1)\n",
                "\n",
                "print(predicted_traj.shape, real_traj.shape, predicted_sampled_traj.shape, angles.shape)\n",
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(12, 5))\n",
                "\n",
                "colors = sns.color_palette(\"tab10\")\n",
                "for i, idx in enumerate(range(0, 70, 7)):\n",
                "    ax[0].plot(\n",
                "        real_traj[idx, 0, :],\n",
                "        real_traj[idx, 1, :],\n",
                "        label=\"Real trajectory\",\n",
                "        color=colors[i],\n",
                "    )\n",
                "    ax[0].plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"Predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=colors[i],\n",
                "    )\n",
                "ax[0].set_xlabel(\"X position\")\n",
                "ax[0].set_ylabel(\"Y position\")\n",
                "ax[0].set_title(\"Real vs predicted trajectory (Val)\")\n",
                "\n",
                "# create cmap based on behavior angles and plot behavior_cumsum on 2d plot \n",
                "cmap = plt.get_cmap(\"hsv\")\n",
                "colors = cmap(angles)\n",
                "# angles = (angles * 20).numpy().tolist()\n",
                "for i, idx in enumerate(range(0, 20)):\n",
                "    ax[1].plot(\n",
                "        predicted_sampled_traj[idx, 0, :],\n",
                "        predicted_sampled_traj[idx, 1, :],\n",
                "        # label=\"Predicted trajectory\",\n",
                "        linestyle=\"-\",\n",
                "        color=colors[i],\n",
                "    ) \n",
                "ax[1].set_xlabel(\"X position\")\n",
                "ax[1].set_ylabel(\"Y position\")\n",
                "ax[1].set_title(\"Predicted trajectory (Sampled conditioned on $\\\\theta$)\")\n",
                "\n",
                "# set ticks formatting for colorbar in terms of pi\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), label=\"Behavior angle\", ax=ax[1])\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "\n",
                "fig.tight_layout()\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "fig, ax = plt.subplots(1, 3, figsize=cm2inch(12, 3.5), sharex=True, sharey=True)\n",
                "\n",
                "colors = sns.color_palette(\"Blues\")\n",
                "# for i, idx in enumerate(range(0, 70, 1)):\n",
                "#     ax[0].plot(\n",
                "#         real_traj[idx, 0, :],\n",
                "#         real_traj[idx, 1, :],\n",
                "#         label=\"real trajectory\",\n",
                "#         color=colors[i%10],\n",
                "#         alpha=0.5,\n",
                "#         lw=0.5,\n",
                "#     )\n",
                "#     ax[0].plot(\n",
                "#         predicted_traj[idx, 0, :],\n",
                "#         predicted_traj[idx, 1, :],\n",
                "#         label=\"predicted trajectory\",\n",
                "#         linestyle=\"--\",\n",
                "#         color=colors[i%10],\n",
                "#         alpha=0.5,\n",
                "#         lw=0.5,\n",
                "#     )\n",
                "\n",
                "for i, idx in enumerate(range(0, 70, 5)):\n",
                "    ax[0].plot(\n",
                "        real_traj[idx, 0, :],\n",
                "        real_traj[idx, 1, :],\n",
                "        label=\"real trajectory\",\n",
                "        color=colors[i%5],\n",
                "        lw=0.5,\n",
                "    )\n",
                "    ax[0].plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=colors[i%5],\n",
                "        lw=0.5,\n",
                "    )\n",
                "ax[0].set_xlabel(\"x position\")\n",
                "ax[0].set_ylabel(\"y position\")\n",
                "ax[0].set_title(\"real vs decoded\")\n",
                "\n",
                "# plot colorbar\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(0, 1))\n",
                "sm.set_array([])\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5),  ax=ax[0])\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "\n",
                "# create cmap based on behavior angles and plot behavior_cumsum on 2d plot \n",
                "cmap = plt.get_cmap(\"hsv\")\n",
                "colors = cmap(angles)\n",
                "# angles = (angles * 20).numpy().tolist()\n",
                "for i, idx in enumerate(range(0, 100)):\n",
                "    ax[1].plot(\n",
                "        predicted_sampled_traj[idx, 0, :],\n",
                "        predicted_sampled_traj[idx, 1, :],\n",
                "        # label=\"Predicted trajectory\",\n",
                "        linestyle=\"-\",\n",
                "        color=colors[i],\n",
                "        alpha=0.5,\n",
                "        lw=0.5,\n",
                "    ) \n",
                "\n",
                "ax[1].set_xlabel(\"x position\")\n",
                "ax[1].set_title(\"diffusion conditioned on $\\\\theta$\")\n",
                "# set ticks formatting for colorbar in terms of pi\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), ax=ax[1])\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "\n",
                "colors = sns.color_palette(\"Blues\")\n",
                "\n",
                "for i, idx in enumerate(range(0, 70, 5)):\n",
                "    ax[2].plot(\n",
                "        real_traj[idx, 0, :],\n",
                "        real_traj[idx, 1, :],\n",
                "        label=\"real trajectory\",\n",
                "        color=colors[i%5],\n",
                "        lw=0.5,\n",
                "    )\n",
                "    ax[2].plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=colors[i%5],\n",
                "        lw=0.5,\n",
                "    )\n",
                "ax[2].set_xlabel(\"x position\")\n",
                "ax[2].set_ylabel(\"y position\")\n",
                "ax[2].set_title(\"real vs decoded\")\n",
                "\n",
                "# plot colorbar\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(0, 1))\n",
                "sm.set_array([])\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5),  ax=ax[2])\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "\n",
                "\n",
                "fig.tight_layout()\n",
                "plt.savefig(save_path + \"Fig_4_all_conditinging_figure.png\")\n",
                "plt.savefig(save_path + \"Fig_4_all_conditinging_figure.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "behav_dict = compute_decoded_behavior(\n",
                "    ae_model, train_latent_dataloader, test_latent_dataloader\n",
                ")\n",
                "predicted_val_behavior = behav_dict[\"predicted_val_behavior\"]\n",
                "real_val_behavior = behav_dict[\"real_val_behavior\"]\n",
                "predicted_traj = np.cumsum(predicted_val_behavior, axis=-1)\n",
                "real_traj = np.cumsum(real_val_behavior, axis=-1)\n",
                "# get the reach angle \n",
                "angle_real_val_behavior = np.arctan2(real_traj[:, 1, 50], real_traj[:, 0, 50])\n",
                "\n",
                "print(predicted_traj.shape, real_traj.shape)\n",
                "plt.figure(figsize=cm2inch(3, 2.2))\n",
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(9, 3), sharex=True, sharey=True)\n",
                "\n",
                "\n",
                "for i, idx in enumerate(range(0, 70, 5)):\n",
                "    ax[0].plot(\n",
                "        real_traj[idx, 0, :], real_traj[idx, 1, :],\n",
                "         label=\"real trajectory\", color=angle_to_color(angle_real_val_behavior[idx]),\n",
                "         alpha=0.5,\n",
                "        lw=0.5,\n",
                "    )\n",
                "    ax[0].plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=angle_to_color(angle_real_val_behavior[idx]),\n",
                "        alpha=0.5,\n",
                "        lw=0.5,\n",
                "    )\n",
                "ax[0].set_xlabel(\"x position\")\n",
                "ax[0].set_ylabel(\"y position\")\n",
                "# plt.legend()\n",
                "ax[0].set_title(\"data vs. decoded\")\n",
                "\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "# set ticks formatting for colorbar in terms of pi\n",
                "# cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), label=\"behavior angle\", ax=plt.gca())\n",
                "# cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), ax=ax[0])\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "\n",
                "\n",
                "# create cmap based on behavior angles and plot behavior_cumsum on 2d plot \n",
                "cmap = plt.get_cmap(\"hsv\")\n",
                "colors = cmap(angles)\n",
                "# angles = (angles * 20).numpy().tolist()\n",
                "for i, idx in enumerate(range(0, 70)):\n",
                "    ax[1].plot(\n",
                "        predicted_sampled_traj[idx, 0, :],\n",
                "        predicted_sampled_traj[idx, 1, :],\n",
                "        # label=\"Predicted trajectory\",\n",
                "        linestyle=\"-\",\n",
                "        color=colors[i],\n",
                "        alpha=0.5,\n",
                "        lw=0.5,\n",
                "    ) \n",
                "ax[1].set_xlabel(\"x position\")\n",
                "ax[1].set_title(\"conditioned on angle\")\n",
                "# set ticks formatting for colorbar in terms of pi\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), ax=ax[1])\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "\n",
                "##fig.tight_layout()\n",
                "plt.savefig(save_path + \"Fig_4_angle_cond_figure.png\")\n",
                "plt.savefig(save_path + \"Fig_4_angle_cond_figure.pdf\")\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(11, 5), sharex=True, sharey=True)\n",
                "\n",
                "colors = sns.color_palette(\"tab10\")\n",
                "for i, idx in enumerate(range(0, 70, 1)):\n",
                "    ax[0].plot(\n",
                "        real_traj[idx, 0, :],\n",
                "        real_traj[idx, 1, :],\n",
                "        label=\"real trajectory\",\n",
                "        color=colors[i%10],\n",
                "        alpha=0.5,\n",
                "        lw=0.5,\n",
                "    )\n",
                "    ax[0].plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=colors[i%10],\n",
                "        alpha=0.5,\n",
                "        lw=0.5,\n",
                "    )\n",
                "ax[0].set_xlabel(\"x position\")\n",
                "ax[0].set_ylabel(\"y position\")\n",
                "ax[0].set_title(\"real vs predicted trajectory\")\n",
                "\n",
                "# create cmap based on behavior angles and plot behavior_cumsum on 2d plot \n",
                "cmap = plt.get_cmap(\"hsv\")\n",
                "colors = cmap(angles)\n",
                "# angles = (angles * 20).numpy().tolist()\n",
                "for i, idx in enumerate(range(0, 100)):\n",
                "    ax[1].plot(\n",
                "        predicted_sampled_traj[idx, 0, :],\n",
                "        predicted_sampled_traj[idx, 1, :],\n",
                "        # label=\"Predicted trajectory\",\n",
                "        linestyle=\"-\",\n",
                "        color=colors[i],\n",
                "        alpha=0.5,\n",
                "        lw=0.5,\n",
                "    ) \n",
                "# ax[1].set_xlabel(\"X position\")\n",
                "# ax[1].set_ylabel(\"Y position\")\n",
                "# ax[1].set_title(\"Predicted trajectory (Sampled conditioned on $\\\\theta$)\")\n",
                "\n",
                "ax[1].set_xlabel(\"x position\")\n",
                "ax[1].set_ylabel(\"y position\")\n",
                "ax[1].set_title(\"diffusion conditioned on $\\\\theta$\")\n",
                "# set ticks formatting for colorbar in terms of pi\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), label=\"Behavior angle\", ax=ax[1])\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "\n",
                "fig.tight_layout()\n",
                "plt.savefig(save_path + \"angle_cond_figure.png\")\n",
                "plt.savefig(save_path + \"angle_cond_figure.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.figure(figsize=cm2inch(5, 3))\n",
                "plt.hist(latent_dataset_train.behavior_angles.flatten(), bins=200, label=\"train\")\n",
                "plt.hist(latent_dataset_test.behavior_angles.flatten(), bins=200, label=\"test\")\n",
                "plt.legend()\n",
                "# xticks in terms of pi\n",
                "plt.xticks(\n",
                "    np.linspace(-np.pi, np.pi, 5),\n",
                "    [r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"],\n",
                ")\n",
                "plt.ylabel(\"number of reaches\")\n",
                "plt.xlabel(\"behavior angle\")\n",
                "plt.savefig(save_path + \"behavior_angles.png\")\n",
                "plt.savefig(save_path + \"behavior_angles.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "behav_dict = gen_rates_and_train_decoded_behavior_conditioned_on_angle(\n",
                "    ema_denoiser=ema_model,\n",
                "    scheduler=scheduler,\n",
                "    ae=ae_model,\n",
                "    cfg=cfg,\n",
                "    train_latent_dataloader=train_latent_dataloader,\n",
                "    val_latent_dataloader=val_latent_dataloader,\n",
                "    num_samples=100,\n",
                "    device=\"cuda\",\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "(-np.pi/2*0.6+np.pi)/2/np.pi"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "behav_dict = gen_rates_and_train_decoded_behavior_conditioned_on_angle(\n",
                "    ema_denoiser=ema_model,\n",
                "    scheduler=scheduler,\n",
                "    ae=ae_model,\n",
                "    cfg=cfg,\n",
                "    train_latent_dataloader=train_latent_dataloader,\n",
                "    val_latent_dataloader=val_latent_dataloader,\n",
                "    num_samples=100,\n",
                "    device=\"cuda\",\n",
                "    angle=(-np.pi/2*0.6+np.pi)/2/np.pi,\n",
                ")\n",
                "\n",
                "predicted_val_behavior = behav_dict[\"predicted_val_behavior\"]\n",
                "real_val_behavior = behav_dict[\"real_val_behavior\"]\n",
                "predicted_sampled_behavior = behav_dict[\"predicted_sampled_behavior\"]\n",
                "angles = behav_dict[\"angles\"].numpy() + np.pi # shift to 0-2pi\n",
                "angles = angles / (2 * np.pi) # normalize to 0-1 for the colourmap\n",
                "\n",
                "\n",
                "predicted_traj = np.cumsum(predicted_val_behavior, axis=-1)\n",
                "real_traj = np.cumsum(real_val_behavior, axis=-1)\n",
                "predicted_sampled_traj = np.cumsum(predicted_sampled_behavior, axis=-1)\n",
                "\n",
                "print(predicted_traj.shape, real_traj.shape, predicted_sampled_traj.shape, angles.shape)\n",
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(12, 5))\n",
                "\n",
                "colors = sns.color_palette(\"tab10\")\n",
                "for i, idx in enumerate(range(0, 70, 7)):\n",
                "    ax[0].plot(\n",
                "        real_traj[idx, 0, :],\n",
                "        real_traj[idx, 1, :],\n",
                "        label=\"Real trajectory\",\n",
                "        color=colors[i],\n",
                "    )\n",
                "    ax[0].plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"Predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=colors[i],\n",
                "    )\n",
                "ax[0].set_xlabel(\"X position\")\n",
                "ax[0].set_ylabel(\"Y position\")\n",
                "ax[0].set_title(\"Real vs predicted trajectory (Val)\")\n",
                "\n",
                "# create cmap based on behavior angles and plot behavior_cumsum on 2d plot \n",
                "cmap = plt.get_cmap(\"hsv\")\n",
                "colors = cmap(angles)\n",
                "# angles = (angles * 20).numpy().tolist()\n",
                "for i, idx in enumerate(range(0, 20)):\n",
                "    ax[1].plot(\n",
                "        predicted_sampled_traj[idx, 0, :],\n",
                "        predicted_sampled_traj[idx, 1, :],\n",
                "        # label=\"Predicted trajectory\",\n",
                "        linestyle=\"-\",\n",
                "        color=colors[i],\n",
                "    ) \n",
                "ax[1].set_xlabel(\"X position\")\n",
                "ax[1].set_ylabel(\"Y position\")\n",
                "ax[1].set_title(\"Predicted trajectory (Sampled conditioned on $\\\\theta$)\")\n",
                "\n",
                "# set ticks formatting for colorbar in terms of pi\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), label=\"Behavior angle\", ax=ax[1])\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "\n",
                "fig.tight_layout()\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.hist(angles, bins=50)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "fig, ax = plt.subplots(1, 1, figsize=cm2inch(6.5, 5), sharex=True, sharey=True)\n",
                "\n",
                "# create cmap based on behavior angles and plot behavior_cumsum on 2d plot \n",
                "cmap = plt.get_cmap(\"hsv\")\n",
                "colors = cmap(angles)\n",
                "# angles = (angles * 20).numpy().tolist()\n",
                "for i, idx in enumerate(range(0, 100)):\n",
                "    ax.plot(\n",
                "        predicted_sampled_traj[idx, 0, :],\n",
                "        predicted_sampled_traj[idx, 1, :],\n",
                "        # label=\"Predicted trajectory\",\n",
                "        linestyle=\"-\",\n",
                "        color=colors[i],\n",
                "        alpha=0.5,\n",
                "        lw=0.5,\n",
                "    ) \n",
                "\n",
                "ax.set_xlabel(\"x position\")\n",
                "ax.set_ylabel(\"y position\")\n",
                "ax.set_title(\"diffusion conditioned on $\\\\theta$\")\n",
                "# set ticks formatting for colorbar in terms of pi\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), label=\"Behavior angle\", ax=ax)\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "\n",
                "fig.tight_layout()\n",
                "plt.savefig(save_path + \"angle_cond_figure_only_samples.png\")\n",
                "plt.savefig(save_path + \"angle_cond_figure_only_samples.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(11, 6), sharex=True, sharey=True)\n",
                "\n",
                "colors = sns.color_palette(\"tab10\")\n",
                "for i, idx in enumerate(range(0, 70, 1)):\n",
                "    ax[0].plot(\n",
                "        real_traj[idx, 0, :],\n",
                "        real_traj[idx, 1, :],\n",
                "        label=\"real trajectory\",\n",
                "        color=colors[i%10],\n",
                "        alpha=0.5,\n",
                "        lw=0.5,\n",
                "    )\n",
                "    ax[0].plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=colors[i%10],\n",
                "        alpha=0.5,\n",
                "        lw=0.5,\n",
                "    )\n",
                "ax[0].set_xlabel(\"x position\")\n",
                "ax[0].set_ylabel(\"y position\")\n",
                "ax[0].set_title(\"real vs predicted trajectory\")\n",
                "\n",
                "# create cmap based on behavior angles and plot behavior_cumsum on 2d plot \n",
                "cmap = plt.get_cmap(\"hsv\")\n",
                "colors = cmap(angles)\n",
                "# angles = (angles * 20).numpy().tolist()\n",
                "for i, idx in enumerate(range(0, 100)):\n",
                "    ax[1].plot(\n",
                "        predicted_sampled_traj[idx, 0, :],\n",
                "        predicted_sampled_traj[idx, 1, :],\n",
                "        # label=\"Predicted trajectory\",\n",
                "        linestyle=\"-\",\n",
                "        color=colors[i],\n",
                "        alpha=0.5,\n",
                "        lw=0.5,\n",
                "    ) \n",
                "# ax[1].set_xlabel(\"X position\")\n",
                "# ax[1].set_ylabel(\"Y position\")\n",
                "# ax[1].set_title(\"Predicted trajectory (Sampled conditioned on $\\\\theta$)\")\n",
                "\n",
                "ax[1].set_xlabel(\"x position\")\n",
                "ax[1].set_ylabel(\"y position\")\n",
                "ax[1].set_title(\"diffusion conditioned on $\\\\theta$\")\n",
                "# set ticks formatting for colorbar in terms of pi\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "# cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), label=\"Behavior angle\", ax=ax[1])\n",
                "# cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
                "\n",
                "fig.tight_layout()\n",
                "plt.savefig(save_path + \"angle_cond_figure_no_color_bar.png\")\n",
                "plt.savefig(save_path + \"angle_cond_figure_no_color_bar.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(10, 5))\n",
                "\n",
                "colors = sns.color_palette(\"tab20\")\n",
                "for i, idx in enumerate(range(0, 70, 1)):\n",
                "    ax[0].plot(\n",
                "        real_traj[idx, 0, :], real_traj[idx, 1, :], label=\"real trajectory\", color=colors[i%20], alpha=0.5,\n",
                "        lw=0.5,\n",
                "    )\n",
                "    ax[0].plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"Predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=colors[i%20],\n",
                "        alpha=0.5,\n",
                "        lw=0.5,\n",
                "    )\n",
                "ax[0].set_xlabel(\"x position\")\n",
                "ax[0].set_ylabel(\"y position\")\n",
                "ax[0].set_title(\"real vs predicted trajectory\")\n",
                "\n",
                "colors = sns.color_palette(\"tab20\")\n",
                "for i, idx in enumerate(range(0, 50)):\n",
                "    ax[1].plot(\n",
                "        predicted_sampled_traj[idx, 0, :],\n",
                "        predicted_sampled_traj[idx, 1, :],\n",
                "        # label=\"Predicted trajectory\",\n",
                "        linestyle=\"-\",\n",
                "        color=colors[i%20],\n",
                "        alpha=0.5,\n",
                "        lw=0.5,\n",
                "    )\n",
                "ax[1].set_xlabel(\"x position\")\n",
                "ax[1].set_ylabel(\"y position\")\n",
                "ax[1].set_title(\"unconditional diffusion\")\n",
                "\n",
                "fig.tight_layout()\n",
                "\n",
                "plt.savefig(save_path + \"unconditional_figure_no_color_bar.png\")\n",
                "plt.savefig(save_path + \"unconditional_figure_no_color_bar.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(10, 5))\n",
                "\n",
                "colors = sns.color_palette(\"tab20\")\n",
                "for i, idx in enumerate(range(0, 70, 1)):\n",
                "    ax[0].plot(\n",
                "        real_traj[idx, 0, :], real_traj[idx, 1, :], label=\"real trajectory\", color=colors[i%20]\n",
                "    )\n",
                "    ax[0].plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"Predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=colors[i%20],\n",
                "    )\n",
                "ax[0].set_xlabel(\"x position\")\n",
                "ax[0].set_ylabel(\"y position\")\n",
                "ax[0].set_title(\"real vs predicted trajectory\")\n",
                "\n",
                "colors = sns.color_palette(\"tab20\")\n",
                "for i, idx in enumerate(range(0, 50)):\n",
                "    ax[1].plot(\n",
                "        predicted_sampled_traj[idx, 0, :],\n",
                "        predicted_sampled_traj[idx, 1, :],\n",
                "        # label=\"Predicted trajectory\",\n",
                "        linestyle=\"-\",\n",
                "        color=colors[i%20],\n",
                "    )\n",
                "ax[1].set_xlabel(\"x position\")\n",
                "ax[1].set_ylabel(\"y position\")\n",
                "ax[1].set_title(\"unconditional diffusion\")\n",
                "\n",
                "fig.tight_layout()\n",
                "\n",
                "plt.savefig(save_path + \"unconditional_figure_no_color_bar_thick.png\")\n",
                "plt.savefig(save_path + \"unconditional_figure_no_color_bar_thick.pdf\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "timeseries",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.8.18"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
