{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import argparse\n",
                "import os\n",
                "import sys\n",
                "\n",
                "# append parent directory to path (../notebooks -> ..)\n",
                "sys.path.append(os.path.dirname(os.getcwd()))\n",
                "os.chdir(os.path.dirname(os.getcwd()))\n",
                "\n",
                "import accelerate\n",
                "import auraloss  # freq loss\n",
                "import lovely_tensors as lt\n",
                "import matplotlib.pyplot as plt\n",
                "import matplotlib\n",
                "import numpy as np\n",
                "import scipy.io as io\n",
                "import seaborn as sns\n",
                "import torch\n",
                "import torch.nn as nn\n",
                "import wandb\n",
                "import yaml\n",
                "from diffusers.optimization import get_scheduler\n",
                "from omegaconf import OmegaConf\n",
                "from scipy.signal import welch\n",
                "from tqdm.auto import tqdm\n",
                "\n",
                "from ntldm.data.lds import get_lds_dataset\n",
                "from ntldm.data.latent_attractor import get_attractor_dataset\n",
                "from ntldm.data.lds import get_lds_dataset\n",
                "from ntldm.networks import S4AE, AutoEncoder, CountWrapper\n",
                "from ntldm.utils.plotting_utils import *\n",
                "from ntldm.losses import latent_regularizer, latent_regularizer_v2\n",
                "\n",
                "# always run from ../ntldm\n",
                "\n",
                "\n",
                "lt.monkey_patch()\n",
                "matplotlib.rc_file('matplotlibrc')\n",
                "\n",
                "%load_ext autoreload\n",
                "%autoreload 2\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# cfg_yaml = \"\"\"\n",
                "# model:\n",
                "#   C_in: 128\n",
                "#   C: 256\n",
                "#   C_latent: 16\n",
                "#   kernel: s4\n",
                "#   num_blocks: 4\n",
                "#   num_lin_per_mlp: 2\n",
                "# dataset:\n",
                "#   system_name: LDS\n",
                "#   signal_length: 1024\n",
                "#   n_ic: 6000\n",
                "#   mean_rate: 0.25\n",
                "#   split_frac: 0.9\n",
                "#   random_seed: 42\n",
                "#   softplus_beta: 2.0\n",
                "# training:\n",
                "#   lr: 0.001\n",
                "#   weight_decay: 0.0\n",
                "#   num_epochs: 50\n",
                "#   num_warmup_epochs: 1\n",
                "#   batch_size: 64\n",
                "#   random_seed: 42\n",
                "#   precision: bf16\n",
                "#   latent_beta: 0.001\n",
                "#   latent_td_beta: 0.001\n",
                "#   mask_prob: 0.2\n",
                "# exp_name: autoencoder-count_s4-FHN_newarch_mask0.2_linearinout_mean0.5\n",
                "# \"\"\"\n",
                "\n",
                "\n",
                "cfg_yaml = \"\"\"\n",
                "model:\n",
                "  C_in: 128\n",
                "  C: 256\n",
                "  C_latent: 8 # or 8\n",
                "  kernel: s4\n",
                "  num_blocks: 4\n",
                "  num_blocks_decoder: 0\n",
                "  num_lin_per_mlp: 2\n",
                "dataset:\n",
                "  system_name: Lorenz\n",
                "  signal_length: 256\n",
                "  n_ic: 5000\n",
                "  mean_rate: 0.5\n",
                "  split_frac_train: 0.7\n",
                "  split_frac_val: 0.1\n",
                "  random_seed: 42\n",
                "  softplus_beta: 2.0\n",
                "training:\n",
                "  lr: 0.001\n",
                "  weight_decay: 0.0\n",
                "  num_epochs: 200\n",
                "  num_warmup_epochs: 10\n",
                "  batch_size: 512\n",
                "  random_seed: 42\n",
                "  precision: bf16\n",
                "  latent_beta: 0.001\n",
                "  latent_td_beta: 0.01\n",
                "  mask_prob: 0.2\n",
                "exp_name: autoencoder-count_s4-Lorenz_z=8_new_regularization_05\n",
                "\"\"\"\n",
                "#exp_name: autoencoder-count_s4-Lorenz_z=8_true_pointwise_decoder_with_test_low_rate_new_reg_005\n",
                "\n",
                "# omegaconf from yaml\n",
                "cfg = OmegaConf.create(yaml.safe_load(cfg_yaml))\n",
                "print(OmegaConf.to_yaml(cfg))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import math\n",
                "\n",
                "# set seed\n",
                "torch.manual_seed(cfg.training.random_seed)\n",
                "np.random.seed(cfg.training.random_seed)\n",
                "\n",
                "if cfg.dataset.system_name==\"LDS\":    \n",
                "    train_dataloader, val_dataloader, test_dataloader = get_lds_dataset(\n",
                "        n_latents=6,\n",
                "        n_neurons=cfg.model.C_in,\n",
                "        sequence_length=cfg.dataset.signal_length,\n",
                "        rotation_angles=[math.pi / 111, math.pi / 72, math.pi / 27],\n",
                "        n_ic=cfg.dataset.n_ic//10,\n",
                "        n_reps=10,\n",
                "        noise_variance=0.02,\n",
                "        mean_spike_count=cfg.dataset.signal_length*cfg.dataset.mean_rate,\n",
                "        batch_size=cfg.training.batch_size,\n",
                "        train_frac=0.7,\n",
                "        valid_frac=0.15,\n",
                "        random_seed=cfg.training.random_seed,\n",
                "        time_last=True,\n",
                "    )\n",
                "\n",
                "elif cfg.dataset.system_name==\"Lorenz\":\n",
                "        \n",
                "    train_dataloader, val_dataloader, test_dataloader = get_attractor_dataset(\n",
                "        system_name=cfg.dataset.system_name,\n",
                "        n_neurons=cfg.model.C_in,\n",
                "        sequence_length=cfg.dataset.signal_length,\n",
                "        noise_std=0.05,\n",
                "        n_ic=cfg.dataset.n_ic,\n",
                "        #mean_spike_count=cfg.dataset.get(\"mean_rate\", 0.5) * cfg.dataset.signal_length,\n",
                "        mean_spike_count=cfg.dataset.mean_rate * cfg.dataset.signal_length, \n",
                "        train_frac=cfg.dataset.split_frac_train,\n",
                "        valid_frac=cfg.dataset.split_frac_val, # test is 1 - train - valid\n",
                "        random_seed=cfg.training.random_seed,\n",
                "        batch_size=cfg.training.batch_size,\n",
                "        softplus_beta=cfg.dataset.get(\"softplus_beta\", 2.0),\n",
                "    )\n",
                "else:\n",
                "    raise ValueError(\"Unknown system name\")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Visualise dataset"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_dataset_visualizations_order, cm2inch\n",
                "plot_dataset_visualizations_order(train_dataloader.dataset.dataset, indices=[0, 1, 5, 90], figsize=cm2inch((6, 8)), green=True)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plot_dataset_visualizations_order(val_dataloader.dataset.dataset, indices=[0, 1, 5, 90], figsize=cm2inch((6, 8)), green=True)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plot_dataset_visualizations_order(test_dataloader.dataset.dataset, indices=[0, 1, 5, 90], figsize=cm2inch((6, 8)), green=True)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Define model, Optimizer and Learning rate scheduler\n",
                "\n",
                "- wrap accelerator around "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
                "\n",
                "\n",
                "ae = AutoEncoder(\n",
                "    C_in=cfg.model.C_in,\n",
                "    C=cfg.model.C,\n",
                "    C_latent=cfg.model.C_latent,\n",
                "    L=cfg.dataset.signal_length,\n",
                "    kernel=cfg.model.kernel,\n",
                "    num_blocks=cfg.model.num_blocks,\n",
                "    num_blocks_decoder=cfg.model.num_blocks_decoder,\n",
                "    num_lin_per_mlp=cfg.model.get(\"num_lin_per_mlp\", 2),  # default 2\n",
                ")\n",
                "\n",
                "ae = CountWrapper(ae, use_sin_enc=cfg.model.get(\"use_sin_enc\", False))\n",
                "\n",
                "ae = ae.to(device)\n",
                "optimizer = torch.optim.AdamW(\n",
                "    ae.parameters(), lr=cfg.training.lr\n",
                ")  # default wd=0.01 for now\n",
                "\n",
                "num_batches = len(train_dataloader)\n",
                "lr_scheduler = get_scheduler(\n",
                "    name=\"cosine\",\n",
                "    optimizer=optimizer,\n",
                "    # step_rules=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "    num_warmup_steps=num_batches * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs\n",
                "    num_training_steps=num_batches * cfg.training.num_epochs * 1.3,  # total number of steps\n",
                "    # num_training_steps=cfg.training.num_epochs # AS changed\n",
                ")\n",
                "\n",
                "# check if signal length is power of 2\n",
                "if cfg.dataset.signal_length & (cfg.dataset.signal_length - 1) != 0:\n",
                "    cfg.training.precision = \"no\"  # torch.fft doesnt support half if L!=2^x\n",
                "\n",
                "\n",
                "#\n",
                "\n",
                "# Accelerate setuo\n",
                "\n",
                "accelerator = accelerate.Accelerator(\n",
                "    mixed_precision=cfg.training.precision,\n",
                "    log_with=\"wandb\",\n",
                ")\n",
                "\n",
                "\n",
                "(\n",
                "    ae,\n",
                "    optimizer,\n",
                "    lr_scheduler,\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ") = accelerator.prepare(\n",
                "    ae,\n",
                "    optimizer,\n",
                "    lr_scheduler,\n",
                "    train_dataloader,\n",
                "    val_dataloader,\n",
                "    test_dataloader,\n",
                ")\n",
                "\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# set up losses"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "    \n",
                "def compute_val_loss(net, dataloader):\n",
                "    net.eval()\n",
                "    poisson_loss_total = 0\n",
                "    rates_loss_total = 0\n",
                "    batch_count = 0\n",
                "\n",
                "    for batch in dataloader:\n",
                "        signal = batch[\"signal\"]\n",
                "        real_rates = batch[\"rates\"]\n",
                "        with torch.no_grad():\n",
                "            output_rates = net(signal)[0].cpu()\n",
                "\n",
                "        signal = signal.cpu()  # move signal to cpu\n",
                "        real_rates = real_rates.cpu()\n",
                "\n",
                "        # compute pointwise l2 loss\n",
                "        poisson_loss = criterion_poisson(output_rates, signal)\n",
                "        poisson_loss_total += poisson_loss.mean().item()\n",
                "        rates_loss_total += ((output_rates - real_rates) ** 2).mean().item()\n",
                "\n",
                "        batch_count += 1\n",
                "\n",
                "    # compute average losses over all batches\n",
                "    avg_poisson_loss = poisson_loss_total / batch_count * cfg.training.mask_prob\n",
                "\n",
                "    fig, ax = plt.subplots(2, 1, figsize=cm2inch((10, 2)), dpi=300)\n",
                "    for row in range(2):  # plot channels 0 and 71\n",
                "        ax[row].plot(output_rates[0, 92 * (row)].cpu().clip(0, 3).numpy(), label=\"pred\")\n",
                "        ax[row].plot(\n",
                "            batch[\"signal\"][0, 92 * (row)].cpu().clip(0, 3).numpy(),\n",
                "            label=\"spikes\",\n",
                "            alpha=0.5,\n",
                "            color=\"grey\",\n",
                "        )\n",
                "        ax[row].plot(\n",
                "            batch[\"rates\"][0, 71 * (row)].cpu().clip(0, 3).numpy(), label=\"real\"\n",
                "        )\n",
                "        plt.legend()\n",
                "    #accelerator.log({\"val_rates\": wandb.Image(fig)})\n",
                "    plt.close(fig)\n",
                "\n",
                "    return avg_poisson_loss, rates_loss_total / batch_count\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# train loop\n",
                "criterion_poisson = nn.PoissonNLLLoss(log_input=False, full=True, reduction=\"none\")\n",
                "\n",
                "rec_losses, latent_losses, total_losses, lrs, val_rate_losses = [], [], [], [], []\n",
                "avg_poisson_loss, avg_rate_loss = 0, 0\n",
                "wandb.init(project=\"ntldm-lorenz\", entity=\"anon-project\")\n",
                "\n",
                "with tqdm(range(0, cfg.training.num_epochs)) as pbar:\n",
                "    for epoch in pbar:\n",
                "        ae.train()\n",
                "\n",
                "        for i, data in enumerate(train_dataloader):\n",
                "            optimizer.zero_grad()\n",
                "\n",
                "            signal = data[\"signal\"]\n",
                "\n",
                "            # applying mask (coordinated dropout)\n",
                "            mask_prob = cfg.training.get(\"mask_prob\", 0.25)\n",
                "            mask = (\n",
                "                torch.rand_like(signal) > mask_prob\n",
                "            ).float()  # if mask_prob=0.2, 80% will be 1 and rest 0\n",
                "            input_signal = signal * (\n",
                "                mask / (1 - mask_prob)\n",
                "            )  # mask and scale unmasked by 1/(1-p)\n",
                "\n",
                "            output_rates, z = ae(input_signal)\n",
                "\n",
                "            numel = signal.shape[0] * signal.shape[1] * signal.shape[2]\n",
                "\n",
                "            # computing loss on masked parts\n",
                "            unmasked = (1 - mask) if mask_prob > 0 else torch.ones_like(mask)\n",
                "            poisson_loss = criterion_poisson(output_rates, signal) * unmasked\n",
                "            poisson_loss = poisson_loss.mean()\n",
                "\n",
                "            rec_loss = poisson_loss\n",
                "\n",
                "            latent_loss = latent_regularizer_v2(z, cfg) / numel\n",
                "            loss = rec_loss + cfg.training.latent_beta * latent_loss\n",
                "\n",
                "            accelerator.backward(loss)\n",
                "            accelerator.clip_grad_norm_(ae.parameters(), 2.0)\n",
                "\n",
                "            optimizer.step()\n",
                "            lr_scheduler.step()\n",
                "\n",
                "            pbar.set_postfix(\n",
                "                **{\n",
                "                    \"rec_loss\": rec_loss.item(),\n",
                "                    \"latent_loss\": latent_loss.item(),\n",
                "                    \"total_loss\": loss.item(),\n",
                "                    \"lr\": optimizer.param_groups[0][\"lr\"],\n",
                "                    \"epoch\": epoch,\n",
                "                    \"val_poisson_loss\": avg_poisson_loss,\n",
                "                    \"val_rate_loss\": avg_rate_loss,\n",
                "                }\n",
                "            )\n",
                "            rec_losses.append(rec_loss.item())\n",
                "            latent_losses.append(latent_loss.item())\n",
                "            total_losses.append(loss.item())\n",
                "            lrs.append(optimizer.param_groups[0][\"lr\"])\n",
                "            \n",
                "            wandb.log(\n",
                "                {\n",
                "                    \"rec_loss\": rec_loss.item(),\n",
                "                    \"latent_loss\": latent_loss.item(),\n",
                "                    \"total_loss\": loss.item(),\n",
                "                    \"lr\": optimizer.param_groups[0][\"lr\"],\n",
                "                    \"epoch\": epoch,\n",
                "                }\n",
                "            )\n",
                "\n",
                "        # eval\n",
                "        if accelerator.is_main_process and (\n",
                "            (epoch + 1) % 5 == 0 or epoch == cfg.training.num_epochs - 1\n",
                "        ):\n",
                "            avg_poisson_loss, avg_rate_loss = compute_val_loss(ae, val_dataloader)\n",
                "            val_rate_losses.append(avg_rate_loss)\n",
                "            wandb.log({\"val_poisson_loss\": avg_poisson_loss})\n",
                "\n",
                "        if accelerator.is_main_process and epoch % 50 == 0:\n",
                "            # plotting\n",
                "            plot_rec_rates(ae, val_dataloader)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "save_model = True\n",
                "load_model = False\n",
                "\n",
                "if save_model:\n",
                "\n",
                "    # save model and config file\n",
                "    if accelerator.is_main_process:\n",
                "        os.makedirs(f\"exp/{cfg.exp_name}\", exist_ok=True)\n",
                "        torch.save(accelerator.unwrap_model(ae).state_dict(), f\"exp/{cfg.exp_name}/model.pt\")\n",
                "        os.makedirs(f\"conf/sweeps_count/\", exist_ok=True)\n",
                "        with open(f\"conf/sweeps_count/{cfg.exp_name}.yaml\", \"w\") as f:\n",
                "            f.write(OmegaConf.to_yaml(cfg))\n",
                "            \n",
                "        print(f\"Saved model and config to exp/{cfg.exp_name}\")\n",
                "        print(f\"Saved model and config conf/sweeps_count/{cfg.exp_name}.yaml\")\n",
                "            \n",
                "            \n",
                "    accelerator.save_state(f\"exp/{cfg.exp_name}/accelerator_state_end_epoch\")\n",
                "    \n",
                "elif load_model:\n",
                "    accelerator.load_state(f\"exp/{cfg.exp_name}/accelerator_state_end_epoch\")\n",
                "    \n",
                "else:\n",
                "    print(\"Choose either save_model or load_model\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_inferred_latents\n",
                "\n",
                "plot_inferred_latents(ae, test_dataloader, n_latents=8,\n",
                "                      y_stack=4, figsize=cm2inch(10, 10), color='royalblue', idx=1)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from ntldm.utils.plotting_utils import plot_rate_traces\n",
                "plot_rate_traces(ae, test_dataloader, idx=10, figsize=cm2inch(10, 10), true_data=False)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plot_spikes_next_to_each_other(ae, test_dataloader, idx=45, figsize=cm2inch(15, 5))\n",
                "\n",
                "plot_spikes_next_to_each_other(ae, test_dataloader, idx=45, figsize=cm2inch(15, 5), binary=True)\n",
                "plot_spikes_next_to_each_other(ae, test_dataloader, idx=10, figsize=cm2inch(15, 5), binary=False)\n",
                "plot_spikes_next_to_each_other(ae, test_dataloader, idx=10, figsize=cm2inch(15, 5), binary=True)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "_1, _2 , test_dataloader_longer = get_attractor_dataset(\n",
                "    system_name=cfg.dataset.system_name,\n",
                "    n_neurons=cfg.model.C_in,\n",
                "    sequence_length=cfg.dataset.signal_length * 16,\n",
                "    noise_std=0.05,\n",
                "    n_ic=cfg.dataset.n_ic//10,\n",
                "    mean_spike_count=cfg.dataset.mean_rate * cfg.dataset.signal_length * 16,\n",
                "    train_frac=cfg.dataset.split_frac_train,\n",
                "    valid_frac=cfg.dataset.split_frac_val, # test is 1 - train - valid\n",
                "    random_seed=cfg.training.random_seed,\n",
                "    batch_size=cfg.training.batch_size//16,\n",
                "    softplus_beta=cfg.dataset.get(\"softplus_beta\", 2.0),\n",
                ")\n",
                "\n",
                "test_dataloader_longer = accelerator.prepare(test_dataloader_longer)\n",
                "\n",
                "plot_rate_traces(ae, test_dataloader_longer, figsize=cm2inch(20, 10))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "from mpl_toolkits.mplot3d import Axes3D\n",
                "from matplotlib.animation import FuncAnimation\n",
                "from matplotlib import cm\n",
                "%matplotlib widget\n",
                "from ntldm.utils.plotting_utils import plot_3d_latent_trajectory\n",
                "plot_3d_latent_trajectory(ae, test_dataloader_longer, figsize=cm2inch(6, 6), sample_idx=7, indices=[2, 1, 0,])\n",
                "plot_3d_latent_trajectory(ae, test_dataloader, figsize=cm2inch(6, 6), sample_idx=8, indices=[2, 1, 0,])\n",
                "\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Run all analyses"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def evaluate_autoencoder(ae, val_dataloader,val_dataloader_longer, n_latents=8, save=False, save_path=None, idx=0):\n",
                "    save_path = save_path+'/ae_figures/'\n",
                "    os.makedirs(save_path, exist_ok=True)\n",
                "    \n",
                "    from ntldm.utils.plotting_utils import plot_rate_traces, plot_inferred_latents, plot_3d_latent_trajectory, plot_spikes_next_to_each_other, cm2inch\n",
                "\n",
                "    plot_rate_traces(ae, val_dataloader, idx=idx, figsize=cm2inch(10, 10), true_data=False, save=save, save_path=save_path+'rate_traces')\n",
                "    \n",
                "    plot_spikes_next_to_each_other(ae, val_dataloader, idx=idx, figsize=cm2inch(15, 5), save=save, save_path=save_path+'spikes_next_to_each_other')\n",
                "    \n",
                "    # run all sorts of analyses\n",
                "    plot_inferred_latents(ae, val_dataloader, n_latents=8,\n",
                "                      y_stack=4, figsize=cm2inch(10, 10), color='midnightblue', idx=idx, \n",
                "                         save=save, save_path=save_path+'inferred_latents')\n",
                "    \n",
                "    plot_3d_latent_trajectory(ae, val_dataloader, figsize=cm2inch(8, 8), sample_idx=idx,\n",
                "                              save=save, save_path=save_path+'3d_latent_trajectory')\n",
                "        \n",
                "    \n",
                "    # long validation dataloader\n",
                "    plot_rate_traces(ae, val_dataloader_longer, figsize=cm2inch(20, 10),idx=idx,\n",
                "                     true_data=False, save=save, save_path=save_path+'rate_traces_longer')\n",
                "    \n",
                "    plot_3d_latent_trajectory(ae, val_dataloader_longer, figsize=cm2inch(8, 8), sample_idx=idx,\n",
                "                              save=save, save_path=save_path+'3d_latent_trajectory_longer')\n",
                "    \n",
                "    \n",
                "save_path = 'exp/'+cfg.exp_name\n",
                "save_path = '/home/anonauthor/anonloc1/results/projects/latent-diffusion'\n",
                "\n",
                "_1, _2 , test_dataloader_longer = get_attractor_dataset(\n",
                "    system_name=cfg.dataset.system_name,\n",
                "    n_neurons=cfg.model.C_in,\n",
                "    sequence_length=cfg.dataset.signal_length * 16,\n",
                "    noise_std=0.05,\n",
                "    n_ic=cfg.dataset.n_ic//10,\n",
                "    mean_spike_count=cfg.dataset.mean_rate * cfg.dataset.signal_length * 16,\n",
                "    train_frac=cfg.dataset.split_frac_train,\n",
                "    valid_frac=cfg.dataset.split_frac_val, # test is 1 - train - valid\n",
                "    random_seed=cfg.training.random_seed,\n",
                "    batch_size=cfg.training.batch_size//16,\n",
                "    softplus_beta=cfg.dataset.get(\"softplus_beta\", 2.0),\n",
                ")\n",
                "test_dataloader_longer = accelerator.prepare(test_dataloader_longer)\n",
                "\n",
                "\n",
                "evaluate_autoencoder(ae, test_dataloader,test_dataloader_longer, n_latents=8, save=True, save_path=save_path, idx=2)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "test_dataloader_longer = accelerator.prepare(test_dataloader_longer)\n",
                "\n",
                "\n",
                "evaluate_autoencoder(ae, test_dataloader,test_dataloader_longer, n_latents=8, save=True, save_path=save_path, idx=2)\n",
                "\n",
                "# release cuda memory\n",
                "del ae"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "timeseries",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.8.18"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
