{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import argparse\n",
                "import os\n",
                "import sys\n",
                "\n",
                "# append parent directory to path (../notebooks -> ..)\n",
                "sys.path.append(os.path.dirname(os.getcwd()))\n",
                "os.chdir(os.path.dirname(os.getcwd()))\n",
                "\n",
                "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
                "\n",
                "import accelerate\n",
                "import auraloss  # freq loss\n",
                "import lovely_tensors as lt\n",
                "import matplotlib.pyplot as plt\n",
                "import matplotlib\n",
                "import numpy as np\n",
                "import scipy.io as io\n",
                "import seaborn as sns\n",
                "import torch\n",
                "import torch.nn as nn\n",
                "import wandb\n",
                "import yaml\n",
                "from diffusers.optimization import get_scheduler\n",
                "from omegaconf import OmegaConf\n",
                "from scipy.signal import welch\n",
                "from tqdm.auto import tqdm\n",
                "\n",
                "from ntldm.data.lds import get_lds_dataset\n",
                "from ntldm.data.latent_attractor import get_attractor_dataset\n",
                "from ntldm.data.lds import get_lds_dataset\n",
                "from ntldm.networks import S4AE, AutoEncoder, CountWrapper\n",
                "from ntldm.utils.plotting_utils import *\n",
                "from ntldm.losses import latent_regularizer\n",
                "\n",
                "# always run from ../ntldm\n",
                "\n",
                "\n",
                "lt.monkey_patch()\n",
                "matplotlib.rc_file('matplotlibrc')\n",
                "\n",
                "%load_ext autoreload\n",
                "%autoreload 2\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# cfg_yaml = \"\"\"\n",
                "# model:\n",
                "#   C_in: 128\n",
                "#   C: 256\n",
                "#   C_latent: 16\n",
                "#   kernel: s4\n",
                "#   num_blocks: 4\n",
                "#   num_lin_per_mlp: 2\n",
                "# dataset:\n",
                "#   system_name: LDS\n",
                "#   signal_length: 1024\n",
                "#   n_ic: 6000\n",
                "#   mean_rate: 0.25\n",
                "#   split_frac: 0.9\n",
                "#   random_seed: 42\n",
                "#   softplus_beta: 2.0\n",
                "# training:\n",
                "#   lr: 0.001\n",
                "#   weight_decay: 0.0\n",
                "#   num_epochs: 50\n",
                "#   num_warmup_epochs: 1\n",
                "#   batch_size: 64\n",
                "#   random_seed: 42\n",
                "#   precision: bf16\n",
                "#   latent_beta: 0.001\n",
                "#   latent_td_beta: 0.001\n",
                "#   mask_prob: 0.2\n",
                "# exp_name: autoencoder-count_s4-FHN_newarch_mask0.2_linearinout_mean0.5\n",
                "# \"\"\"\n",
                "\n",
                "\n",
                "cfg_yaml = \"\"\"\n",
                "model:\n",
                "  C_in: 182\n",
                "  C: 256\n",
                "  C_latent: 16\n",
                "  kernel: s4\n",
                "  num_blocks: 4\n",
                "  num_blocks_decoder: 0\n",
                "  num_lin_per_mlp: 2\n",
                "  bidirectional: False\n",
                "dataset:\n",
                "  system_name: monkey\n",
                "  task: mc_maze\n",
                "  datapath: data/000128/sub-Jenkins/\n",
                "  signal_length: 140\n",
                "training:\n",
                "  lr: 0.001\n",
                "  num_epochs: 260\n",
                "  num_warmup_epochs: 10\n",
                "  batch_size: 64\n",
                "  random_seed: 42\n",
                "  precision: bf16\n",
                "  latent_beta: 0.001\n",
                "  latent_td_beta: 0.2\n",
                "  tk_k: 5\n",
                "  mask_prob: 0.5\n",
                "exp_name: autoencoder-count_s4-monkey_new\n",
                "\"\"\"\n",
                "\n",
                "# omegaconf from yaml\n",
                "cfg = OmegaConf.create(yaml.safe_load(cfg_yaml))\n",
                "print(OmegaConf.to_yaml(cfg))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import math\n",
                "from ntldm.data.monkey import get_monkey_dataloaders\n",
                "\n",
                "# set seed\n",
                "torch.manual_seed(cfg.training.random_seed)\n",
                "np.random.seed(cfg.training.random_seed)\n",
                "\n",
                "train_loader, val_loader, test_loader = get_monkey_dataloaders(\n",
                "        cfg.dataset.task, cfg.dataset.datapath, bin_width=5, batch_size=cfg.training.batch_size\n",
                "    )"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "train_dataloader = train_loader\n",
                "val_dataloader = val_loader\n",
                "test_dataloader = test_loader"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# train_dataloader, val_dataloader, test_dataloader = accelerator.prepare(train_dataloader, val_dataloader, test_dataloader)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Visualise dataset"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# plt.figure(figsize=cm2inch((6, 4)))\n",
                "# plt.hist(train_loader.dataset.dataset.samples[0].flatten(), density=True, cumulative=True, histtype='step')\n",
                "# # plt.yscale('log')\n",
                "\n",
                "plt.imshow(train_loader.dataset[0]['signal'], aspect='auto', cmap='Greys')\n",
                "plt.colorbar()\n",
                "\n",
                "# from ntldm.utils.plotting_utils import plot_dataset_visualizations, cm2inch\n",
                "# plt.rcParams['figure.dpi'] = 300\n",
                "# plot_dataset_visualizations(train_loader.dataset, indices=[0, 1, 5, 90], figsize=cm2inch((6, 8)))"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Define model, Optimizer and Learning rate scheduler\n",
                "\n",
                "- wrap accelerator around "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
                "\n",
                "\n",
                "ae = AutoEncoder(\n",
                "    C_in=cfg.model.C_in,\n",
                "    C=cfg.model.C,\n",
                "    C_latent=cfg.model.C_latent,\n",
                "    L=cfg.dataset.signal_length,\n",
                "    kernel=cfg.model.kernel,\n",
                "    num_blocks=cfg.model.num_blocks,\n",
                "    num_blocks_decoder=cfg.model.get(\"num_blocks_decoder\", cfg.model.num_blocks),\n",
                "    num_lin_per_mlp=cfg.model.get(\"num_lin_per_mlp\", 2),  # default 2\n",
                "    bidirectional=cfg.model.get(\"bidirectional\", True),\n",
                ")\n",
                "\n",
                "print(\"Number of params\", sum(p.numel() for p in ae.parameters() if p.requires_grad)/1e6, \"M\")\n",
                "\n",
                "ae = CountWrapper(ae, use_sin_enc=cfg.model.get(\"use_sin_enc\", False))\n",
                "print(ae)\n",
                "\n",
                "ae = ae.to(device)\n",
                "optimizer = torch.optim.AdamW(\n",
                "    ae.parameters(), lr=cfg.training.lr\n",
                ")  # default wd=0.01 for now\n",
                "\n",
                "num_batches = len(train_dataloader)\n",
                "lr_scheduler = get_scheduler(\n",
                "    name=\"cosine\",\n",
                "    optimizer=optimizer,\n",
                "    num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "    num_training_steps=num_batches * cfg.training.num_epochs * 1.1,  # total number of steps\n",
                "    # num_training_steps=cfg.training.num_epochs # AS changed\n",
                ")\n",
                "\n",
                "# check if signal length is power of 2\n",
                "if cfg.dataset.signal_length & (cfg.dataset.signal_length - 1) != 0:\n",
                "    cfg.training.precision = \"no\"  # torch.fft doesnt support half if L!=2^x\n",
                "\n",
                "\n",
                "#\n",
                "\n",
                "# Accelerate setuo\n",
                "\n",
                "accelerator = accelerate.Accelerator(\n",
                "    mixed_precision=cfg.training.precision,\n",
                "    log_with=\"wandb\",\n",
                ")\n",
                "\n",
                "\n",
                "(\n",
                "    ae,\n",
                "    optimizer,\n",
                "    lr_scheduler,\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ") = accelerator.prepare(\n",
                "    ae,\n",
                "    optimizer,\n",
                "    lr_scheduler,\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ")\n",
                "\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# set up losses"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def compute_val_loss(net, dataloader):\n",
                "    net.eval()\n",
                "    poisson_loss_total = 0\n",
                "    rates_loss_total = 0\n",
                "    batch_count = 0\n",
                "\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = net(signal)[0].cpu()\n",
                "\n",
                "        signal = signal.cpu()  # move signal to cpu\n",
                "\n",
                "        # compute pointwise l2 loss\n",
                "        poisson_loss = criterion_poisson(output_rates, signal)\n",
                "        poisson_loss_total += poisson_loss.mean().item()\n",
                "        batch_count += 1\n",
                "\n",
                "    # compute average losses over all batches\n",
                "    avg_poisson_loss = poisson_loss_total / batch_count * cfg.training.mask_prob\n",
                "    print(f\"Validation loss: {avg_poisson_loss:.4f}, mask_prob {cfg.training.mask_prob}\")\n",
                "\n",
                "    fig, ax = plt.subplots(2, 1, figsize=(10, 2), dpi=300)\n",
                "    for row in range(2):  # plot channels 0 and 71\n",
                "        ax[row].plot(output_rates[0, 92 * (row)].cpu().clip(0, 3).numpy(), label=\"pred\")\n",
                "        ax[row].plot(\n",
                "            batch[\"signal\"][0, 92 * (row)].cpu().clip(0, 3).numpy(),\n",
                "            label=\"spikes\",\n",
                "            alpha=0.5,\n",
                "            color=\"grey\",\n",
                "        )\n",
                "        plt.legend()\n",
                "    wandb.log({\"val_rates\": wandb.Image(fig)})\n",
                "    plt.close(fig)\n",
                "\n",
                "    return avg_poisson_loss\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "accelerator.load_state(f'exp/{cfg.exp_name}/epoch_180')\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import cm2inch\n",
                "from einops import rearrange\n",
                "    \n",
                "def plot_rate_traces_real(model, dataloader,figsize=(12, 5), idx=0):\n",
                "    model.eval()\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = model(signal)[0].cpu()\n",
                "\n",
                "        signal = signal.cpu()  # move signal to cpu\n",
                "        break\n",
                "\n",
                "    channels = [136, 8]\n",
                "    fig, ax = plt.subplots(len(channels), 1, figsize=cm2inch(figsize), dpi=150)\n",
                "\n",
                "    for i, channel in enumerate(channels):\n",
                "        print(batch[\"signal\"][idx, channel])\n",
                "        L = batch[\"signal\"][idx, channel].shape[0]\n",
                "        ax[i].vlines(\n",
                "            torch.arange(L),\n",
                "            torch.zeros(L),\n",
                "            torch.ones(L)*output_rates[idx, channel].cpu().max().item(),\n",
                "            # batch[\"signal\"][0, channel].cpu().numpy(),\n",
                "            # label=\"spikes\",\n",
                "            color=\"black\",\n",
                "            alpha=np.min(\n",
                "                np.stack(\n",
                "                    (np.ones(L), batch[\"signal\"][idx, channel].cpu().numpy() * 0.2),\n",
                "                    axis=1,\n",
                "                ),\n",
                "                axis=1,\n",
                "            ),\n",
                "        )\n",
                "        ax[i].plot(output_rates[idx, channel].cpu().numpy(), label=\"pred\", color=\"red\")\n",
                "        ax[i].set_title(f\"channel {channel}\")\n",
                "\n",
                "    ax[-1].legend()\n",
                "\n",
                "    fig.suptitle(\"rate traces for channels\")\n",
                "    fig.tight_layout()\n",
                "    plt.show()\n",
                "\n",
                "\n",
                "def imshow_rates_real(model, dataloader,figsize=(12, 5), idx=0):\n",
                "    model.eval()\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = model(signal)[0].cpu()\n",
                "\n",
                "        signal = signal.cpu()  # move signal to cpu\n",
                "        break\n",
                "\n",
                "    fig, ax = plt.subplots(2, 1, figsize=cm2inch(figsize), dpi=150)\n",
                "\n",
                "    im1 = ax[0].imshow(output_rates[idx].cpu().numpy(), label=\"rates\", aspect='auto', cmap='Greys')\n",
                "    im2 = ax[1].imshow(signal[idx].cpu().numpy(), label=\"rates\", aspect='auto', cmap='Greys')\n",
                "    plt.colorbar(im1, ax=ax[0])\n",
                "    plt.colorbar(im2, ax=ax[1])\n",
                "    plt.show()\n",
                "\n",
                "    # ax.set_title(f\"channel {channel}\")\n",
                "\n",
                "    ax[-1].legend()\n",
                "\n",
                "    fig.suptitle(f\"infeered rates, idx {idx}\")\n",
                "    fig.tight_layout()\n",
                "    plt.show()\n",
                "\n",
                "\n",
                "def plot_latents_with_colors(latents, behavior_train):\n",
                "    def compute_angles(position_arr):\n",
                "        \"\"\"Compute angles between the initial and final positions.\"\"\"\n",
                "        # Starting positions are the first element in the sequence\n",
                "        start_positions = position_arr[:, 0, :]\n",
                "        # Final positions are the last element in the sequence\n",
                "        final_positions = position_arr[:, -1, :]\n",
                "\n",
                "        # Calculate the difference in positions\n",
                "        delta_positions = final_positions - start_positions\n",
                "\n",
                "        # Calculate angles using arctan2 which takes into account the sign of both vectors components\n",
                "        angles = np.arctan2(delta_positions[:, 1], delta_positions[:, 0])\n",
                "        return angles\n",
                "\n",
                "    def classify_angles(angles):\n",
                "        \"\"\"Classify angles into 8 classes from -pi to pi.\"\"\"\n",
                "        # Define the edges of the 8 classes\n",
                "        edges = np.linspace(-np.pi, np.pi, num=9)  # 8 sections\n",
                "        # Use digitize to get the indices of the bins to which each value in input array belongs.\n",
                "        angle_classes = np.digitize(angles, edges) - 1\n",
                "\n",
                "        # Initialize a dictionary to hold the indices for each class\n",
                "        class_indices = {i: np.where(angle_classes == i)[0] for i in range(8)}\n",
                "        return class_indices\n",
                "\n",
                "    position_arr = np.cumsum(behavior_train.cpu(), axis=1)\n",
                "    angles = compute_angles(position_arr[:, :50, :])\n",
                "    angle_indices = classify_angles(angles)\n",
                "\n",
                "    angle_centers = np.array([-2.74889357, -1.96349541, -1.17809725, -0.39269908, 0.39269908, 1.17809725, 1.96349541, 2.74889357])\n",
                "    colors = np.array([[1.0, 0.34742682, 0.0, 1.0], [0.91139597, 1.0, 0.0, 1.0], [0.17021876, 1.0, 0.0, 1.0], [0.0, 1.0, 0.57095525, 1.0], [0.0, 0.68787024, 1.0, 1.0], [0.07646876, 0.0, 1.0, 1.0], [0.81764597, 0.0, 1.0, 1.0], [1.0, 0.0, 0.44117682, 1.0]])\n",
                "\n",
                "    mean_trajectory = []\n",
                "    for class_theta in range(8):\n",
                "        mean_trajectory.append(np.mean(latents.numpy()[angle_indices[class_theta], :, :], axis=0))\n",
                "\n",
                "    fig, axs = plt.subplots(2, 4, figsize=(12, 3))\n",
                "    axs = axs.flatten()\n",
                "\n",
                "    for q in range(4):\n",
                "        for classth in range(8):\n",
                "            axs[q].plot(mean_trajectory[classth][q, :], label=f\"$\\\\theta=$ {angle_centers[classth]:.2f}\", color=colors[classth])\n",
                "\n",
                "    axs[0].legend(bbox_to_anchor=(0.6, 1), loc=\"best\", title=\"angle(residuals of latents below)\", title_fontsize=\"small\", ncol=8)\n",
                "\n",
                "    for q in range(4, 8):\n",
                "        for classth in range(8):\n",
                "            axs[q].plot(mean_trajectory[classth][q-4, :] - np.mean(np.stack([mean_trajectory[classth][q-4] for classth in range(8)], axis=0), axis=0), label=f\"angle {angle_centers[classth]:.2f}\", color=colors[classth])\n",
                "\n",
                "    # plt.tight_layout()\n",
                "    plt.show()\n",
                "\n",
                "\n",
                "def compute_latents(model, dataloader):\n",
                "    model.eval()\n",
                "    latents = []\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates, z = model(signal)\n",
                "            z = z.cpu()\n",
                "        latents.append(z)\n",
                "    return torch.cat(latents, 0)\n",
                "            \n",
                "behavior_test = test_dataloader.dataset.dataset.behavior.permute(0, 2, 1)\n",
                "behavior_test = behavior_test[len(behavior_test)//4:] # hack to get test split and not val split\n",
                "behavior_train = train_dataloader.dataset.behavior.permute(0, 2, 1)\n",
                "\n",
                "\n",
                "## decoding the smoothed rates to behavior\n",
                "\n",
                "def decode_rates_to_behavior(model, train_dataloader, val_dataloader):\n",
                "\n",
                "    model.eval()\n",
                "\n",
                "    train_rates = []\n",
                "    train_behavior = []\n",
                "    for batch in train_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        behavior = batch[\"behavior\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = model(signal)[0].cpu()\n",
                "        train_rates.append(output_rates)\n",
                "        train_behavior.append(behavior.cpu())\n",
                "        \n",
                "    train_rates = torch.cat(train_rates, 0) # [B C L]\n",
                "    train_behavior = torch.cat(train_behavior, 0) # [B 2 L]\n",
                "    print(train_rates, train_behavior)\n",
                "    \n",
                "    val_rates = []\n",
                "    val_behavior = []\n",
                "    for batch in val_dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        behavior = batch[\"behavior\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = model(signal)[0].cpu()\n",
                "        val_rates.append(output_rates)\n",
                "        val_behavior.append(behavior.cpu())\n",
                "\n",
                "    val_rates = torch.cat(val_rates, 0) # [B C L]\n",
                "    val_behavior = torch.cat(val_behavior, 0) # [B 2 L]\n",
                "    print(val_rates, val_behavior)\n",
                "\n",
                "    # decode rates to behavior\n",
                "    from sklearn.linear_model import Ridge\n",
                "    \n",
                "    bs_train = train_rates.shape[0]\n",
                "    bs_val = val_rates.shape[0]\n",
                "    train_rates = rearrange(train_rates, 'b c l -> (b l) c').numpy()\n",
                "    val_rates = rearrange(val_rates, 'b c l -> (b l) c').numpy()\n",
                "    train_behavior = rearrange(train_behavior, 'b c l -> (b l) c').numpy()/1e3\n",
                "    val_behavior = rearrange(val_behavior, 'b c l -> (b l) c').numpy()/1e3\n",
                "\n",
                "    r2s_per_alpha = []\n",
                "\n",
                "    for i, alpha in enumerate(np.logspace(-9, 0, 9)):\n",
                "        RidgeRegressionModel = Ridge(alpha=alpha)\n",
                "        RidgeRegressionModel.fit(train_rates, train_behavior)\n",
                "        predicted_behavior = RidgeRegressionModel.predict(val_rates)\n",
                "        # r2 score\n",
                "        from sklearn.metrics import r2_score\n",
                "        r2 = r2_score(val_behavior, predicted_behavior)\n",
                "        r2s_per_alpha.append(r2)\n",
                "    \n",
                "        print(f\"Alpha {alpha:.5f} R2 score: {r2:.4f}\")\n",
                "        \n",
                "        if i == 0:\n",
                "            fig, ax = plt.subplots(2, 4, figsize=(6, 2), dpi=200)\n",
                "            ax = ax.flatten()\n",
                "            for i_j, j in enumerate(range(0, 64, 8)):\n",
                "                predicted_behavior_rearranged = rearrange(predicted_behavior, '(b l) c -> b c l', b=bs_val)\n",
                "                val_behavior_rearranged = rearrange(val_behavior, '(b l) c -> b c l', b=bs_val)\n",
                "                ax[i_j].plot(val_behavior_rearranged[j, 0, :], color='black', label='target(x)')\n",
                "                ax[i_j].plot(predicted_behavior_rearranged[j, 0, :], color='red', label='predicted(x)')\n",
                "                ax[i_j].plot(val_behavior_rearranged[j, 1, :], color='black', linestyle=':', label='target(y)')\n",
                "                ax[i_j].plot(predicted_behavior_rearranged[j, 1, :], color='red', linestyle=':', label='predicted(y)')\n",
                "                if i_j < 4:\n",
                "                    ax[i_j].set_xticks([])\n",
                "            \n",
                "            fig.suptitle(f'Predicted hand velocities from smoothed rates', y=-0.02)\n",
                "            # create a common legend\n",
                "            handles, labels = ax[0].get_legend_handles_labels()\n",
                "            # legend just below title\n",
                "            fig.legend(handles, labels, loc='upper center', ncol=4)\n",
                "            fig.tight_layout()\n",
                "            plt.show()\n",
                "\n",
                "\n",
                "\n",
                "\n",
                "    \n",
                "    plt.figure(figsize=cm2inch((5, 3)), dpi=150)\n",
                "    plt.plot(np.logspace(-5.1, 0, 9), r2s_per_alpha, '-o')\n",
                "    plt.xscale('log')\n",
                "    plt.xlabel('alpha')\n",
                "    plt.xticks(np.logspace(-5, 0, 6))\n",
                "    plt.ylabel('R2 score')\n",
                "    plt.title(f'Rates -> behavior\\n(max={np.max(r2s_per_alpha):.4f} at alpha={np.logspace(-5, 0, 6)[np.argmax(r2s_per_alpha)]})')\n",
                "    plt.show()\n",
                "    return np.max(r2s_per_alpha)\n",
                "\n",
                "\n",
                "decode_rates_to_behavior(ae, train_dataloader, val_dataloader)\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# eval load\n",
                "accelerator.load_state(f'exp/{cfg.exp_name}/epoch_180') # best checkpoint\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## training loop"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# train loop\n",
                "criterion_poisson = nn.PoissonNLLLoss(log_input=False, full=True, reduction=\"none\")\n",
                "\n",
                "rec_losses, latent_losses, total_losses, lrs, val_rate_losses = [], [], [], [], []\n",
                "avg_poisson_loss, avg_rate_loss = 0, 0\n",
                "import wandb\n",
                "\n",
                "# os.environ[\"WANDB_MODE\"] = \"online\"\n",
                "wandb.init(project=\"ntldm\", entity=\"anon-project\")\n",
                "with tqdm(range(0, cfg.training.num_epochs)) as pbar:\n",
                "    for epoch in pbar:\n",
                "        ae.train()\n",
                "\n",
                "        for i, data in enumerate(train_dataloader):\n",
                "            optimizer.zero_grad()\n",
                "\n",
                "            signal = data[\"signal\"]\n",
                "\n",
                "            # applying mask (coordinated dropout)\n",
                "            mask_prob = cfg.training.get(\"mask_prob\", 0.25)\n",
                "            mask = (\n",
                "                torch.rand_like(signal) > mask_prob\n",
                "            ).float()  # if mask_prob=0.2, 80% will be 1 and rest 0\n",
                "            input_signal = signal * (\n",
                "                mask / (1 - mask_prob)\n",
                "            )  # mask and scale unmasked by 1/(1-p)\n",
                "\n",
                "            output_rates, z = ae(input_signal)\n",
                "\n",
                "            numel = signal.shape[0] * signal.shape[1] * signal.shape[2]\n",
                "\n",
                "            # computing loss on masked parts\n",
                "            unmasked = (1 - mask) if mask_prob > 0 else torch.ones_like(mask)\n",
                "            poisson_loss = criterion_poisson(output_rates, signal) * unmasked\n",
                "            poisson_loss = poisson_loss.mean()\n",
                "\n",
                "            rec_loss = poisson_loss\n",
                "\n",
                "            latent_loss = latent_regularizer(z, cfg) / numel\n",
                "            loss = rec_loss + cfg.training.latent_beta * latent_loss\n",
                "\n",
                "            accelerator.backward(loss)\n",
                "            accelerator.clip_grad_norm_(ae.parameters(), 2.0)\n",
                "\n",
                "            optimizer.step()\n",
                "            lr_scheduler.step()\n",
                "\n",
                "            pbar.set_postfix(\n",
                "                **{\n",
                "                    \"rec_loss\": rec_loss.item(),\n",
                "                    \"latent_loss\": latent_loss.item(),\n",
                "                    \"total_loss\": loss.item(),\n",
                "                    \"lr\": optimizer.param_groups[0][\"lr\"],\n",
                "                    \"val_poisson_loss\": avg_poisson_loss,\n",
                "                }\n",
                "            )\n",
                "            rec_losses.append(rec_loss.item())\n",
                "            latent_losses.append(latent_loss.item())\n",
                "            total_losses.append(loss.item())\n",
                "            lrs.append(optimizer.param_groups[0][\"lr\"])\n",
                "            wandb.log(\n",
                "                {\n",
                "                    \"rec_loss\": rec_loss.item(),\n",
                "                    \"latent_loss\": latent_loss.item(),\n",
                "                    \"total_loss\": loss.item(),\n",
                "                    \"lr\": optimizer.param_groups[0][\"lr\"],\n",
                "                    \"epoch\": epoch,\n",
                "                }\n",
                "            )\n",
                "        # eval\n",
                "\n",
                "        if accelerator.is_main_process and (\n",
                "            (epoch + 1) % 10 == 0 or epoch == cfg.training.num_epochs - 1\n",
                "        ):\n",
                "            avg_poisson_loss = compute_val_loss(ae, val_dataloader)\n",
                "            wandb.log({\"val_poisson_loss\": avg_poisson_loss})\n",
                "        if accelerator.is_main_process and (\n",
                "            (epoch + 1) % 20 == 0 or epoch == cfg.training.num_epochs - 1\n",
                "        ):\n",
                "\n",
                "            ae.eval()\n",
                "            plot_rate_traces_real(ae, val_dataloader, figsize=(6, 6), idx=1)\n",
                "            imshow_rates_real(ae, val_dataloader, figsize=(6, 3), idx=1)\n",
                "            latents = compute_latents(ae, train_dataloader)\n",
                "            plot_latents_with_colors(latents, behavior_train)\n",
                "            r2_val = decode_rates_to_behavior(ae, train_dataloader, val_dataloader)\n",
                "            wandb.log({\"r2_val\": r2_val})\n",
                "            accelerator.save_state(f\"exp/{cfg.exp_name}/epoch_{(epoch+20)//20*20}\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## eval"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "## decoding the smoothed rates to behavior\n",
                "\n",
                "decode_rates_to_behavior(ae, train_dataloader, val_dataloader)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def reconstruct_spikes(model, dataloader):\n",
                "    model.eval()\n",
                "    latents = []\n",
                "    spikes = []\n",
                "    rec_spikes = []\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates, z = model(signal)\n",
                "            z = z.cpu()\n",
                "        latents.append(z)\n",
                "        spikes.append(signal.cpu())\n",
                "        rec_spikes.append(torch.poisson(output_rates.cpu()))\n",
                "        \n",
                "    return {\n",
                "        'latents': torch.cat(latents, 0),\n",
                "        'spikes': torch.cat(spikes, 0),\n",
                "        'rec_spikes': torch.cat(rec_spikes, 0)\n",
                "    }\n",
                "\n",
                "rec_dict = reconstruct_spikes(ae, test_dataloader)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "# plot reconstructed spikes\n",
                "plt.figure(figsize=cm2inch((6, 4)))\n",
                "bins = np.linspace(0, 20, 20)\n",
                "plt.hist(rec_dict['spikes'].sum(2).flatten(), density=True, color='grey', bins=bins, alpha=0.5)\n",
                "plt.hist(rec_dict['rec_spikes'].sum(2).flatten(), density=True, color='darkblue', bins=bins, alpha=0.5)\n",
                "\n",
                "plt.legend(['gt', 'ae'])\n",
                "plt.title('spike count distribution (test set)')\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "rec_dict['spikes'].sum(2).flatten().shape"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def compute_latents(model, dataloader):\n",
                "    model.eval()\n",
                "    latents = []\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates, z = model(signal)\n",
                "            z = z.cpu()\n",
                "        latents.append(z)\n",
                "    return torch.cat(latents, 0)\n",
                "            \n",
                "latents_test = compute_latents(ae, test_dataloader)\n",
                "latents_train = compute_latents(ae, train_dataloader)\n",
                "\n",
                "behavior_test = test_dataloader.dataset.dataset.behavior.permute(0, 2, 1)\n",
                "behavior_test = behavior_test[len(behavior_test)//4:] # hack to get test split and not val split\n",
                "behavior_train = train_dataloader.dataset.behavior.permute(0, 2, 1)\n",
                "\n",
                "print(latents_test.shape, behavior_test.shape, latents_train.shape, behavior_train.shape)\n",
                "        \n",
                "\n",
                "    \n",
                "def plot_inferred_latents(model, dataloader, figsize=(10, 4)):\n",
                "    model.eval()\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        # real_rates = batch[\"rates\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates, z = model(signal)\n",
                "            output_rates = output_rates.cpu()\n",
                "            z = z.cpu()\n",
                "\n",
                "        signal = signal.cpu()  # move signal to cpu\n",
                "        # real_rates = real_rates.cpu()\n",
                "        break\n",
                "\n",
                "    fig, ax = plt.subplots(1, 4, figsize=figsize, dpi=150)\n",
                "    ax = ax.flatten()\n",
                "    for i in range(4):\n",
                "        ax[i].plot(z[0, i].numpy())\n",
                "        ax[i].set_title(f\"latent {i}\")\n",
                "        # ax[i].set_yticks([])\n",
                "        # ax[i].set_ylim(z[0].min(), z[0].max())\n",
                "        # ax[i].set_yticks([])\n",
                "    fig.suptitle(\"inferred latents\")\n",
                "    plt.tight_layout()\n",
                "    plt.show()\n",
                "\n",
                "plot_inferred_latents(ae, val_dataloader, figsize=(10, 2))\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## classification from latents to behavior (velocities) using S4"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "f"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def inverse_softplus(x: torch.Tensor) -> torch.Tensor:\n",
                "    return x + torch.log(-torch.expm1(-x))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "val_dataloader.dataset.dataset.behavior"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# accelerator.load_state(f'exp/{cfg.exp_name}/epoch_220')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "ae.eval()\n",
                "for batch in val_dataloader:\n",
                "    signal = batch[\"signal\"]\n",
                "    with torch.no_grad():\n",
                "        output_rates = ae(signal)[0].cpu()\n",
                "\n",
                "    spikes = signal.cpu()  # move signal to cpu\n",
                "    break\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.imshow(output_rates[0].cpu().numpy(), aspect='auto')\n",
                "plt.colorbar()\n",
                "plt.figure()\n",
                "plt.imshow(signal[0].cpu().numpy(), aspect='auto')\n",
                "plt.colorbar()\n",
                "\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.eval_utils import *\n",
                "from einops import rearrange\n",
                "avg_rates = average_rates(spikes.permute(0, 2, 1))\n",
                "print(avg_rates.shape)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "avg_inferred_rates = output_rates.mean((0, 2))\n",
                "avg_inferred_rates.shape, avg_rates.shape"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.scatter(avg_rates, avg_inferred_rates)\n",
                "plt.plot([0, avg_rates.max()], [0, avg_inferred_rates.max()], color='red')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plot_rate_comparisons(spikes.permute(0, 2, 1), [output_rates.permute(0, 2, 1),], mode='neur')\n",
                "plot_rate_comparisons(spikes.permute(0, 2, 1), [output_rates.permute(0, 2, 1),], mode='neurtime')\n",
                "plot_rate_comparisons(spikes.permute(0, 2, 1), [output_rates.permute(0, 2, 1),], mode='neursample')\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "def compute_angles(position_arr):\n",
                "    \"\"\"Compute angles between the initial and final positions.\"\"\"\n",
                "    # Starting positions are the first element in the sequence\n",
                "    start_positions = position_arr[:, 0, :]\n",
                "    # Final positions are the last element in the sequence\n",
                "    final_positions = position_arr[:, -1, :]\n",
                "\n",
                "    # Calculate the difference in positions\n",
                "    delta_positions = final_positions - start_positions\n",
                "\n",
                "    # Calculate angles using arctan2 which takes into account the sign of both vectors components\n",
                "    angles = np.arctan2(delta_positions[:, 1], delta_positions[:, 0])\n",
                "    return angles\n",
                "\n",
                "def classify_angles(angles):\n",
                "    \"\"\"Classify angles into 8 classes from -pi to pi.\"\"\"\n",
                "    # Define the edges of the 8 classes\n",
                "    edges = np.linspace(-np.pi, np.pi, num=9)  # 8 sections\n",
                "    # Use digitize to get the indices of the bins to which each value in input array belongs.\n",
                "    angle_classes = np.digitize(angles, edges) - 1\n",
                "\n",
                "    # Initialize a dictionary to hold the indices for each class\n",
                "    class_indices = {i: np.where(angle_classes == i)[0] for i in range(8)}\n",
                "    return class_indices\n",
                "\n",
                "\n",
                "def plot_latents(latents, behavior_train):\n",
                "    \n",
                "    position_arr = np.cumsum(behavior_train.cpu(), axis=1)\n",
                "    angles = compute_angles(position_arr[:, :50, :])\n",
                "    angle_indices = classify_angles(angles)\n",
                "\n",
                "    angle_centers = np.array([-2.74889357, -1.96349541, -1.17809725, -0.39269908, 0.39269908, 1.17809725, 1.96349541, 2.74889357])\n",
                "    colors = np.array([[1.0, 0.34742682, 0.0, 1.0], [0.91139597, 1.0, 0.0, 1.0], [0.17021876, 1.0, 0.0, 1.0], [0.0, 1.0, 0.57095525, 1.0], [0.0, 0.68787024, 1.0, 1.0], [0.07646876, 0.0, 1.0, 1.0], [0.81764597, 0.0, 1.0, 1.0], [1.0, 0.0, 0.44117682, 1.0]])\n",
                "\n",
                "    mean_trajectory = []\n",
                "    for class_theta in range(8):\n",
                "        mean_trajectory.append(np.mean(latents.numpy()[angle_indices[class_theta], :, :], axis=0))\n",
                "\n",
                "    fig, axs = plt.subplots(2, 4, figsize=(12, 3))\n",
                "    axs = axs.flatten()\n",
                "\n",
                "    for q in range(4):\n",
                "        for classth in range(8):\n",
                "            axs[q].plot(mean_trajectory[classth][q, :], label=f\"$\\\\theta=$ {angle_centers[classth]:.2f}\", color=colors[classth])\n",
                "\n",
                "    axs[0].legend(bbox_to_anchor=(0.6, 1), loc=\"best\", title=\"angle(residuals of latents below)\", title_fontsize=\"small\", ncol=8)\n",
                "\n",
                "    for q in range(4, 8):\n",
                "        for classth in range(8):\n",
                "            axs[q].plot(mean_trajectory[classth][q-4, :] - np.mean(np.stack([mean_trajectory[classth][q-4] for classth in range(8)], axis=0), axis=0), label=f\"angle {angle_centers[classth]:.2f}\", color=colors[classth])\n",
                "\n",
                "    # plt.tight_layout()\n",
                "    plt.show()\n",
                "\n",
                "\n",
                "plot_latents(latents_train, behavior_train)\n",
                "# # Compute the angles for the provided data\n",
                "# import matplotlib.colors as mcolors\n",
                "# import matplotlib.cm as cm\n",
                "\n",
                "\n",
                "# angles = compute_angles(position_arr).numpy()\n",
                "\n",
                "# # Map the angles to colors using a colormap\n",
                "# norm = mcolors.Normalize(vmin=np.min(angles), vmax=np.max(angles))\n",
                "# cmap = cm.hsv\n",
                "# colors = cmap(norm(angles))\n",
                "\n",
                "# # Plot the traces with colors according to the angle\n",
                "# plt.figure(figsize=cm2inch(6, 5), dpi=300)\n",
                "# for i in range(position_arr.shape[0]):\n",
                "#     plt.plot(position_arr[i,:,0], position_arr[i,:,1], '.', color=colors[i], ms=0.1)\n",
                "\n",
                "# # create a colorbar associated with the axes and the colormap\n",
                "# sm = cm.ScalarMappable(norm=norm, cmap=cmap)\n",
                "# sm.set_array([])  # You have to set_array for ScalarMappable\n",
                "# cbar = plt.colorbar(sm, ax=plt.gca(),label='angle')\n",
                "# plt.xlabel('x position')\n",
                "# plt.ylabel('y position')\n",
                "# plt.xlim(-20, 20)\n",
                "# plt.ylim(-20, 20)\n",
                "# plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "angles"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "position_arr = np.cumsum(behavior_train.cpu(), axis=1)\n",
                "angles = compute_angles(position_arr[:, :50, :])\n",
                "\n",
                "import matplotlib.cm as cm      \n",
                "import matplotlib.colors as mcolors\n",
                "\n",
                "\n",
                "\n",
                "def classify_angles(angles):\n",
                "    \"\"\"Classify angles into 8 classes from -pi to pi.\"\"\"\n",
                "    # Define the edges of the 8 classes\n",
                "    edges = np.linspace(-np.pi, np.pi, num=9)  # 8 sections\n",
                "    # Use digitize to get the indices of the bins to which each value in input array belongs.\n",
                "    angle_classes = np.digitize(angles, edges) - 1\n",
                "\n",
                "    # Initialize a dictionary to hold the indices for each class\n",
                "    class_indices = {i: np.where(angle_classes == i)[0] for i in range(8)}\n",
                "    return class_indices, edges\n",
                "\n",
                "\n",
                "# Call the function and get the indices\n",
                "angle_indices, angle_edges = classify_angles(angles)\n",
                "angle_edges\n",
                "angle_indices\n",
                "angle_edges\n",
                "# get the middle between the edges\n",
                "angle_centers = (angle_edges[:-1] + angle_edges[1:]) / 2\n",
                "\n",
                "angles = compute_angles(position_arr).numpy()\n",
                "\n",
                "# Map the angles to colors using a colormap\n",
                "norm = mcolors.Normalize(vmin=np.min(angles), vmax=np.max(angles))\n",
                "cmap = cm.hsv\n",
                "colors = cmap(norm(angles))\n",
                "\n",
                "mean_trajectory = []\n",
                "for class_theta in range(8):\n",
                "    mean_trajectory.append(\n",
                "        np.mean(latents_train.numpy()[angle_indices[class_theta], :, :], axis=0)\n",
                "    )\n",
                "\n",
                "\n",
                "fig, axs = plt.subplots(4, 4, figsize=(10, 8))\n",
                "axs = axs.flatten()\n",
                "for q in range(16):\n",
                "    for classth in range(8):\n",
                "        axs[q].plot(\n",
                "            mean_trajectory[classth][q, :],\n",
                "            # - np.mean(\n",
                "            #     np.stack([mean_trajectory[classth][q] for classth in range(8)], axis=0),\n",
                "            #     axis=0,\n",
                "            # ),\n",
                "            label=f\"$\\\\theta=${angle_centers[classth]/np.pi:.2f}$\\\\pi$\",\n",
                "            color=colors[classth],\n",
                "        )\n",
                "\n",
                "# set the legend of the entire thing just using the first subplot\n",
                "# make sure the legend has same width as the plot\n",
                "axs[0].legend(\n",
                "    bbox_to_anchor=(0.6, 1),\n",
                "    loc=\"best\",\n",
                "    title=\"angle(residuals of latents below)\",\n",
                "    title_fontsize=\"small\",\n",
                "    ncol=4,\n",
                ")\n",
                "\n",
                "fig.tight_layout()\n",
                "plt.show()\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "timeseries",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.9.18"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
