{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import argparse\n",
                "import os\n",
                "import sys\n",
                "\n",
                "# append parent directory to path (../notebooks -> ..)\n",
                "sys.path.append(os.path.dirname(os.getcwd()))\n",
                "os.chdir(os.path.dirname(os.getcwd()))\n",
                "\n",
                "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
                "\n",
                "import accelerate\n",
                "import auraloss  # freq loss\n",
                "import lovely_tensors as lt\n",
                "import matplotlib.pyplot as plt\n",
                "import matplotlib\n",
                "import numpy as np\n",
                "import scipy.io as io\n",
                "import seaborn as sns\n",
                "import torch\n",
                "import torch.nn as nn\n",
                "import wandb\n",
                "import yaml\n",
                "from diffusers.optimization import get_scheduler\n",
                "from omegaconf import OmegaConf\n",
                "from scipy.signal import welch\n",
                "from tqdm.auto import tqdm\n",
                "\n",
                "from ntldm.data.phoneme import get_phoneme_dataloaders\n",
                "from ntldm.networks import S4AE, AutoEncoder, CountWrapper\n",
                "from ntldm.utils.plotting_utils import *\n",
                "from ntldm.losses import latent_regularizer\n",
                "\n",
                "# always run from ../ntldm\n",
                "\n",
                "\n",
                "lt.monkey_patch()\n",
                "matplotlib.rc_file('matplotlibrc')\n",
                "\n",
                "%load_ext autoreload\n",
                "%autoreload 2\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "\n",
                "cfg_yaml = \"\"\"\n",
                "model:\n",
                "  C_in: 128\n",
                "  C: 256\n",
                "  C_latent: 32\n",
                "  kernel: s4\n",
                "  num_blocks: 6\n",
                "  num_blocks_decoder: 0\n",
                "  num_lin_per_mlp: 2\n",
                "  bidirectional: False # important!\n",
                "dataset:\n",
                "  system_name: phoneme\n",
                "  datapath: data/phoneme/competitionData\n",
                "  max_seqlen: 512\n",
                "training:\n",
                "  lr: 0.001\n",
                "  num_epochs: 400\n",
                "  num_warmup_epochs: 20\n",
                "  batch_size: 256\n",
                "  random_seed: 42\n",
                "  precision: bf16\n",
                "  latent_beta: 0.001\n",
                "  latent_td_beta: 0.02\n",
                "  tk_k: 5\n",
                "  mask_prob: 0.20\n",
                "exp_name: autoencoder-count_s4-phoneme_td0.02_l6\n",
                "\"\"\"\n",
                "\n",
                "\n",
                "# omegaconf from yaml\n",
                "cfg = OmegaConf.create(yaml.safe_load(cfg_yaml))\n",
                "print(OmegaConf.to_yaml(cfg))\n",
                "# # save\n",
                "# with open(f\"conf/sweeps_count/{cfg.exp_name}.yaml\", \"w\") as f:\n",
                "#     f.write(OmegaConf.to_yaml(cfg))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import math\n",
                "from ntldm.data.phoneme import get_phoneme_dataloaders\n",
                "\n",
                "# set seed\n",
                "torch.manual_seed(cfg.training.random_seed)\n",
                "np.random.seed(cfg.training.random_seed)\n",
                "\n",
                "train_dataloader, val_dataloader, test_dataloader = get_phoneme_dataloaders(\n",
                "        cfg.dataset.datapath, batch_size=cfg.training.batch_size, max_seqlen=cfg.dataset.max_seqlen\n",
                "    )"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Visualise dataset"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "plt.imshow(train_dataloader.dataset[0]['signal'], aspect='auto', cmap='Greys')\n",
                "plt.colorbar()\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Define model, Optimizer and Learning rate scheduler\n",
                "\n",
                "- wrap accelerator around "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
                "\n",
                "\n",
                "ae = AutoEncoder(\n",
                "    C_in=cfg.model.C_in,\n",
                "    C=cfg.model.C,\n",
                "    C_latent=cfg.model.C_latent,\n",
                "    L=cfg.dataset.max_seqlen,\n",
                "    kernel=cfg.model.kernel,\n",
                "    num_blocks=cfg.model.num_blocks,\n",
                "    num_blocks_decoder=cfg.model.get(\"num_blocks_decoder\", cfg.model.num_blocks),\n",
                "    num_lin_per_mlp=cfg.model.get(\"num_lin_per_mlp\", 2),  # default 2\n",
                "    bidirectional=cfg.model.get(\"bidirectional\", False), # default is false for phoneme\n",
                ")\n",
                "\n",
                "print(\"Number of params\", sum(p.numel() for p in ae.parameters() if p.requires_grad)/1e6, \"M\")\n",
                "\n",
                "ae = CountWrapper(ae, use_sin_enc=cfg.model.get(\"use_sin_enc\", False))\n",
                "print(ae)\n",
                "\n",
                "ae = ae.to(device)\n",
                "optimizer = torch.optim.AdamW(\n",
                "    ae.parameters(), lr=cfg.training.lr\n",
                ")  # default wd=0.01 for now\n",
                "\n",
                "num_batches = len(train_dataloader)\n",
                "lr_scheduler = get_scheduler(\n",
                "    name=\"cosine\",\n",
                "    optimizer=optimizer,\n",
                "    num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "    num_training_steps=num_batches * cfg.training.num_epochs * 1.5,  # total number of steps\n",
                "    # num_training_steps=cfg.training.num_epochs # AS changed\n",
                ")\n",
                "\n",
                "# check if signal length is power of 2\n",
                "if cfg.dataset.max_seqlen & (cfg.dataset.max_seqlen - 1) != 0:\n",
                "    cfg.training.precision = \"no\"  # torch.fft doesnt support half if L!=2^x\n",
                "\n",
                "\n",
                "#\n",
                "\n",
                "# Accelerate setuo\n",
                "\n",
                "accelerator = accelerate.Accelerator(\n",
                "    mixed_precision=cfg.training.precision,\n",
                "    log_with=\"wandb\",\n",
                ")\n",
                "\n",
                "\n",
                "(\n",
                "    ae,\n",
                "    optimizer,\n",
                "    lr_scheduler,\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ") = accelerator.prepare(\n",
                "    ae,\n",
                "    optimizer,\n",
                "    lr_scheduler,\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ")\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "accelerator.load_state(os.path.join('exp', cfg.exp_name, 'epoch_400'))"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# set up losses"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "criterion_poisson = nn.PoissonNLLLoss(log_input=False, full=True, reduction=\"none\")\n",
                "\n",
                "def compute_val_loss(net, dataloader):\n",
                "    net.eval()\n",
                "    poisson_loss_total = 0\n",
                "    rates_loss_total = 0\n",
                "    batch_count = 0\n",
                "\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        signal_mask = batch[\"mask\"].cpu()\n",
                "        with torch.no_grad():\n",
                "            output_rates = net(signal)[0].cpu()\n",
                "\n",
                "        signal = signal.cpu()  # move signal to cpu\n",
                "\n",
                "        # compute pointwise l2 loss\n",
                "        poisson_loss = criterion_poisson(output_rates, signal)\n",
                "        poisson_loss = poisson_loss * signal_mask\n",
                "    \n",
                "        poisson_loss_total += poisson_loss.mean().item()\n",
                "        batch_count += 1\n",
                "\n",
                "    # compute average losses over all batches\n",
                "    avg_poisson_loss = poisson_loss_total / batch_count * cfg.training.mask_prob\n",
                "    print(f\"Validation loss: {avg_poisson_loss:.4f}, mask_prob {cfg.training.mask_prob}\")\n",
                "\n",
                "    fig, ax = plt.subplots(2, 1, figsize=(10, 2), dpi=300)\n",
                "    for row in range(2):  # plot channels 0 and 71\n",
                "        ax[row].plot(output_rates[0, 92 * (row)].cpu().clip(0, 3).numpy(), label=\"pred\")\n",
                "        ax[row].plot(\n",
                "            batch[\"signal\"][0, 92 * (row)].cpu().clip(0, 3).numpy(),\n",
                "            label=\"spikes\",\n",
                "            alpha=0.5,\n",
                "            color=\"grey\",\n",
                "        )\n",
                "        plt.legend()\n",
                "    wandb.log({\"val_rates\": wandb.Image(fig)})\n",
                "    plt.close(fig)\n",
                "\n",
                "    return avg_poisson_loss\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import cm2inch\n",
                "from einops import rearrange\n",
                "\n",
                "\n",
                "def plot_rate_traces_real(model, dataloader, figsize=(12, 5), idx=0):\n",
                "    model.eval()\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        signal_mask = batch[\"mask\"].cpu()\n",
                "        with torch.no_grad():\n",
                "            output_rates = model(signal)[0].cpu()\n",
                "\n",
                "        signal = signal.cpu()  # move signal to cpu\n",
                "        break\n",
                "\n",
                "    # select 2 channels that are 100%ile and 30%ile in the firing rates across the minibatch\n",
                "\n",
                "    mean_firing_rates = signal.mean(0).mean(0)\n",
                "    channels = torch.argsort(mean_firing_rates, descending=True)[:2]\n",
                "\n",
                "    fig, ax = plt.subplots(1, len(channels), figsize=cm2inch(figsize), dpi=150)\n",
                "\n",
                "    for i, channel in enumerate(channels):\n",
                "        # print(batch[\"signal\"][idx, channel])\n",
                "        L_actual = int(signal_mask[idx, channel].sum().item())\n",
                "        # print('L_actual: ', L_actual)\n",
                "        L = batch[\"signal\"][idx, channel].shape[0]\n",
                "        ax[i].vlines(\n",
                "            torch.arange(L_actual),\n",
                "            torch.zeros(L_actual),\n",
                "            torch.ones(L_actual)\n",
                "            * output_rates[idx, channel, :L_actual].cpu().max().item(),\n",
                "            color=\"black\",\n",
                "            alpha=np.min(\n",
                "                np.stack(\n",
                "                    (\n",
                "                        np.ones(L_actual),\n",
                "                        batch[\"signal\"][idx, channel, :L_actual].cpu().numpy() * 0.1,\n",
                "                    ),\n",
                "                    axis=1,\n",
                "                ),\n",
                "                axis=1,\n",
                "            ),\n",
                "        )\n",
                "        ax[i].plot(\n",
                "            output_rates[idx, channel, :L_actual].cpu().numpy(),\n",
                "            label=\"pred\",\n",
                "            color=\"red\",\n",
                "        )\n",
                "        ax[i].set_title(f\"channel {channel}\")\n",
                "\n",
                "    ax[-1].legend()\n",
                "\n",
                "    fig.suptitle(\"rate traces for channels\")\n",
                "    fig.tight_layout()\n",
                "    plt.show()\n",
                "\n",
                "\n",
                "def imshow_rates_real(model, dataloader, figsize=(12, 5), idx=0):\n",
                "    model.eval()\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        signal_mask = batch[\"mask\"].cpu()\n",
                "        with torch.no_grad():\n",
                "            output_rates = model(signal)[0].cpu()\n",
                "\n",
                "        signal = signal.cpu()  # move signal to cpu\n",
                "        break\n",
                "\n",
                "    fig, ax = plt.subplots(1, 2, figsize=cm2inch(figsize), dpi=150)\n",
                "\n",
                "    L_actual = int(batch[\"mask\"][idx, 0].sum().item())\n",
                "\n",
                "    im1 = ax[0].imshow(\n",
                "        output_rates[idx, :, :L_actual].cpu().numpy(),\n",
                "        label=\"rates\",\n",
                "        aspect=\"auto\",\n",
                "        cmap=\"Greys\",\n",
                "    )\n",
                "    im2 = ax[1].imshow(\n",
                "        signal[idx, :, :L_actual].cpu().numpy(),\n",
                "        label=\"rates\",\n",
                "        aspect=\"auto\",\n",
                "        cmap=\"Greys\",\n",
                "    )\n",
                "    plt.colorbar(im1, ax=ax[0])\n",
                "    plt.colorbar(im2, ax=ax[1])\n",
                "\n",
                "    # ax.set_title(f\"channel {channel}\")\n",
                "\n",
                "    ax[-1].legend()\n",
                "    ax[0].set_title(\"Inferred rates\")\n",
                "    ax[1].set_title(\"Spikes\")\n",
                "\n",
                "    fig.suptitle(f\"infeered rates, idx {idx}\")\n",
                "    fig.tight_layout()\n",
                "    plt.show()\n",
                "\n",
                "\n",
                "def compute_latents(model, dataloader):\n",
                "    model.eval()\n",
                "    latents = []\n",
                "    signal_masks = []\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        signal_mask = batch[\"mask\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates, z = model(signal)\n",
                "            z = z.cpu()\n",
                "        latents.append(z)\n",
                "        signal_masks.append(signal_mask.cpu())\n",
                "\n",
                "    return {\n",
                "        \"latents\": torch.cat(latents, 0),\n",
                "        \"signal_masks\": torch.cat(signal_masks, 0),\n",
                "    }\n",
                "\n",
                "\n",
                "def reconstruct_spikes(model, dataloader):\n",
                "    model.eval()\n",
                "    latents = []\n",
                "    spikes = []\n",
                "    rec_spikes = []\n",
                "    signal_masks = []\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        signal_mask = batch[\"mask\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates, z = model(signal)\n",
                "            z = z.cpu()\n",
                "        latents.append(z)\n",
                "        spikes.append(signal.cpu())\n",
                "        rec_spikes.append(torch.poisson(output_rates.cpu()) * signal_mask.cpu())\n",
                "        signal_masks.append(signal_mask.cpu())\n",
                "\n",
                "    return {\n",
                "        \"latents\": torch.cat(latents, 0),\n",
                "        \"spikes\": torch.cat(spikes, 0),\n",
                "        \"rec_spikes\": torch.cat(rec_spikes, 0),\n",
                "        \"signal_masks\": torch.cat(signal_masks, 0),\n",
                "    }\n",
                "\n",
                "\n",
                "def plot_corrcoef(rec_dict, figsize=cm2inch(12, 4)):\n",
                "    # cross-correlation between neurons\n",
                "\n",
                "    # corrcoef_real = np.corrcoef(rec_dict['spikes'][:,:].sum(2).numpy(), rowvar=False)\n",
                "    # corrcoef_rec = np.corrcoef(rec_dict['rec_spikes'][:,:].sum(2).numpy(), rowvar=False)\n",
                "\n",
                "    real_spikes = [\n",
                "        rec_dict[\"spikes\"][i, :, : int(rec_dict[\"signal_masks\"][i, 0].sum().item())]\n",
                "        for i in range(len(rec_dict[\"spikes\"]))\n",
                "    ]\n",
                "    real_spikes = torch.cat(real_spikes, 1).numpy()\n",
                "\n",
                "    rec_spikes = [\n",
                "        rec_dict[\"rec_spikes\"][i, :, : int(rec_dict[\"signal_masks\"][i, 0].sum().item())]\n",
                "        for i in range(len(rec_dict[\"rec_spikes\"]))\n",
                "    ]\n",
                "    rec_spikes = torch.cat(rec_spikes, 1).numpy()\n",
                "\n",
                "    print(f'rec_spikes shape {rec_spikes.shape}, real_spikes shape {real_spikes.shape}')\n",
                "\n",
                "    corrcoef_real = np.corrcoef(\n",
                "        real_spikes,\n",
                "        rowvar=True,\n",
                "    )\n",
                "    corrcoef_rec = np.corrcoef(\n",
                "        rec_spikes,\n",
                "        rowvar=True,\n",
                "    )\n",
                "    print(corrcoef_real.shape, corrcoef_rec.shape)\n",
                "\n",
                "    np.fill_diagonal(corrcoef_real, 0.01)\n",
                "    np.fill_diagonal(corrcoef_rec, 0.01)\n",
                "\n",
                "    fig, axs = plt.subplots(1, 3, figsize=figsize)\n",
                "\n",
                "    # Plot corrcoef_real\n",
                "    axs[0].imshow(\n",
                "        corrcoef_real,\n",
                "        cmap=\"coolwarm\",\n",
                "        vmin=-1,\n",
                "        vmax=1,\n",
                "    )\n",
                "    axs[0].set_title(\"neuron correlations gt\")\n",
                "    axs[0].axis(\"off\")\n",
                "    # colorbar\n",
                "    cbar = plt.colorbar(\n",
                "        axs[0].imshow(corrcoef_real, cmap=\"coolwarm\"),\n",
                "        ax=axs[0],\n",
                "        orientation=\"vertical\",\n",
                "        fraction=0.046,\n",
                "        pad=0.04,\n",
                "        # ticks=[-1, 0, 1],\n",
                "        boundaries=np.linspace(-1.01, 1.01, 50),\n",
                "    )\n",
                "\n",
                "        \n",
                "    # Plot corrcoef_rec\n",
                "    axs[1].imshow(\n",
                "        corrcoef_rec,\n",
                "        cmap=\"coolwarm\",\n",
                "        vmin=-1,\n",
                "        vmax=1,\n",
                "    )\n",
                "    axs[1].set_title(\"neuron correlations ae\")\n",
                "    axs[1].axis(\"off\")\n",
                "    # colorbar\n",
                "    cbar = plt.colorbar(\n",
                "        axs[1].imshow(corrcoef_rec, cmap=\"coolwarm\"),\n",
                "        ax=axs[1],\n",
                "        orientation=\"vertical\",\n",
                "        fraction=0.046,\n",
                "        pad=0.04,\n",
                "        boundaries=np.linspace(-1.01, 1.01, 50),\n",
                "\n",
                "        # ticks=[-1, 0, 1],\n",
                "    )\n",
                "\n",
                "    # Plot difference\n",
                "    axs[2].imshow(np.abs(corrcoef_rec - corrcoef_real), cmap=\"magma\")\n",
                "    axs[2].set_title(\"neuron |corr_real-corr_recon|\")\n",
                "    axs[2].axis(\"off\")\n",
                "    # all ticks\n",
                "    # axs[2].set_yticks(np.arange(0, corrcoef_real.shape[0], 10), fontsize=4)\n",
                "\n",
                "    # colorbar\n",
                "    cbar = plt.colorbar(\n",
                "        axs[2].imshow(np.abs(corrcoef_rec - corrcoef_real), cmap=\"magma\"),\n",
                "        orientation=\"vertical\",\n",
                "        fraction=0.046,\n",
                "        pad=0.04,\n",
                "        boundaries=np.linspace(0, 1.01, 50),\n",
                "    )\n",
                "\n",
                "\n",
                "    plt.tight_layout()\n",
                "    plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plot_corrcoef(rec_dict, figsize=cm2inch(20, 6))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plot_rate_traces_real(ae, test_dataloader, figsize=(12, 5), idx=0)\n",
                "imshow_rates_real(ae, test_dataloader, figsize=(12, 5), idx=0)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## training loop"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# train loop\n",
                "criterion_poisson = nn.PoissonNLLLoss(log_input=False, full=True, reduction=\"none\")\n",
                "\n",
                "rec_losses, latent_losses, total_losses, lrs, val_rate_losses = [], [], [], [], []\n",
                "avg_poisson_loss, avg_rate_loss = 0, 0\n",
                "import wandb\n",
                "\n",
                "# os.environ[\"WANDB_MODE\"] = \"online\"\n",
                "wandb.init(project=\"ntldm\", entity=\"anon-project\")\n",
                "with tqdm(range(0, cfg.training.num_epochs)) as pbar:\n",
                "    for epoch in pbar:\n",
                "        ae.train()\n",
                "\n",
                "        for i, data in enumerate(train_dataloader):\n",
                "            optimizer.zero_grad()\n",
                "\n",
                "            signal = data[\"signal\"]\n",
                "            signal_mask = data[\"mask\"]\n",
                "            L_actual = signal_mask[:, 0, :].sum(-1)\n",
                "\n",
                "            # applying mask (coordinated dropout)\n",
                "            mask_prob = cfg.training.get(\"mask_prob\", 0.25)\n",
                "            \n",
                "            mask = (\n",
                "                torch.rand_like(signal[:]) > mask_prob\n",
                "            ).float()  # if mask_prob=0.2, 80% will be 1 and rest 0\n",
                "            input_signal = signal * (\n",
                "                mask / (1 - mask_prob)\n",
                "            )  # mask and scale unmasked by 1/(1-p)\n",
                "\n",
                "            output_rates, z = ae(input_signal)\n",
                "\n",
                "            numel = signal.shape[0] * signal.shape[1] * signal.shape[2]\n",
                "\n",
                "            # computing loss on masked parts\n",
                "            unmasked = (1 - mask) if mask_prob > 0 else torch.ones_like(mask)\n",
                "            poisson_loss = criterion_poisson(output_rates, signal) * unmasked\n",
                "\n",
                "            poisson_loss = poisson_loss * signal_mask # also mask out padding\n",
                "\n",
                "            poisson_loss = poisson_loss.mean()\n",
                "\n",
                "            rec_loss = poisson_loss\n",
                "\n",
                "            latent_loss = latent_regularizer(z, cfg) / numel\n",
                "            loss = rec_loss + cfg.training.latent_beta * latent_loss\n",
                "\n",
                "            accelerator.backward(loss)\n",
                "            accelerator.clip_grad_norm_(ae.parameters(), 2.0)\n",
                "\n",
                "            optimizer.step()\n",
                "            lr_scheduler.step()\n",
                "\n",
                "            pbar.set_postfix(\n",
                "                **{\n",
                "                    \"rec_loss\": rec_loss.item(),\n",
                "                    \"latent_loss\": latent_loss.item(),\n",
                "                    \"total_loss\": loss.item(),\n",
                "                    \"lr\": optimizer.param_groups[0][\"lr\"],\n",
                "                    \"val_poisson_loss\": avg_poisson_loss,\n",
                "                }\n",
                "            )\n",
                "            rec_losses.append(rec_loss.item())\n",
                "            latent_losses.append(latent_loss.item())\n",
                "            total_losses.append(loss.item())\n",
                "            lrs.append(optimizer.param_groups[0][\"lr\"])\n",
                "            wandb.log(\n",
                "                {\n",
                "                    \"rec_loss\": rec_loss.item(),\n",
                "                    \"latent_loss\": latent_loss.item(),\n",
                "                    \"total_loss\": loss.item(),\n",
                "                    \"lr\": optimizer.param_groups[0][\"lr\"],\n",
                "                    \"epoch\": epoch,\n",
                "                }\n",
                "            )\n",
                "        # eval\n",
                "\n",
                "        if accelerator.is_main_process and (\n",
                "            (epoch) % 10 == 0 or epoch == cfg.training.num_epochs - 1\n",
                "        ):\n",
                "            avg_poisson_loss = compute_val_loss(ae, val_dataloader)\n",
                "            wandb.log({\"val_poisson_loss\": avg_poisson_loss})\n",
                "        if accelerator.is_main_process and (\n",
                "            (epoch) % 20 == 0 or epoch == cfg.training.num_epochs - 1\n",
                "        ):\n",
                "\n",
                "            ae.eval()\n",
                "            plot_rate_traces_real(ae, val_dataloader, figsize=(12, 5), idx=1)\n",
                "            imshow_rates_real(ae, val_dataloader, figsize=(12, 5), idx=1)\n",
                "\n",
                "            rec_dict = reconstruct_spikes(ae, val_dataloader)\n",
                "            # plot reconstructed spikes\n",
                "            plt.figure(figsize=cm2inch((6, 4)))\n",
                "            # bins = np.linspace(0, 20, 20)\n",
                "            plt.hist(rec_dict['spikes'][:,:].sum(2).flatten(), density=True, color='grey', bins=100, alpha=0.5)\n",
                "            plt.hist(rec_dict['rec_spikes'][:,:].sum(2).flatten(), density=True, color='darkblue', bins=100, alpha=0.5)\n",
                "\n",
                "            plt.xlim(0, 1000)\n",
                "\n",
                "            plt.legend(['gt', 'ae'])\n",
                "            # compute wasserstein distance\n",
                "            from scipy.stats import wasserstein_distance\n",
                "            wass = wasserstein_distance(rec_dict['spikes'][:,:].sum(2).flatten(), rec_dict['rec_spikes'][:,:].sum(2).flatten())\n",
                "            plt.title('spike count distribution (val set)')\n",
                "            plt.show()\n",
                "\n",
                "            \n",
                "            fig, ax = plt.subplots(1, 2,figsize=cm2inch((10, 4)))\n",
                "            ax[0].scatter(rec_dict['spikes'][:,:].sum(2).flatten(), rec_dict['rec_spikes'][:,:].sum(2).flatten(), alpha=0.1)\n",
                "            ax[0].set_xlabel('gt spike count')\n",
                "            ax[0].set_ylabel('ae spike count')\n",
                "            ax[0].plot([0, 1000], [0, 1000], color='black')\n",
                "            ax[0].set_xlim(0, 1000)\n",
                "            ax[0].set_ylim(0, 1000)\n",
                "            \n",
                "                \n",
                "            ax[1].scatter(rec_dict['spikes'][:,:].mean((0,2)).flatten(), rec_dict['rec_spikes'][:,:].mean((0,2)).flatten(), alpha=0.5)\n",
                "            ax[1].set_xlabel('gt neuron mean spike rate')\n",
                "            ax[1].set_ylabel('ae s mean spike rate')\n",
                "            ax[1].plot([0, rec_dict['spikes'][:,:].mean((0,2)).max()], [0, rec_dict['spikes'][:,:].mean((0,2)).flatten().max()], color='black')ax[1].set_xlabel('gt spike count')\n",
                "            fig.tigjt_layout()\n",
                "            plt.show()\n",
                "\n",
                "            plot_corrcoef(rec_dict)\n",
                "\n",
                "            \n",
                "            accelerator.save_state(f\"exp/{cfg.exp_name}/epoch_{(epoch+20)//20*20}\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "poisson_loss, signal_mask"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## eval"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "rec_dict = reconstruct_spikes(ae, val_dataloader)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "rec_dict['rec_spikes'].shape"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "L_actuals = [\n",
                "    int(rec_dict['signal_masks'][i,0].sum().item()) for i in range(rec_dict['signal_masks'].shape[0])\n",
                "]\n",
                "\n",
                "plt.hist(L_actuals, bins=np.linspace(min(L_actuals), 512, 20))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "latents = [rec_dict[\"latents\"][i].numpy() for i in range(rec_dict[\"latents\"].shape[0])]\n",
                "\n",
                "latents = [l[:, : L_actuals[i]] for i, l in enumerate(latents)]\n",
                "\n",
                "\n",
                "fig, ax = plt.subplots(4, 8, figsize=(15, 4), dpi=300)\n",
                "ax = ax.flatten()\n",
                "for i in range(32):\n",
                "    cmap = sns.color_palette(\n",
                "        \"hsv\", as_cmap=True, n_colors=len(list(range(0, len(latents), 20)))\n",
                "    )\n",
                "    for j in range(0, len(latents), 20):\n",
                "        ax[i].plot(latents[j][i, :], alpha=0.2, linewidth=2, c=cmap(j))\n",
                "    ax[i].set_title(f\"latent {i}\")\n",
                "    ax[i].set_xticks([])\n",
                "    ax[i].set_yticks([])\n",
                "\n",
                "\n",
                "fig.tight_layout()\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "# plot reconstructed spikes\n",
                "plt.figure(figsize=cm2inch((6, 4)))\n",
                "bins = np.linspace(0, 800, 100)\n",
                "plt.hist(rec_dict['spikes'][:,:].sum(2).flatten(), density=True, color='grey', bins=bins, alpha=0.5)\n",
                "plt.hist(rec_dict['rec_spikes'][:,:].sum(2).flatten(), density=True, color='darkblue', bins=bins, alpha=0.5)\n",
                "\n",
                "plt.xlim(0, 800)\n",
                "\n",
                "plt.legend(['gt', 'ae'])\n",
                "plt.title('spike count distribution (val set)')\n",
                "plt.show()\n",
                "\n",
                "plt.scatter(rec_dict['spikes'][:,:].mean((0,2)).flatten(), rec_dict['rec_spikes'][:,:].mean((0,2)).flatten(), alpha=0.5)\n",
                "plt.xlabel('gt neuron mean firing rate')\n",
                "plt.ylabel('ae neuron mean firing rate')\n",
                "plt.plot([0, rec_dict['spikes'][:,:].mean((0,2)).max()], [0, rec_dict['spikes'][:,:].mean((0,2)).flatten().max()], color='black')\n",
                "# plt.xlim(0, 1000)\n",
                "# plt.ylim(0, 1000)\n",
                "\n",
                "\n",
                "# get smallest ratio of rec to gt firing rate\n",
                "ratios = rec_dict['rec_spikes'][:,:].mean((0,2)).flatten() / rec_dict['spikes'][:,:].mean((0,2)).flatten()\n",
                "min_idx = ratios.argmin()\n",
                "min_idx, ratios.min()\n",
                "\n",
                "\n",
                "# plot rate traces for min_idx\n",
                "fig, ax = plt.subplots(2, 1, figsize=cm2inch((6, 4)))\n",
                "ax[0].vlines(np.arange(512), np.zeros(512), rec_dict['spikes'][0,min_idx,:], label='gt', alpha=0.1, color='grey')\n",
                "ax[1].vlines(np.arange(512), np.zeros(512), rec_dict['rec_spikes'][0,min_idx,:], label='ae', alpha=0.1, color='darkblue')\n",
                "ax[0].set_xlim(0, int(torch.sum(rec_dict['signal_masks'][0,0]).item()))\n",
                "ax[0].set_xticks([])\n",
                "ax[1].set_xlim(0, int(torch.sum(rec_dict['signal_masks'][0,0]).item()))\n",
                "plt.legend\n",
                "fig.tight_layout()\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plot_corrcoef(rec_dict, figsize=cm2inch(20, 6))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "data['original_sentence'][0], data['phonemized_sentence'][0]"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "rec_dict['spikes'].sum(2).flatten().shape"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "timeseries",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.9.18"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
