{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "%load_ext autoreload\n",
                "%autoreload 2\n",
                "\n",
                "\n",
                "import argparse\n",
                "import os\n",
                "import sys\n",
                "\n",
                "ANONAUTHOR = False\n",
                "\n",
                "if ANONAUTHOR:\n",
                "    os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
                "    # append parent directory to path (../notebooks -> ..)\n",
                "    sys.path.append(os.path.dirname(os.getcwd()))\n",
                "    os.chdir(os.path.dirname(os.getcwd()))\n",
                "\n",
                "else:\n",
                "    os.chdir('../')\n",
                "\n",
                "\n",
                "import accelerate\n",
                "import auraloss  # freq loss\n",
                "import lovely_tensors as lt\n",
                "import matplotlib.pyplot as plt\n",
                "import matplotlib\n",
                "import numpy as np\n",
                "import scipy.io as io\n",
                "import seaborn as sns\n",
                "import torch\n",
                "import torch.nn as nn\n",
                "import wandb\n",
                "import yaml\n",
                "from diffusers.optimization import get_scheduler\n",
                "from omegaconf import OmegaConf\n",
                "from scipy.signal import welch\n",
                "from tqdm.auto import tqdm\n",
                "from einops import rearrange\n",
                "\n",
                "from ntldm.data.lds import get_lds_dataset\n",
                "from ntldm.data.latent_attractor import get_attractor_dataset\n",
                "from ntldm.data.lds import get_lds_dataset\n",
                "from ntldm.networks import S4AE, AutoEncoder, CountWrapper\n",
                "from ntldm.utils.plotting_utils import *\n",
                "from ntldm.losses import latent_regularizer\n",
                "from ntldm.networks import Denoiser\n",
                "from diffusers.training_utils import EMAModel\n",
                "from diffusers.schedulers import DDPMScheduler\n",
                "# always run from ../ntldm\n",
                "\n",
                "\n",
                "lt.monkey_patch()\n",
                "matplotlib.rc_file('matplotlibrc')\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "## load config and model path\n",
                "\n",
                "#cfg_ae = OmegaConf.load(\"conf/sweeps_count/autoencoder-count_s4-monkey_new.yaml\")\n",
                "\n",
                "cfg_ae = OmegaConf.load(\"conf/sweeps_count/autoencoder-count_s4-monkey_new_new_regularisation.yaml\")\n",
                "\n",
                "\n",
                "cfg_yaml = \"\"\"\n",
                "denoiser_model:\n",
                "  C_in: 16\n",
                "  C: 256\n",
                "  kernel: s4\n",
                "  num_blocks: 6\n",
                "  bidirectional: True\n",
                "  num_train_timesteps: 1000\n",
                "training:\n",
                "  lr: 0.001\n",
                "  weight_decay: 0.0\n",
                "  num_epochs: 2000\n",
                "  num_warmup_epochs: 50\n",
                "  batch_size: 512\n",
                "  random_seed: 42\n",
                "  precision: \"no\"\n",
                "exp_name: diffusion_s4-monkey_vel_cond_new_regularisation_140_epoch\n",
                "\"\"\"\n",
                "\n",
                "cfg = OmegaConf.create(yaml.safe_load(cfg_yaml))\n",
                "cfg.dataset = cfg_ae.dataset\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Same as the other notebook on diffusion on monmkey data, except this one is conditioned on the velocity traces.\n",
                " "
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "\n",
                "### 0. Load autoencoder (with checkpoint) and autoencoder dataset (run for all points 2-5)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "    \n",
                "ae_model = AutoEncoder(\n",
                "    C_in=cfg_ae.model.C_in,\n",
                "    C=cfg_ae.model.C,\n",
                "    C_latent=cfg_ae.model.C_latent,\n",
                "    L=cfg_ae.dataset.signal_length,\n",
                "    kernel=cfg_ae.model.kernel,\n",
                "    num_blocks=cfg_ae.model.num_blocks,\n",
                "    num_blocks_decoder=cfg_ae.model.get(\"num_blocks_decoder\", cfg_ae.model.num_blocks),\n",
                "    num_lin_per_mlp=cfg_ae.model.get(\"num_lin_per_mlp\", 2),  # default 2\n",
                "    bidirectional=cfg_ae.model.get(\"bidirectional\", True),\n",
                ")\n",
                "\n",
                "ae_model = CountWrapper(ae_model, use_sin_enc=cfg_ae.model.get(\"use_sin_enc\", False))\n",
                "\n",
                "\n",
                "\n",
                "\n",
                "import math\n",
                "from ntldm.data.monkey import get_monkey_dataloaders\n",
                "\n",
                "# set seed\n",
                "torch.manual_seed(cfg.training.random_seed)\n",
                "np.random.seed(cfg.training.random_seed)\n",
                "\n",
                "train_dataloader, val_dataloader, test_dataloader = get_monkey_dataloaders(\n",
                "        cfg_ae.dataset.task, cfg_ae.dataset.datapath, bin_width=5, batch_size=cfg_ae.training.batch_size\n",
                "    )\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "accelerator = accelerator = accelerate.Accelerator(\n",
                "    mixed_precision='no',\n",
                "    log_with=\"wandb\",\n",
                ")\n",
                "#\n",
                "\n",
                "# prepare the ae model and dataset\n",
                "\n",
                "ae_model = accelerator.prepare(ae_model)\n",
                "    \n",
                "print(cfg_ae.exp_name)\n",
                "accelerator.load_state(f\"exp/{cfg_ae.exp_name}/epoch_140\") # best checkpoint CHANGED\n",
                "\n",
                "(\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ") = accelerator.prepare(\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ")\n",
                "\n",
                "def reconstruct_spikes(model, full_dataloader):\n",
                "    model.eval()\n",
                "    latents = []\n",
                "    spikes = []\n",
                "    rec_spikes = []\n",
                "    for batch in full_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates, z = model(signal)\n",
                "            z = z.cpu()\n",
                "        latents.append(z)\n",
                "        spikes.append(signal.cpu())\n",
                "        rec_spikes.append(torch.poisson(output_rates.cpu()))\n",
                "        \n",
                "    return {\n",
                "        'latents': torch.cat(latents, 0),\n",
                "        'spikes': torch.cat(spikes, 0),\n",
                "        'rec_spikes': torch.cat(rec_spikes, 0)\n",
                "    }\n",
                "\n",
                "rec_dict = reconstruct_spikes(ae_model, test_dataloader)\n",
                "\n",
                "\n",
                "# plot reconstructed spikes\n",
                "plt.figure(figsize=cm2inch((6, 4)))\n",
                "bins = np.linspace(0, 20, 20)\n",
                "plt.hist(rec_dict['spikes'].sum(2).flatten(), density=True, color='grey', bins=bins, alpha=0.5)\n",
                "plt.hist(rec_dict['rec_spikes'].sum(2).flatten(), density=True, color='darkblue', bins=bins, alpha=0.5)\n",
                "\n",
                "plt.legend(['gt', 'ae'])\n",
                "plt.title('spike count distribution (test set)')\n",
                "\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### 1. Create dataset containing behavior, behavior angle, spike dataset, latents from ae ✅ (run for all points 2-5)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# # create the latent dataset\n",
                "# class LatentMonkeyDataset(torch.utils.data.Dataset):\n",
                "#     def __init__(\n",
                "#         self, dataloader, ae_model, clip=True, latent_means=None, latent_stds=None\n",
                "#     ):\n",
                "#         self.full_dataloader = dataloader\n",
                "#         self.ae_model = ae_model\n",
                "#         self.latents, self.train_spikes, self.behavior = self.create_latents()\n",
                "#         # normalize to N(0, 1)\n",
                "#         if latent_means is None or latent_stds is None:\n",
                "#             self.latent_means = self.latents.mean(dim=(0, 2)).unsqueeze(0).unsqueeze(2)\n",
                "#             self.latent_stds = self.latents.std(dim=(0, 2)).unsqueeze(0).unsqueeze(2)\n",
                "#         else:\n",
                "#             self.latent_means = latent_means\n",
                "#             self.latent_stds = latent_stds\n",
                "#         self.latents = (self.latents - self.latent_means) / self.latent_stds\n",
                "#         if clip:\n",
                "#             self.latents = self.latents.clamp(-5, 5)\n",
                "\n",
                "#         assert len(self.latents) == len(self.behavior) and len(self.latents) == len(\n",
                "#             self.train_spikes\n",
                "#         ), f\"Lengths of latents, behavior, and spikes do not match: {len(self.latents)}, {len(self.behavior)}, {len(self.train_spikes)}\"\n",
                "\n",
                "#         self.behavior = self.behavior / 1e3  # better this way\n",
                "\n",
                "#         self.behavior_cumsum = torch.cumsum(self.behavior, dim=-1)\n",
                "\n",
                "#         self.behavior_angles = torch.atan2(\n",
                "#             self.behavior[:, 1, 50], self.behavior[:, 0, 50]\n",
                "#         )\n",
                "#         self.behavior_angles = rearrange(self.behavior_angles, \"B -> B 1\")\n",
                "\n",
                "#     def create_latents(self):\n",
                "#         latent_dataset = []\n",
                "#         train_spikes = []\n",
                "#         behavior = []\n",
                "#         self.ae_model.eval()\n",
                "#         for i, batch in tqdm(\n",
                "#             enumerate(self.full_dataloader),\n",
                "#             total=len(self.full_dataloader),\n",
                "#             desc=\"Creating latent dataset\",\n",
                "#         ):\n",
                "#             with torch.no_grad():\n",
                "#                 z = self.ae_model.encode(batch[\"signal\"])\n",
                "#                 latent_dataset.append(z.cpu())\n",
                "#                 train_spikes.append(batch[\"signal\"].cpu())\n",
                "#                 behavior.append(batch[\"behavior\"].cpu())\n",
                "#         return torch.cat(latent_dataset), torch.cat(train_spikes), torch.cat(behavior)\n",
                "\n",
                "#     def __len__(self):\n",
                "#         return len(self.latents)\n",
                "\n",
                "#     def __getitem__(self, idx):\n",
                "#         return {\n",
                "#             \"signal\": self.train_spikes[idx],\n",
                "#             \"latent\": self.latents[idx],\n",
                "#             \"behavior\": self.behavior[idx],\n",
                "#             \"behavior_angle\": self.behavior_angles[idx],\n",
                "#         }\n",
                "\n",
                "\n",
                "from ntldm.data.monkey import LatentMonkeyDataset\n",
                "latent_dataset_train = LatentMonkeyDataset(train_dataloader, ae_model, clip=False)\n",
                "latent_dataset_val = LatentMonkeyDataset(\n",
                "    val_dataloader,\n",
                "    ae_model,\n",
                "    latent_means=latent_dataset_train.latent_means,\n",
                "    latent_stds=latent_dataset_train.latent_stds,\n",
                "    clip=False,\n",
                ")\n",
                "latent_dataset_test = LatentMonkeyDataset(\n",
                "    test_dataloader,\n",
                "    ae_model,\n",
                "    latent_means=latent_dataset_train.latent_means,\n",
                "    latent_stds=latent_dataset_train.latent_stds,\n",
                "    clip=False,\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "print(\"latent dataset\", latent_dataset_train.latents)\n",
                "print(\"latent dataset means\", latent_dataset_train.latent_means)\n",
                "print(\"latent dataset stds\", latent_dataset_train.latent_stds)\n",
                "plt.figure(figsize=cm2inch(5, 3))\n",
                "hist = plt.hist(latent_dataset_train.latents[:100].flatten(), bins=200, density=True)\n",
                "hist = plt.hist(latent_dataset_val.latents[:100].flatten(), bins=200, density=True)\n",
                "plt.title(\"Latent dataset histogram\")\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.hist(latent_dataset_train.behavior_angles.flatten(), bins=200)\n",
                "# xticks in terms of pi\n",
                "plt.xticks(\n",
                "    np.linspace(-np.pi, np.pi, 5),\n",
                "    [r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"],\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# create cmap based on behavior angles and plot behavior_cumsum on 2d plot \n",
                "cmap = plt.get_cmap(\"hsv\")\n",
                "colors = cmap((latent_dataset_train.behavior_angles.squeeze() + np.pi) / (2 * np.pi))\n",
                "plt.figure(figsize=cm2inch(8, 4))\n",
                "\n",
                "for i in range(0, len(latent_dataset_train.behavior), 10):\n",
                "    plt.plot(\n",
                "        latent_dataset_train.behavior_cumsum[i, 0],\n",
                "        latent_dataset_train.behavior_cumsum[i, 1],\n",
                "        color=colors[i],\n",
                "        alpha=0.3,\n",
                "    )\n",
                "# colormap\n",
                "sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(-np.pi, np.pi))\n",
                "sm.set_array([])\n",
                "# set ticks formatting for colorbar in terms of pi\n",
                "cbar = plt.colorbar(sm, ticks=np.linspace(-np.pi, np.pi, 5), label=\"Behavior angle\", ax=plt.gca())\n",
                "cbar.ax.set_yticklabels([r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", r\"$0$\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "train_latent_dataloader = torch.utils.data.DataLoader(\n",
                "    latent_dataset_train,\n",
                "    batch_size=cfg.training.batch_size,\n",
                "    shuffle=True,\n",
                "    num_workers=4,\n",
                "    pin_memory=True,\n",
                ")\n",
                "\n",
                "val_latent_dataloader = torch.utils.data.DataLoader(\n",
                "    latent_dataset_val,\n",
                "    batch_size=cfg.training.batch_size,\n",
                "    shuffle=False,\n",
                "    num_workers=4,\n",
                "    pin_memory=True,\n",
                ")\n",
                "\n",
                "num_batches = len(train_latent_dataloader)\n",
                "\n",
                "# check if signal length is power of 2\n",
                "if cfg.dataset.signal_length & (cfg.dataset.signal_length - 1) != 0:\n",
                "    cfg.training.precision = \"no\"  # torch.fft doesnt support half if L!=2^x\n",
                "\n",
                "# prepare the denoiser model and dataset\n",
                "(\n",
                "    train_latent_dataloader,\n",
                "    val_latent_dataloader,\n",
                ") = accelerator.prepare(\n",
                "    train_latent_dataloader,\n",
                "    val_latent_dataloader,\n",
                ")\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "train_dataloader.dataset"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "## sanity check: ae decoded signal, mapped to behavior (trained on train data) should be similar to the actual behavior\n",
                "\n",
                "\n",
                "def compute_decoded_behavior(model, train_latent_dataloader, val_latent_dataloader):\n",
                "\n",
                "    model.eval()\n",
                "\n",
                "    train_rates = []\n",
                "    train_behavior = []\n",
                "    for batch in train_latent_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        behavior = batch[\"behavior\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = model(signal)[0].cpu()\n",
                "        train_rates.append(output_rates)\n",
                "        train_behavior.append(behavior.cpu())\n",
                "\n",
                "    train_rates = torch.cat(train_rates, 0)  # [B C L]\n",
                "    train_behavior = torch.cat(train_behavior, 0)  # [B 2 L]\n",
                "    print(train_rates, train_behavior)\n",
                "\n",
                "    val_rates = []\n",
                "    val_behavior = []\n",
                "    for batch in val_latent_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        behavior = batch[\"behavior\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = model(signal)[0].cpu()\n",
                "        val_rates.append(output_rates)\n",
                "        val_behavior.append(behavior.cpu())\n",
                "\n",
                "    val_rates = torch.cat(val_rates, 0)  # [B C L]\n",
                "    val_behavior = torch.cat(val_behavior, 0)  # [B 2 L]\n",
                "    print(val_rates, val_behavior)\n",
                "\n",
                "    # decode rates to behavior\n",
                "    from sklearn.linear_model import Ridge\n",
                "\n",
                "    bs_train = train_rates.shape[0]\n",
                "    bs_val = val_rates.shape[0]\n",
                "    train_rates = rearrange(train_rates, \"b c l -> (b l) c\").numpy()\n",
                "    val_rates = rearrange(val_rates, \"b c l -> (b l) c\").numpy()\n",
                "    train_behavior = rearrange(train_behavior, \"b c l -> (b l) c\").numpy()\n",
                "    val_behavior = rearrange(val_behavior, \"b c l -> (b l) c\").numpy()\n",
                "\n",
                "    r2s_per_alpha = []\n",
                "\n",
                "    for i, alpha in enumerate([1e-6]):\n",
                "        RidgeRegressionModel = Ridge(alpha=alpha)\n",
                "        RidgeRegressionModel.fit(train_rates, train_behavior)\n",
                "        predicted_behavior = RidgeRegressionModel.predict(val_rates)\n",
                "        # r2 score\n",
                "        from sklearn.metrics import r2_score\n",
                "\n",
                "        r2 = r2_score(val_behavior, predicted_behavior)\n",
                "        r2s_per_alpha.append(r2)\n",
                "        return {\n",
                "            \"predicted_val_behavior\": rearrange(\n",
                "                predicted_behavior, \"(b l) c -> b c l\", b=bs_val\n",
                "            ),\n",
                "            \"real_val_behavior\": rearrange(val_behavior, \"(b l) c -> b c l\", b=bs_val),\n",
                "        }\n",
                "\n",
                "\n",
                "\n",
                "behav_dict = compute_decoded_behavior(\n",
                "    ae_model, train_latent_dataloader, val_latent_dataloader\n",
                ")\n",
                "predicted_val_behavior = behav_dict[\"predicted_val_behavior\"]\n",
                "real_val_behavior = behav_dict[\"real_val_behavior\"]\n",
                "\n",
                "predicted_traj = np.cumsum(predicted_val_behavior, axis=-1)\n",
                "real_traj = np.cumsum(real_val_behavior, axis=-1)\n",
                "print(predicted_traj.shape, real_traj.shape)\n",
                "plt.figure(figsize=cm2inch(8, 4))\n",
                "\n",
                "colors = sns.color_palette(\"tab10\")\n",
                "for i, idx in enumerate(range(0, 70, 7)):\n",
                "    plt.plot(\n",
                "        real_traj[idx, 0, :], real_traj[idx, 1, :], label=\"Real trajectory\", color=colors[i]\n",
                "    )\n",
                "    plt.plot(\n",
                "        predicted_traj[idx, 0, :],\n",
                "        predicted_traj[idx, 1, :],\n",
                "        label=\"Predicted trajectory\",\n",
                "        linestyle=\"--\",\n",
                "        color=colors[i],\n",
                "    )\n",
                "plt.xlabel(\"X position\")\n",
                "plt.ylabel(\"Y position\")\n",
                "# plt.legend()\n",
                "plt.title(\"Real vs predicted trajectory\")\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "\n",
                "### 2. Train diffusion to **conditionally** generate rates (and then spikes) on the velocity traces, then decode the generated rates into the velocity traces"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "## initialize (unconditional) denoiser\n",
                "\n",
                "denoiser = Denoiser(\n",
                "    C_in=cfg.denoiser_model.C_in + 2, # add 2 for behavior \n",
                "    C=cfg.denoiser_model.C,\n",
                "    L=cfg.dataset.signal_length,\n",
                "    kernel=cfg.denoiser_model.kernel,\n",
                "    num_blocks=cfg.denoiser_model.num_blocks,\n",
                "    bidirectional=cfg.denoiser_model.get(\"bidirectional\", True),\n",
                ")\n",
                "\n",
                "# initial values may be way off so better to scale down the output layer\n",
                "denoiser.conv_out.weight.data = denoiser.conv_out.weight.data * 0.1\n",
                "denoiser.conv_out.bias.data = denoiser.conv_out.bias.data * 0.1\n",
                "\n",
                "scheduler = DDPMScheduler(\n",
                "    num_train_timesteps=cfg.denoiser_model.num_train_timesteps,\n",
                "    clip_sample=False,\n",
                "    beta_schedule=\"linear\", # ddpm doesnt support cosine\n",
                ")\n",
                "\n",
                "\n",
                "optimizer = torch.optim.AdamW(\n",
                "    denoiser.parameters(), lr=cfg.training.lr\n",
                ")  # default wd=0.01 for now\n",
                "\n",
                "\n",
                "\n",
                "num_batches = len(train_latent_dataloader)\n",
                "lr_scheduler = get_scheduler(\n",
                "    name=\"cosine\",\n",
                "    optimizer=optimizer,\n",
                "    num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "    num_training_steps=num_batches * cfg.training.num_epochs * 1.3,  # total number of steps\n",
                ")\n",
                "\n",
                "# prepare the denoiser model and dataset\n",
                "(\n",
                "    denoiser,\n",
                "    optimizer,\n",
                "    lr_scheduler,\n",
                ") = accelerator.prepare(\n",
                "    denoiser,\n",
                "    optimizer,\n",
                "    lr_scheduler,\n",
                ")\n",
                "\n",
                "# accelerator.load_state(f\"exp/{cfg.exp_name}/epoch_1000\", model=denoiser)\n",
                "\n",
                "\n",
                "ema_model = EMAModel(denoiser)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "from ntldm.utils import count_parameters\n",
                "\n",
                "print(count_parameters(denoiser)/1e6, \"M parameters\")\n",
                "\n",
                "# accelerator.load_state(f\"exp/{cfg.exp_name}/epoch_200\") # 200 because i restarted the traiinng with ep=0\n",
                "# ema_model = EMAModel(denoiser)\n",
                "\n",
                "\n",
                "ae_model, denoiser"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def sample_with_velocity(\n",
                "    ema_denoiser,\n",
                "    scheduler,\n",
                "    cfg,\n",
                "    behavior_vel,\n",
                "    batch_size=1,\n",
                "    generator=None,\n",
                "    device=\"cuda\",\n",
                "    signal_length=None,\n",
                "):\n",
                "    if signal_length is None:\n",
                "        signal_length = cfg.dataset.signal_length\n",
                "\n",
                "    assert (\n",
                "        behavior_vel.shape[0] == batch_size\n",
                "    ), \"Velocity shape should be [B 2 L], B should match batch size\"\n",
                "    assert (\n",
                "        behavior_vel.shape[1] == 2\n",
                "    ), \"Velocity shape should be [B 2 L], 2 for x and y velocity\"\n",
                "    assert (\n",
                "        behavior_vel.shape[2] == signal_length\n",
                "    ), \"Velocity shape should be [B 2 L], L should match signal length\"\n",
                "\n",
                "    behavior_vel = behavior_vel.to(device)\n",
                "\n",
                "    z_t = torch.randn((batch_size, cfg.denoiser_model.C_in, signal_length)).to(device)\n",
                "    ema_denoiser_avg = ema_denoiser.averaged_model\n",
                "    ema_denoiser_avg.eval()\n",
                "\n",
                "    scheduler.set_timesteps(cfg.denoiser_model.num_train_timesteps)\n",
                "\n",
                "    for t in tqdm(scheduler.timesteps, desc=\"Sampling DDPM\"):\n",
                "        with torch.no_grad():\n",
                "            model_output = ema_denoiser_avg(\n",
                "                torch.cat([z_t, behavior_vel], dim=1),\n",
                "                torch.tensor([t] * batch_size).to(device).long(),\n",
                "            )\n",
                "            model_output = model_output[:, :-2]\n",
                "    \n",
                "        z_t = scheduler.step(\n",
                "            model_output, t, z_t, generator=generator, return_dict=False\n",
                "        )[0]\n",
                "\n",
                "    return z_t\n",
                "\n",
                "\n",
                "def sample_spikes_with_velocity(\n",
                "    ema_denoiser, scheduler, ae, cfg, behavior_vel, batch_size=1, device=\"cuda\", num_samples_per_batch=1,\n",
                "):\n",
                "\n",
                "    assert (\n",
                "        behavior_vel.shape[0] == batch_size\n",
                "    ), \"Velocity shape should be [B 2 L], B should match batch size\"\n",
                "    assert (\n",
                "        behavior_vel.shape[1] == 2\n",
                "    ), \"Velocity shape should be [B 2 L], 2 for x and y velocity\"\n",
                "    assert (\n",
                "        behavior_vel.shape[2] == cfg.dataset.signal_length\n",
                "    ), \"Velocity shape should be [B 2 L], L should match signal length\"\n",
                "\n",
                "    behavior_vel = behavior_vel.to(device)\n",
                "    behavior_vel = repeat(behavior_vel, \"B C L -> (S B) C L\", S=num_samples_per_batch)\n",
                "\n",
                "    z_t = torch.randn(\n",
                "        (num_samples_per_batch, batch_size, cfg.denoiser_model.C_in, cfg.dataset.signal_length)\n",
                "    ).to(device)\n",
                "\n",
                "    z_t = rearrange(z_t, \"S B C L -> (S B) C L\")\n",
                "\n",
                "    ema_denoiser_avg = ema_denoiser.averaged_model\n",
                "    ema_denoiser_avg.eval()\n",
                "    scheduler.set_timesteps(cfg.denoiser_model.num_train_timesteps)\n",
                "\n",
                "    for t in tqdm(scheduler.timesteps, desc=\"Sampling DDPM\"):\n",
                "        with torch.no_grad():\n",
                "            model_output = ema_denoiser_avg(\n",
                "                torch.cat([z_t, behavior_vel], dim=1),\n",
                "                torch.tensor([t] * batch_size).to(device).long(),\n",
                "            )\n",
                "        z_t = scheduler.step(model_output, t, z_t, return_dict=False)[0]\n",
                "\n",
                "    z_t = z_t * latent_dataset_train.latent_stds.to(\n",
                "        z_t.device\n",
                "    ) + latent_dataset_train.latent_means.to(z_t.device)\n",
                "\n",
                "    with torch.no_grad():\n",
                "        rates = ae.decode(z_t).cpu()\n",
                "\n",
                "    spikes = torch.poisson(rates)\n",
                "\n",
                "    z_t = rearrange(z_t, \"(S B) C L -> S B C L\", S=num_samples_per_batch)\n",
                "    spikes = rearrange(spikes, \"(S B) C L -> S B C L\", S=num_samples_per_batch)\n",
                "    rates = rearrange(rates, \"(S B) C L -> S B C L\", S=num_samples_per_batch)\n",
                "\n",
                "    return {\n",
                "        \"spikes\": spikes,\n",
                "        \"rates\": rates,\n",
                "        \"latents\": z_t,\n",
                "        \n",
                "    }"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# ## to restart training from the last checkpoint\n",
                "\n",
                "# lr_scheduler = get_scheduler(\n",
                "#     name=\"cosine\",\n",
                "#     optimizer=optimizer,\n",
                "#     num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "#     num_training_steps=num_batches * cfg.training.num_epochs * 1.3,  # total number of steps\n",
                "# )\n",
                "# lr_scheduler = accelerator.prepare(lr_scheduler)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "cfg.exp_name"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
                "\n",
                "loss_fn = torch.nn.SmoothL1Loss(\n",
                "    beta=0.04, reduction=\"mean\"\n",
                ")  # faster convergence than mse\n",
                "\n",
                "pbar = tqdm(range(0, cfg.training.num_epochs), desc=\"epochs\")\n",
                "for epoch in pbar:\n",
                "    for i, batch in enumerate(train_latent_dataloader):\n",
                "        denoiser.train()\n",
                "        optimizer.zero_grad()\n",
                "\n",
                "        z = batch[\"latent\"]\n",
                "        behavior_vel = batch['behavior']\n",
                "        t = torch.randint(\n",
                "            0, cfg.denoiser_model.num_train_timesteps, (z.shape[0],), device=\"cpu\"\n",
                "        ).long()\n",
                "        # print(z.shape, t.shape)\n",
                "        noise = torch.randn_like(z)\n",
                "        noisy_z = scheduler.add_noise(z, noise, t)\n",
                "        noise_pred = denoiser(torch.cat([noisy_z,behavior_vel], dim=1), t)\n",
                "        noise_pred = noise_pred[:, :-2]\n",
                "\n",
                "        # loss = torch.nn.functional.mse_loss(noise, noise_pred) * 0.5\n",
                "        # loss = loss + (noise - noise_pred).abs().mean() * 0.5  # l1 loss\n",
                "\n",
                "        loss = loss_fn(noise, noise_pred)\n",
                "        # loss = (noise - noise_pred).abs().mean()\n",
                "\n",
                "        accelerator.backward(loss)\n",
                "        accelerator.clip_grad_norm_(denoiser.parameters(), 1.0)\n",
                "\n",
                "        optimizer.step()\n",
                "        lr_scheduler.step()\n",
                "\n",
                "        if i % 10 == 0:\n",
                "            pbar.set_postfix({\"loss\": loss.item(), \"lr\": lr_scheduler.get_last_lr()[0]})\n",
                "\n",
                "        ema_model.step(denoiser)\n",
                "\n",
                "    if (epoch) % 100 == 0:  # plot samples\n",
                "\n",
                "        denoiser.eval()\n",
                "\n",
                "        sampled_latents = sample_with_velocity(\n",
                "            ema_denoiser=ema_model,\n",
                "            scheduler=scheduler,\n",
                "            cfg=cfg,\n",
                "            behavior_vel=latent_dataset_train.behavior[:2],\n",
                "            batch_size=2,\n",
                "            device=\"cuda\",\n",
                "        )\n",
                "        sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "            sampled_latents.device\n",
                "        ) + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "\n",
                "        real_latents = latent_dataset_train.latents[:2].cuda()\n",
                "        real_latents = real_latents * latent_dataset_train.latent_stds.to(\n",
                "            real_latents.device\n",
                "        ) + latent_dataset_train.latent_means.to(real_latents.device)\n",
                "\n",
                "        with torch.no_grad():\n",
                "            sampled_rates = ae_model.decode(sampled_latents).cpu()\n",
                "            decoded_rates_from_real_latents = ae_model.decode(real_latents).cpu()\n",
                "\n",
                "        fig, ax = plt.subplots(1, 2, figsize=cm2inch(12, 4))\n",
                "        im = ax[0].imshow(sampled_rates[0], aspect=\"auto\")\n",
                "        ax[0].set_title(\"Sampled rates\")\n",
                "        fig.colorbar(im, ax=ax[0], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "\n",
                "        im = ax[1].imshow(decoded_rates_from_real_latents[0], aspect=\"auto\")\n",
                "        ax[1].set_title(\"Real rates\")\n",
                "        fig.colorbar(im, ax=ax[1], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "        fig.tight_layout()\n",
                "        plt.show()\n",
                "\n",
                "        # # get avg spike count across neurons\n",
                "        # real_spikes = latent_dataset_val.train_spikes.cpu()\n",
                "        # gen_spikes, gen_rates = sample_spikes(\n",
                "        #     ema_model,\n",
                "        #     scheduler,\n",
                "        #     ae_model,\n",
                "        #     cfg,\n",
                "        #     batch_size=real_spikes.shape[0] * 10,\n",
                "        #     device=\"cuda\",\n",
                "        # )\n",
                "\n",
                "        # spike_count_means = real_spikes.mean(dim=(0, 2))\n",
                "        # spike_count_stds = real_spikes.std(dim=(0, 2))\n",
                "        # gen_spike_count_means = gen_spikes.mean(dim=(0, 2))\n",
                "        # gen_spike_count_stds = gen_spikes.std(dim=(0, 2))\n",
                "\n",
                "        # print(real_spikes, gen_spikes, spike_count_means, spike_count_stds)\n",
                "\n",
                "        # # plot boxplots across neurons\n",
                "        # plt.figure(figsize=cm2inch(4, 4))\n",
                "        # plt.violinplot(\n",
                "        #     [spike_count_means.numpy(), gen_spike_count_means.numpy()],\n",
                "        #     positions=[1, 2],\n",
                "        #     showmeans=True,\n",
                "        # )\n",
                "        # plt.ylabel(\"Mean spike count\")\n",
                "        # plt.title(\"Mean spike count across neurons\")\n",
                "        # plt.xticks([1, 2], [\"Real\", \"Generated\"])\n",
                "        # plt.show()\n",
                "\n",
                "        # # plot boxplots per neuron\n",
                "        # from einops import reduce\n",
                "\n",
                "        # spike_count_per_neuron = reduce(real_spikes, \"B C L -> B C\", reduction=\"mean\")\n",
                "        # gen_spike_count_per_neuron = reduce(\n",
                "        #     gen_spikes, \"B C L -> B C\", reduction=\"mean\"\n",
                "        # )\n",
                "\n",
                "        # # sort channels by mean spike count\n",
                "        # sorted_indices = spike_count_means.argsort()\n",
                "        # sorted_indices = torch.flip(sorted_indices, (0,))\n",
                "        # print(sorted_indices)\n",
                "        # spike_count_per_neuron = spike_count_per_neuron[:, sorted_indices]\n",
                "        # gen_spike_count_per_neuron = gen_spike_count_per_neuron[:, sorted_indices]\n",
                "\n",
                "        # plt.figure(figsize=(8, 4))\n",
                "        # for i, (spike_count, gen_spike_count) in enumerate(\n",
                "        #     zip(\n",
                "        #         spike_count_per_neuron[:, ::10].T, gen_spike_count_per_neuron[:, ::10].T\n",
                "        #     )\n",
                "        # ):\n",
                "        #     plt.violinplot(\n",
                "        #         [spike_count.numpy(), gen_spike_count.numpy()],\n",
                "        #         positions=[i, i + 0.5],\n",
                "        #         showmeans=True,\n",
                "        #     )\n",
                "        #     # scatter plot across the violinplot for better visualization\n",
                "        #     plt.scatter(\n",
                "        #         [i] * len(spike_count), spike_count.numpy(), color=\"black\", alpha=0.1\n",
                "        #     )\n",
                "        #     plt.scatter(\n",
                "        #         [i + 0.5] * len(gen_spike_count),\n",
                "        #         gen_spike_count.numpy(),\n",
                "        #         color=\"black\",\n",
                "        #         alpha=0.1,\n",
                "        #     )\n",
                "\n",
                "        # plt.xticks(\n",
                "        #     np.arange(len(sorted_indices[::10])) + 0.5,\n",
                "        #     sorted_indices[::10].numpy().tolist(),\n",
                "        # )\n",
                "        # plt.ylabel(\"Spike count\")\n",
                "        # plt.title(\"Spike count per neuron (real vs generated)\")\n",
                "        # plt.xlabel(\"neuron index\")\n",
                "        # plt.yscale(\"symlog\", linthresh=0.001)\n",
                "        # plt.ylim(\n",
                "        #     -0.0,\n",
                "        #     max(\n",
                "        #         spike_count_per_neuron.max().item(),\n",
                "        #         gen_spike_count_per_neuron.max().item(),\n",
                "        #     )\n",
                "        #     + 0.1,\n",
                "        # )\n",
                "        # plt.show()\n",
                "\n",
                "        # save\n",
                "        accelerator.save_state(f\"exp/{cfg.exp_name}/epoch_{epoch}\")\n",
                "\n",
                "pbar.close()\n",
                "\n",
                "save_new_model = True\n",
                "load_model = False\n",
                "\n",
                "if save_new_model:\n",
                "\n",
                "    if accelerator.is_main_process:\n",
                "        os.makedirs(f\"exp/{cfg.exp_name}\", exist_ok=True)\n",
                "        torch.save(accelerator.unwrap_model(denoiser).state_dict(), f\"exp/{cfg.exp_name}/model.pt\")\n",
                "        with open(f\"conf/sweeps_count/{cfg.exp_name}_save_after_train.yaml\", \"w\") as f:\n",
                "            f.write(OmegaConf.to_yaml(cfg))\n",
                "\n",
                "        print('saved new model at ', f\"exp/{cfg.exp_name}/model.pt\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "f'exp/{cfg.exp_name}'"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "save_new_model = False\n",
                "load_model = True\n",
                "\n",
                "if save_new_model:\n",
                "\n",
                "    if accelerator.is_main_process:\n",
                "        os.makedirs(f\"exp/{cfg.exp_name}\", exist_ok=True)\n",
                "        torch.save(accelerator.unwrap_model(denoiser).state_dict(), f\"exp/{cfg.exp_name}/model.pt\")\n",
                "        with open(f\"conf/sweeps_count/{cfg.exp_name}_save_after_train.yaml\", \"w\") as f:\n",
                "            f.write(OmegaConf.to_yaml(cfg))\n",
                "\n",
                "        print('saved new model at ', f\"exp/{cfg.exp_name}/model.pt\")\n",
                "elif load_model:\n",
                "    ## initialize (unconditional) denoiser\n",
                "    with open(f\"conf/sweeps_count/{cfg.exp_name}_save_after_train.yaml\") as f:\n",
                "        cfg = OmegaConf.create(yaml.safe_load(f))\n",
                "    \n",
                "    denoiser = Denoiser(\n",
                "        C_in=cfg.denoiser_model.C_in + 2, \n",
                "        C=cfg.denoiser_model.C,\n",
                "        L=cfg.dataset.signal_length,\n",
                "        kernel=cfg.denoiser_model.kernel,\n",
                "        num_blocks=cfg.denoiser_model.num_blocks,\n",
                "        bidirectional=cfg.denoiser_model.get(\"bidirectional\", True),\n",
                "    )\n",
                "\n",
                "\n",
                "    # initial values may be way off so better to scale down the output layer\n",
                "    denoiser.conv_out.weight.data = denoiser.conv_out.weight.data * 0.1\n",
                "    denoiser.conv_out.bias.data = denoiser.conv_out.bias.data * 0.1\n",
                "\n",
                "    scheduler = DDPMScheduler(\n",
                "        num_train_timesteps=cfg.denoiser_model.num_train_timesteps,\n",
                "        clip_sample=False,\n",
                "        beta_schedule=\"linear\", # ddpm doesnt support cosine\n",
                "    )\n",
                "\n",
                "\n",
                "    optimizer = torch.optim.AdamW(\n",
                "        denoiser.parameters(), lr=cfg.training.lr\n",
                "    )  # default wd=0.01 for now\n",
                "\n",
                "\n",
                "\n",
                "    num_batches = len(train_latent_dataloader)\n",
                "    lr_scheduler = get_scheduler(\n",
                "        name=\"cosine\",\n",
                "        optimizer=optimizer,\n",
                "        num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "        num_training_steps=num_batches * cfg.training.num_epochs * 1.3,  # total number of steps\n",
                "    )\n",
                "    \n",
                "    denoiser.load_state_dict(torch.load(f\"exp/{cfg.exp_name}/model.pt\", map_location=\"cpu\"))\n",
                "\n",
                "    # prepare the denoiser model and dataset\n",
                "    (\n",
                "        denoiser,\n",
                "        optimizer,\n",
                "        lr_scheduler,\n",
                "    ) = accelerator.prepare(\n",
                "        denoiser,\n",
                "        optimizer,\n",
                "        lr_scheduler,\n",
                "    )\n",
                "\n",
                "    # accelerator.load_state(f\"exp/{cfg.exp_name}/epoch_1000\", model=denoiser)\n",
                "\n",
                "\n",
                "    ema_model = EMAModel(denoiser)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "denoiser.eval()\n",
                "\n",
                "sampled_latents = sample_with_velocity(\n",
                "    ema_denoiser=ema_model,\n",
                "    scheduler=scheduler,\n",
                "    cfg=cfg,\n",
                "    behavior_vel=latent_dataset_train.behavior[:2],\n",
                "    batch_size=2,\n",
                "    device=\"cuda\",\n",
                ")\n",
                "\n",
                "sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "    sampled_latents.device\n",
                ") + latent_dataset_train.latent_means.to(sampled_latents.device)\n",
                "\n",
                "real_latents = latent_dataset_train.latents[:2].cuda()\n",
                "real_latents = real_latents * latent_dataset_train.latent_stds.to(\n",
                "    real_latents.device\n",
                ") + latent_dataset_train.latent_means.to(real_latents.device)\n",
                "\n",
                "with torch.no_grad():\n",
                "    sampled_rates = ae_model.decode(sampled_latents).cpu()\n",
                "    decoded_rates_from_real_latents = ae_model.decode(real_latents).cpu()\n",
                "\n",
                "fig, ax = plt.subplots(1, 2, figsize=cm2inch(12, 4))\n",
                "im = ax[0].imshow(sampled_rates[0], aspect=\"auto\")\n",
                "ax[0].set_title(\"Sampled rates\")\n",
                "fig.colorbar(im, ax=ax[0], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "\n",
                "im = ax[1].imshow(decoded_rates_from_real_latents[0], aspect=\"auto\")\n",
                "ax[1].set_title(\"Real rates\")\n",
                "fig.colorbar(im, ax=ax[1], orientation=\"vertical\", fraction=0.046, pad=0.04)\n",
                "fig.tight_layout()\n",
                "plt.show()\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import torch\n",
                "from einops import rearrange\n",
                "from sklearn.linear_model import Ridge\n",
                "from sklearn.metrics import r2_score\n",
                "\n",
                "\n",
                "# function to train ridge regression on training data\n",
                "def train_ridge_regression(train_rates, train_behavior, alpha=1e-6):\n",
                "    # reshape and convert to numpy\n",
                "    train_rates = rearrange(train_rates, \"b c l -> (b l) c\").numpy()\n",
                "    train_behavior = rearrange(train_behavior, \"b c l -> (b l) c\").numpy()\n",
                "\n",
                "    # train ridge regression\n",
                "    ridge_regression_model = Ridge(alpha=alpha)\n",
                "    ridge_regression_model.fit(train_rates, train_behavior)\n",
                "\n",
                "    return ridge_regression_model\n",
                "\n",
                "\n",
                "# function to evaluate ridge regression model on validation data\n",
                "def evaluate_ridge_regression(ridge_regression_model, val_rates, val_behavior):\n",
                "    # reshape and convert to numpy\n",
                "    bs_val = val_rates.shape[0]\n",
                "    val_rates = rearrange(val_rates, \"b c l -> (b l) c\").numpy()\n",
                "    val_behavior = rearrange(val_behavior, \"b c l -> (b l) c\").numpy()\n",
                "\n",
                "    # predict behavior using ridge regression model\n",
                "    predicted_behavior = ridge_regression_model.predict(val_rates)\n",
                "\n",
                "    # calculate r2 score\n",
                "    r2 = r2_score(val_behavior, predicted_behavior)\n",
                "    print(f\"r2 score on val: {r2:.3f}\")\n",
                "\n",
                "    return {\n",
                "        \"predicted_val_behavior\": rearrange(\n",
                "            predicted_behavior, \"(b l) c -> b c l\", b=bs_val\n",
                "        ),\n",
                "        \"real_val_behavior\": rearrange(val_behavior, \"(b l) c -> b c l\", b=bs_val),\n",
                "    }\n",
                "\n",
                "from einops import repeat\n",
                "# function to generate rates and train decoded behavior\n",
                "def gen_rates_and_train_decoded_behavior(\n",
                "    ema_denoiser,\n",
                "    scheduler,\n",
                "    ae,\n",
                "    cfg,\n",
                "    train_latent_dataloader,\n",
                "    val_latent_dataloader,\n",
                "    num_samples=100,\n",
                "    device=\"cuda\",\n",
                "    num_samples_per_batch=1,\n",
                "    test_velocity=None,\n",
                "):\n",
                "    avg_denoiser = ema_denoiser.averaged_model\n",
                "    avg_denoiser.eval()\n",
                "    ae.eval()\n",
                "\n",
                "    train_rates = []\n",
                "    train_behavior = []\n",
                "    for batch in train_latent_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        behavior = batch[\"behavior\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = ae(signal)[0].cpu()\n",
                "        train_rates.append(output_rates)\n",
                "        train_behavior.append(behavior.cpu())\n",
                "\n",
                "    train_rates = torch.cat(train_rates, 0)  # [b c l]\n",
                "    train_behavior = torch.cat(train_behavior, 0)  # [b 2 l]\n",
                "    print(train_rates, train_behavior)\n",
                "\n",
                "    val_rates = []\n",
                "    val_behavior = []\n",
                "    for batch in val_latent_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        behavior = batch[\"behavior\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = ae(signal)[0].cpu()\n",
                "        val_rates.append(output_rates)\n",
                "        val_behavior.append(behavior.cpu())\n",
                "\n",
                "    val_rates = torch.cat(val_rates, 0)  # [b c l]\n",
                "    val_behavior = torch.cat(val_behavior, 0)  # [b 2 l]\n",
                "    print(val_rates, val_behavior)\n",
                "\n",
                "    # train ridge regression model on training data\n",
                "    ridge_regression_model = train_ridge_regression(train_rates, train_behavior)\n",
                "\n",
                "    # evaluate ridge regression model on validation data\n",
                "    evaluation_results = evaluate_ridge_regression(\n",
                "        ridge_regression_model, val_rates, val_behavior\n",
                "    )\n",
                "\n",
                "    # sample from the denoiser\n",
                "    print(val_latent_dataloader.dataset.behavior[:num_samples].shape)\n",
                "    if num_samples > len(val_latent_dataloader.dataset.behavior):\n",
                "        num_samples = len(val_latent_dataloader.dataset.behavior)\n",
                "    \n",
                "    sampled_latents = sample_with_velocity(\n",
                "        ema_denoiser=ema_denoiser,\n",
                "        scheduler=scheduler,\n",
                "        cfg=cfg,\n",
                "        batch_size=num_samples * num_samples_per_batch,\n",
                "        behavior_vel=repeat(\n",
                "            val_latent_dataloader.dataset.behavior[:num_samples],\n",
                "            \"B C L -> (S B) C L\",\n",
                "            S=num_samples_per_batch,\n",
                "        ),\n",
                "        device=device,\n",
                "    )\n",
                "    print(sampled_latents.shape)\n",
                "    sampled_latents = sampled_latents * latent_dataset_train.latent_stds.to(\n",
                "        sampled_latents.device\n",
                "    ) + latent_dataset_train.latent_means.to(\n",
                "        sampled_latents.device\n",
                "    )\n",
                "\n",
                "    with torch.no_grad():\n",
                "        sampled_rates = ae.decode(sampled_latents).cpu()\n",
                "\n",
                "    sampled_rates = rearrange(sampled_rates, \"(s b) c l -> (s b l) c\", s=num_samples_per_batch).numpy()\n",
                "    predicted_sampled_behavior = ridge_regression_model.predict(sampled_rates)\n",
                "    predicted_sampled_behavior = rearrange(\n",
                "        predicted_sampled_behavior, \"(s b l) c -> s b c l\", b=num_samples, s=num_samples_per_batch\n",
                "    )\n",
                "\n",
                "    target_sampled_behavior = (\n",
                "        val_latent_dataloader.dataset.behavior[:num_samples].cpu().numpy()\n",
                "    )\n",
                "    evaluation_results[\"sampled_behavior\"] = predicted_sampled_behavior\n",
                "    evaluation_results[\"real_behavior\"] = (\n",
                "        val_latent_dataloader.dataset.behavior[:num_samples].cpu().numpy()\n",
                "    )\n",
                "\n",
                "    return evaluation_results"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "ret_dict = gen_rates_and_train_decoded_behavior(\n",
                "    ema_denoiser=ema_model,\n",
                "    scheduler=scheduler,\n",
                "    ae=ae_model,\n",
                "    cfg=cfg,\n",
                "    train_latent_dataloader=train_latent_dataloader,\n",
                "    val_latent_dataloader=val_latent_dataloader,\n",
                "    num_samples=100,\n",
                "    device=\"cuda\",\n",
                "    num_samples_per_batch=5, \n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "ret_dict.keys()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "target_sampled_behavior = ret_dict[\"predicted_val_behavior\"]\n",
                "predicted_sampled_behavior = ret_dict[\"sampled_behavior\"]\n",
                "\n",
                "\n",
                "real_behavior = ret_dict[\"real_behavior\"]\n",
                "\n",
                "print(target_sampled_behavior.shape, predicted_sampled_behavior.shape)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "plt.plot(target_sampled_behavior[0, 0], color=\"blue\", linestyle='--')\n",
                "plt.plot(target_sampled_behavior[0, 1], color=\"blue\")\n",
                "\n",
                "plt.plot(real_behavior[0, 0], color=\"grey\", linestyle='--')\n",
                "plt.plot(real_behavior[0, 1], color=\"grey\")\n",
                "\n",
                "for i in range(len(predicted_sampled_behavior)):\n",
                "    plt.plot(predicted_sampled_behavior[i, 0, 0], color=\"red\", linestyle='--', alpha=0.2)\n",
                "    plt.plot(predicted_sampled_behavior[i, 0, 1], color=\"red\", alpha=0.2)\n",
                "plt.plot(predicted_sampled_behavior.mean(0)[0,0], color=\"red\", linestyle='--')\n",
                "plt.plot(predicted_sampled_behavior.mean(0)[0,1], color=\"red\")\n",
                "\n",
                "# create legend with colors for real, predicted and sampled, and linestyle for x and y\n",
                "# define label styles\n",
                "\n",
                "import matplotlib.lines as mlines\n",
                "real_line_dashed = mlines.Line2D([], [], color='grey', linestyle='--', label='Real(X)')\n",
                "real_line = mlines.Line2D([], [], color='grey', linestyle='-', label='Real(Y)')\n",
                "predicted_line_dashed = mlines.Line2D([], [], color='blue', linestyle='--', label='Predicted(X)')\n",
                "predicted_line = mlines.Line2D([], [], color='blue', linestyle='-', label='Predicted(Y)')\n",
                "sampled_line_dashed = mlines.Line2D([], [], color='red', linestyle='--', label='Sampled(X)')\n",
                "sampled_line = mlines.Line2D([], [], color='red', linestyle='-', label='Sampled(Y)')\n",
                "\n",
                "plt.legend(handles=[real_line, real_line_dashed, predicted_line, predicted_line_dashed, sampled_line, sampled_line_dashed])\n",
                "\n",
                "plt.title(\"Decoded vs sampled behavior\")\n",
                "plt.show()\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "plt.plot(target_sampled_behavior[0, 0], color=\"blue\", linestyle='--')\n",
                "plt.plot(target_sampled_behavior[0, 1], color=\"blue\")\n",
                "\n",
                "plt.plot(real_behavior[0, 0], color=\"grey\", linestyle='--')\n",
                "plt.plot(real_behavior[0, 1], color=\"grey\")\n",
                "\n",
                "for i in range(len(predicted_sampled_behavior)):\n",
                "    plt.plot(predicted_sampled_behavior[i, 0, 0], color=\"red\", linestyle='--', alpha=0.2)\n",
                "    plt.plot(predicted_sampled_behavior[i, 0, 1], color=\"red\", alpha=0.2)\n",
                "plt.plot(predicted_sampled_behavior.mean(0)[0,0], color=\"red\", linestyle='--')\n",
                "plt.plot(predicted_sampled_behavior.mean(0)[0,1], color=\"red\")\n",
                "\n",
                "# create legend with colors for real, predicted and sampled, and linestyle for x and y\n",
                "# define label styles\n",
                "\n",
                "import matplotlib.lines as mlines\n",
                "real_line_dashed = mlines.Line2D([], [], color='grey', linestyle='--', label='Real(X)')\n",
                "real_line = mlines.Line2D([], [], color='grey', linestyle='-', label='Real(Y)')\n",
                "predicted_line_dashed = mlines.Line2D([], [], color='blue', linestyle='--', label='Predicted(X)')\n",
                "predicted_line = mlines.Line2D([], [], color='blue', linestyle='-', label='Predicted(Y)')\n",
                "sampled_line_dashed = mlines.Line2D([], [], color='red', linestyle='--', label='Sampled(X)')\n",
                "sampled_line = mlines.Line2D([], [], color='red', linestyle='-', label='Sampled(Y)')\n",
                "\n",
                "plt.legend(handles=[real_line, real_line_dashed, predicted_line, predicted_line_dashed, sampled_line, sampled_line_dashed])\n",
                "\n",
                "plt.title(\"Decoded vs sampled behavior\")\n",
                "plt.show()\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "predicted_traj = np.cumsum(ret_dict[\"predicted_val_behavior\"], axis=-1)\n",
                "real_traj = np.cumsum(ret_dict[\"real_behavior\"], axis=-1)\n",
                "sampled_traj = np.cumsum(ret_dict[\"sampled_behavior\"], axis=-1)\n",
                "\n",
                "print(predicted_traj.shape, real_traj.shape, sampled_traj.shape)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from matplotlib.lines import Line2D\n",
                "\n",
                "for i_b_i, b_i in enumerate(range(0, 70, 7)): \n",
                "    if i_b_i % 5 == 0:\n",
                "        plt.figure(figsize=cm2inch(6, 4))\n",
                "\n",
                "    plt.plot(real_traj[b_i, 0], real_traj[b_i, 1], color=\"grey\", linestyle='-')\n",
                "    plt.plot(predicted_traj[b_i, 0], predicted_traj[b_i, 1], color=\"blue\")\n",
                "    plt.plot(sampled_traj.mean(0)[b_i, 0], sampled_traj.mean(0)[b_i, 1], color=\"red\", linestyle='-')\n",
                "    for i in range(len(sampled_traj)):\n",
                "        plt.plot(sampled_traj[i, b_i, 0], sampled_traj[i, b_i, 1], color=\"red\", alpha=0.2)\n",
                "            \n",
                "    if i_b_i % 5 == 4:\n",
                "        plt.xlabel(\"X position\")\n",
                "        plt.ylabel(\"Y position\")\n",
                "        legend_elements = [\n",
                "            Line2D([0], [0], color='grey', linestyle='-', label='Real'),\n",
                "            Line2D([0], [0], color='blue', linestyle='-', label='Predicted'),\n",
                "            Line2D([0], [0], color='red', linestyle='-', label='Sampled'),\n",
                "        ]\n",
                "        plt.legend(handles=legend_elements)\n",
                "        plt.title(\"Decoded vs sampled trajectory\")\n",
                "        plt.show()\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.plot(latent_dataset_val.behavior.mean(dim=0)[0], label=\"mean(vel_x)\")\n",
                "plt.plot(latent_dataset_val.behavior.mean(dim=0)[1], label=\"mean(vel_y)\")\n",
                "plt.fill_between(\n",
                "    np.arange(140),\n",
                "    latent_dataset_val.behavior.mean(dim=0)[0]\n",
                "    + latent_dataset_val.behavior.std(dim=0)[0],\n",
                "    latent_dataset_val.behavior.mean(dim=0)[0]\n",
                "    - latent_dataset_val.behavior.std(dim=0)[0],\n",
                "    label=\"std(vel_x)\",\n",
                "    alpha=0.5,\n",
                ")\n",
                "plt.fill_between(\n",
                "    np.arange(140),\n",
                "    latent_dataset_val.behavior.mean(dim=0)[1]\n",
                "    + latent_dataset_val.behavior.std(dim=0)[1],\n",
                "    latent_dataset_val.behavior.mean(dim=0)[1]\n",
                "    - latent_dataset_val.behavior.std(dim=0)[1],\n",
                "    label=\"std(vel_y)\",\n",
                "    alpha=0.5,\n",
                ")\n",
                "\n",
                "plt.legend()\n",
                "plt.title(\"Behavior statistics\")\n",
                "plt.show()\n",
                "\n",
                "for i in range(10):\n",
                "    plt.plot(latent_dataset_val.behavior[i, 0], color=\"grey\", alpha=0.2)\n",
                "plt.title(\"Velocity X\")\n",
                "plt.show()\n",
                "\n",
                "for i in range(10):\n",
                "    plt.plot(latent_dataset_val.behavior[i, 1], color=\"grey\", alpha=0.2)\n",
                "plt.title(\"Velocity Y\")\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "# define a parameterized half-heart trajectory (left half)\n",
                "def parameterized_half_heart(t):\n",
                "    x = 16 * np.sin(t) ** 3\n",
                "    y = 13 * np.cos(t) - 5 * np.cos(2 * t) - 2 * np.cos(3 * t) - np.cos(4 * t)\n",
                "    return np.array([x, y])\n",
                "\n",
                "# compute velocity (derivative) of the trajectory\n",
                "def compute_velocity(position, dt=1.0):\n",
                "    return np.gradient(position, dt, axis=1)\n",
                "\n",
                "# enforce that the final x displacement is zero\n",
                "def constrain_x_displacement(velocity_x):\n",
                "    final_x_displacement = np.sum(velocity_x)\n",
                "    correction = final_x_displacement / len(velocity_x)\n",
                "    velocity_x -= correction\n",
                "    return velocity_x\n",
                "\n",
                "# adjust velocity to match the per-time mean and std\n",
                "def scale_velocity_to_statistics(velocity_normalized, mean_velocity, std_velocity):\n",
                "    return velocity_normalized * std_velocity + mean_velocity\n",
                "\n",
                "# generate both lobes of the heart shape velocity vectors\n",
                "def generate_heart_velocity(mean_velocity_x, std_velocity_x, mean_velocity_y, std_velocity_y, num_points=140):\n",
                "    t_vals = np.linspace(0, np.pi, num_points)\n",
                "    \n",
                "    # half-heart trajectories\n",
                "    half_heart_position = parameterized_half_heart(t_vals)\n",
                "    half_heart_velocity = compute_velocity(half_heart_position)\n",
                "    velocity_norms = np.linalg.norm(half_heart_velocity, axis=0)\n",
                "    half_heart_velocity_normalized = half_heart_velocity / velocity_norms\n",
                "\n",
                "    # scale velocity to match statistics\n",
                "    scaled_velocity_x = scale_velocity_to_statistics(half_heart_velocity_normalized[0], mean_velocity_x, std_velocity_x)\n",
                "    scaled_velocity_y = scale_velocity_to_statistics(half_heart_velocity_normalized[1], mean_velocity_y, std_velocity_y)\n",
                "    scaled_velocity_x = constrain_x_displacement(scaled_velocity_x)\n",
                "    half_heart_velocity_scaled = np.vstack([scaled_velocity_x, scaled_velocity_y])\n",
                "\n",
                "    # create the second lobe by mirroring the x-axis velocities\n",
                "    mirrored_velocity = np.copy(half_heart_velocity_scaled)\n",
                "    mirrored_velocity[0] *= -1\n",
                "\n",
                "    # combine both lobes into one output\n",
                "    full_heart_velocity = np.stack([half_heart_velocity_scaled, mirrored_velocity], axis=0)\n",
                "    return full_heart_velocity\n",
                "\n",
                "# function to plot the heart trajectory\n",
                "def plot_heart_velocity(velocity, mean_velocity_x, std_velocity_x, mean_velocity_y, std_velocity_y):\n",
                "    fig, axs = plt.subplots(2, 1, figsize=cm2inch(10, 10))\n",
                "    axs[0].plot(velocity[0][0], 'b-', alpha=0.5, label='velocity x (left lobe)')\n",
                "    axs[0].plot(velocity[1][0], 'g-', alpha=0.5, label='velocity x (right lobe)')\n",
                "    axs[0].plot(mean_velocity_x, 'k--', alpha=0.7, label='mean velocity x')\n",
                "    axs[0].fill_between(range(len(std_velocity_x)), mean_velocity_x - std_velocity_x, mean_velocity_x + std_velocity_x, color='gray', alpha=0.3)\n",
                "    axs[0].legend()\n",
                "    axs[1].plot(velocity[0][1], 'r-', alpha=0.5, label='velocity y (left lobe)')\n",
                "    axs[1].plot(velocity[1][1], 'm-', alpha=0.5, label='velocity y (right lobe)')\n",
                "    axs[1].plot(mean_velocity_y, 'k--', alpha=0.7, label='mean velocity y')\n",
                "    axs[1].fill_between(range(len(std_velocity_y)), mean_velocity_y - std_velocity_y, mean_velocity_y + std_velocity_y, color='gray', alpha=0.3)\n",
                "    axs[1].legend()\n",
                "    axs[0].set_title('Adjusted Velocity Profiles')\n",
                "    axs[1].set_xlabel('time')\n",
                "    plt.tight_layout()\n",
                "    plt.show()\n",
                "\n",
                "# for alpha in [0.1, 0.2, 0.4, 0.8, 1.0, 1.5, 2.0, 3.0]:\n",
                "#     # example usage\n",
                "#     np.random.seed(42)\n",
                "#     num_points = 140\n",
                "#     mean_velocity_x = latent_dataset_val.behavior.mean(dim=0)[0].numpy()\n",
                "#     mean_velocity_y = latent_dataset_val.behavior.mean(dim=0)[1].numpy()\n",
                "#     std_velocity_x = alpha*latent_dataset_val.behavior.std(dim=0)[0].numpy()\n",
                "#     std_velocity_y = alpha*latent_dataset_val.behavior.std(dim=0)[1].numpy()\n",
                "\n",
                "#     # generate the heart velocities\n",
                "#     heart_velocity = generate_heart_velocity(mean_velocity_x, std_velocity_x, mean_velocity_y, std_velocity_y)\n",
                "\n",
                "#     # # plot the velocity profiles\n",
                "#     # plot_heart_velocity(heart_velocity, mean_velocity_x, std_velocity_x, mean_velocity_y, std_velocity_y)\n",
                "\n",
                "#     # plot the heart trajectory\n",
                "#     def plot_heart_trajectory(velocity):\n",
                "#         adjusted_position_left = np.cumsum(velocity[0], axis=1)\n",
                "#         adjusted_position_right = np.cumsum(velocity[1], axis=1)\n",
                "\n",
                "#         plt.figure(figsize=cm2inch(8, 8))\n",
                "#         plt.plot(adjusted_position_left[0], adjusted_position_left[1], 'b-', label='left lobe')\n",
                "#         plt.plot(adjusted_position_right[0], adjusted_position_right[1], 'r-', label='right lobe')\n",
                "#         plt.xlabel('x position')\n",
                "#         plt.ylabel('y position')\n",
                "#         plt.legend()\n",
                "#         plt.title('Heart Shape Trajectory, alpha={}'.format(alpha))\n",
                "#         # plt.grid(True)\n",
                "#         plt.show()\n",
                "\n",
                "#     # plot the adjusted heart trajectory\n",
                "#     plot_heart_trajectory(heart_velocity)\n",
                "\n",
                "alpha = 0.8\n",
                "mean_velocity_x = latent_dataset_val.behavior.mean(dim=0)[0].numpy()\n",
                "mean_velocity_y = latent_dataset_val.behavior.mean(dim=0)[1].numpy()\n",
                "std_velocity_x = alpha*latent_dataset_val.behavior.std(dim=0)[0].numpy()\n",
                "std_velocity_y = alpha*latent_dataset_val.behavior.std(dim=0)[1].numpy()\n",
                "\n",
                "# generate the heart velocities\n",
                "heart_velocity = generate_heart_velocity(mean_velocity_x, std_velocity_x, mean_velocity_y, std_velocity_y)\n",
                "\n",
                "# # plot the velocity profiles\n",
                "# plot_heart_velocity(heart_velocity, mean_velocity_x, std_velocity_x, mean_velocity_y, std_velocity_y)\n",
                "\n",
                "# plot the heart trajectory\n",
                "def plot_heart_trajectory(velocity):\n",
                "    adjusted_position_left = np.cumsum(velocity[0], axis=1)\n",
                "    adjusted_position_right = np.cumsum(velocity[1], axis=1)\n",
                "\n",
                "    plt.figure(figsize=cm2inch(8, 8))\n",
                "    plt.plot(adjusted_position_left[0], adjusted_position_left[1], 'b-', label='left lobe')\n",
                "    plt.plot(adjusted_position_right[0], adjusted_position_right[1], 'r-', label='right lobe')\n",
                "    plt.xlabel('x position')\n",
                "    plt.ylabel('y position')\n",
                "    plt.legend()\n",
                "    plt.title('Heart Shape Trajectory, alpha={}'.format(alpha))\n",
                "    # plt.grid(True)\n",
                "    plt.show()\n",
                "\n",
                "# plot the adjusted heart trajectory\n",
                "plot_heart_trajectory(heart_velocity)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "heart_velocity = torch.from_numpy(heart_velocity).float()\n",
                "heart_velocity"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "# function to generate samples with a given velocity (e.g., heart-shaped)\n",
                "def generate_samples_with_velocity(\n",
                "    ema_denoiser,\n",
                "    scheduler,\n",
                "    cfg,\n",
                "    ae,\n",
                "    ridge_regression_model,\n",
                "    latent_means,\n",
                "    latent_stds,\n",
                "    test_velocity,\n",
                "    num_samples_per_batch=1,\n",
                "    device=\"cuda\"\n",
                "):\n",
                "    # ema_denoiser.eval()\n",
                "    ae.eval()\n",
                "\n",
                "    # repeat the given velocity to match the number of samples per batch\n",
                "    repeated_velocity = repeat(test_velocity, \"b c l -> (s b) c l\", s=num_samples_per_batch)\n",
                "\n",
                "    # sample latents using the given velocity\n",
                "    sampled_latents = sample_with_velocity(\n",
                "        ema_denoiser=ema_denoiser,\n",
                "        scheduler=scheduler,\n",
                "        cfg=cfg,\n",
                "        batch_size=repeated_velocity.shape[0],\n",
                "        behavior_vel=repeated_velocity,\n",
                "        device=device,\n",
                "    )\n",
                "\n",
                "    # denormalize the sampled latents\n",
                "    sampled_latents = sampled_latents * latent_stds.to(sampled_latents.device) + latent_means.to(sampled_latents.device)\n",
                "\n",
                "    with torch.no_grad():\n",
                "        sampled_rates = ae.decode(sampled_latents).cpu()\n",
                "\n",
                "    sampled_rates = rearrange(sampled_rates, \"(s b) c l -> (s b l) c\", s=num_samples_per_batch).numpy()\n",
                "    predicted_behavior = ridge_regression_model.predict(sampled_rates)\n",
                "    predicted_behavior = rearrange(predicted_behavior, \"(s b l) c -> s b c l\", b=test_velocity.shape[0], s=num_samples_per_batch)\n",
                "\n",
                "    return predicted_behavior\n",
                "\n",
                "\n",
                "# function to generate rates, train decoded behavior, and predict from given velocity\n",
                "def gen_rates_train_and_predict_given_velocity(\n",
                "    ema_denoiser,\n",
                "    scheduler,\n",
                "    ae,\n",
                "    cfg,\n",
                "    train_latent_dataloader,\n",
                "    val_latent_dataloader,\n",
                "    device=\"cuda\",\n",
                "    num_samples_per_batch=1,\n",
                "    test_velocity=None,\n",
                "):\n",
                "\n",
                "    ae.eval()\n",
                "\n",
                "    train_rates = []\n",
                "    train_behavior = []\n",
                "    for batch in train_latent_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        behavior = batch[\"behavior\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = ae(signal)[0].cpu()\n",
                "        train_rates.append(output_rates)\n",
                "        train_behavior.append(behavior.cpu())\n",
                "\n",
                "    train_rates = torch.cat(train_rates, 0)  # [b c l]\n",
                "    train_behavior = torch.cat(train_behavior, 0)  # [b 2 l]\n",
                "    print(train_rates, train_behavior)\n",
                "\n",
                "    val_rates = []\n",
                "    val_behavior = []\n",
                "    for batch in val_latent_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        behavior = batch[\"behavior\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = ae(signal)[0].cpu()\n",
                "        val_rates.append(output_rates)\n",
                "        val_behavior.append(behavior.cpu())\n",
                "\n",
                "    val_rates = torch.cat(val_rates, 0)  # [b c l]\n",
                "    val_behavior = torch.cat(val_behavior, 0)  # [b 2 l]\n",
                "    print(val_rates, val_behavior)\n",
                "\n",
                "    # train ridge regression model on training data\n",
                "    ridge_regression_model = train_ridge_regression(train_rates, train_behavior)\n",
                "\n",
                "    # evaluate ridge regression model on validation data\n",
                "    evaluation_results = evaluate_ridge_regression(ridge_regression_model, val_rates, val_behavior)\n",
                "\n",
                "    # ensure test_velocity is provided\n",
                "    if test_velocity is None:\n",
                "        raise ValueError(\"test_velocity parameter is required!\")\n",
                "\n",
                "    # sample from the denoiser using the given velocity\n",
                "    sampled_behavior = generate_samples_with_velocity(\n",
                "        ema_denoiser=ema_denoiser,\n",
                "        scheduler=scheduler,\n",
                "        cfg=cfg,\n",
                "        ae=ae,\n",
                "        ridge_regression_model=ridge_regression_model,\n",
                "        latent_means=latent_dataset_train.latent_means,\n",
                "        latent_stds=latent_dataset_train.latent_stds,\n",
                "        test_velocity=test_velocity,\n",
                "        num_samples_per_batch=num_samples_per_batch,\n",
                "        device=device,\n",
                "    )\n",
                "\n",
                "    evaluation_results[\"sampled_behavior\"] = sampled_behavior\n",
                "\n",
                "    return evaluation_results"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "results = gen_rates_train_and_predict_given_velocity(\n",
                "    ema_denoiser=ema_model,\n",
                "    scheduler=scheduler,\n",
                "    ae=ae_model,\n",
                "    cfg=cfg,\n",
                "    train_latent_dataloader=train_latent_dataloader,\n",
                "    val_latent_dataloader=val_latent_dataloader,\n",
                "    device=\"cuda\",\n",
                "    num_samples_per_batch=5,\n",
                "    test_velocity=heart_velocity,\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "sampled_heart_velocity = results[\"sampled_behavior\"]\n",
                "real_heart_velocity = heart_velocity.cpu().numpy()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "sampled_heart_velocity.shape, real_heart_velocity.shape"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "sampled_heart_traj = np.cumsum(sampled_heart_velocity, axis=-1)\n",
                "real_heart_traj = np.cumsum(real_heart_velocity, axis=-1)\n",
                "\n",
                "# sampled_heart_traj = rearrange(sampled_heart_traj, \"s b c l -> s c (b l)\")\n",
                "# real_heart_traj = rearrange(real_heart_traj, \"b c l -> c (b l)\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.plot(real_heart_velocity[0, 0], color=\"grey\", linestyle='--')\n",
                "plt.plot(real_heart_velocity[1, 0], color=\"grey\")\n",
                "\n",
                "plt.plot(sampled_heart_velocity.mean(0)[0, 0], color=\"red\", linestyle='--')\n",
                "plt.plot(sampled_heart_velocity.mean(0)[1, 0], color=\"red\")\n",
                "\n",
                "legend_elements = [\n",
                "    Line2D([0], [0], color='grey', linestyle='--', label='Real(X)'),\n",
                "    Line2D([0], [0], color='grey', linestyle='-', label='Real(Y)'),\n",
                "    Line2D([0], [0], color='red', linestyle='--', label='Sampled(X)'),\n",
                "    Line2D([0], [0], color='red', linestyle='-', label='Sampled(Y)'),\n",
                "]\n",
                "\n",
                "plt.legend(handles=legend_elements)\n",
                "plt.title(\"Real vs sampled heart velocity\")\n",
                "plt.show()\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.plot(real_heart_traj[0, 0], real_heart_traj[0, 1], 'b-', label='real')\n",
                "plt.plot(real_heart_traj[1, 0], real_heart_traj[1, 1], 'b-')\n",
                "plt.plot(sampled_heart_traj.mean(0)[0, 0], sampled_heart_traj.mean(0)[0, 1], 'r-', label='sampled')\n",
                "plt.plot(sampled_heart_traj.mean(0)[1, 0], sampled_heart_traj.mean(0)[1, 1], 'r-')\n",
                "plt.xlabel('x position')\n",
                "plt.ylabel('y position')\n",
                "plt.legend()\n",
                "\n",
                "plt.title('Heart Shape Trajectory')\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "timeseries",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.8.18"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
