{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ab12f31b-b486-4d80-a6ea-7ef47ac614a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b11186f2-ec68-4943-b13a-c4321b74857d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-01-30 03:13:58.966942: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "E0000 00:00:1738206838.994247   35023 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "E0000 00:00:1738206839.002662   35023 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "/root/workspace/GitHub/pmlr-v202-finzi23a/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from collections import defaultdict\n",
    "import functools\n",
    "import itertools\n",
    "import pprint\n",
    "\n",
    "import orbax.checkpoint\n",
    "import numpy as np\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import torch.utils.data.dataloader\n",
    "import tensorflow as tf\n",
    "import sqlalchemy as sa\n",
    "import seaborn as sns\n",
    "sns.set_theme(style='whitegrid', font_scale=1.3, palette=sns.color_palette('husl'),)\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "\n",
    "from userdiffusion import samplers, unet\n",
    "from userfm import cs, datasets, event_constraints, diffusion, sde_diffusion, flow_matching, utils, main as main_module, plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "82fd6df7-944a-4ac8-b1e9-d30dce4864a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# somehow, this line of code prevents a segmentation fault in nn.Dense\n",
    "# when calling model.init\n",
    "tf.config.experimental.set_visible_devices([], 'GPU')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "26f2a2a9-1783-4794-bfcb-8cf73dde97a5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<sqlalchemy.orm.session.SessionTransaction at 0x7f7a5e6a6600>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "engine = cs.get_engine()\n",
    "cs.create_all(engine)\n",
    "session = cs.orm.Session(engine)\n",
    "session.begin()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "abdb5d9c-b9a4-4ed4-9170-92195c5f8b09",
   "metadata": {},
   "outputs": [],
   "source": [
    "config_alt_ids = {\n",
    "    # Lorenz\n",
    "    # ('0y35hp7d', 'DM'): {},\n",
    "    # ('3bjjfgwa', 'FM'): {'sample': {'use_score': True}},\n",
    "    # ('c0ijllm1', 'FM+Reg'): {'sample': {'use_score': True}},\n",
    "    # FitzHughNagumo\n",
    "    ('jzke0dh6', 'DM'): {}, # jzke0dh6, epoch_1999, false, cs.DatasetFitzHughNagumo\n",
    "    ('trauw532', 'FM'): {'sample': {'use_score': True}}, # trauw532, epoch_1999, false, cs.DatasetFitzHughNagumo\n",
    "    ('7io88gsu', 'FM+Reg'): {'sample': {'use_score': True}}, # 7io88gsu, epoch_1999, false, cs.DatasetFitzHughNagumo\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "67950db0-e006-4b07-a1e6-9372399ff2fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfgs = session.execute(sa.select(cs.Config).where(cs.Config.alt_id.in_([c[0] for c in config_alt_ids])))\n",
    "cfgs = {c.alt_id: c for (c,) in cfgs}\n",
    "reference_cfg = cfgs[next(iter(cfgs.keys()))]\n",
    "cfg_info = {}\n",
    "for k in config_alt_ids:\n",
    "    cfg = cfgs[k[0]]\n",
    "    assert cfg.rng_seed == reference_cfg.rng_seed\n",
    "    assert cfg.dataset == reference_cfg.dataset\n",
    "    cfg_info[k] = dict(\n",
    "        cfg=cfg,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "14e6d166-d85a-487e-acb4-dc6b729da54e",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.key(reference_cfg.rng_seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f22b503a-c4e9-4838-8ff6-b1c21daedb76",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6450/6450 [21:00<00:00,  5.12it/s]\n"
     ]
    }
   ],
   "source": [
    "key, key_dataset = jax.random.split(key)\n",
    "reference_cfg.dataset.batch_count_test = 64\n",
    "ds = datasets.get_dataset(reference_cfg.dataset, key=key_dataset)\n",
    "splits = datasets.split_dataset(reference_cfg.dataset, ds)\n",
    "dataloaders = {}\n",
    "for n, s in splits.items():\n",
    "    dataloaders[n] = torch.utils.data.dataloader.DataLoader(\n",
    "        list(tf.data.Dataset.from_tensor_slices(s).batch(reference_cfg.dataset.batch_size).as_numpy_iterator()),\n",
    "        batch_size=1,\n",
    "        collate_fn=lambda x: x[0],\n",
    "    )\n",
    "data_std = splits['train'].std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3d9f7a72-ecb8-4e67-8532-a11237e53524",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/workspace/GitHub/pmlr-v202-finzi23a/.venv/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1330: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()\n",
    "x_sample = next(iter(dataloaders['train']))\n",
    "ckpt_name = 'epoch_1999'\n",
    "\n",
    "for info in cfg_info.values():\n",
    "    cfg = info['cfg']\n",
    "    assert cfg.rng_seed == reference_cfg.rng_seed\n",
    "    assert cfg.dataset == reference_cfg.dataset\n",
    "\n",
    "    cfg_unet = unet.unet_64_config(\n",
    "        splits['train'].shape[-1],\n",
    "        base_channels=cfg.model.architecture.base_channel_count,\n",
    "        attention=cfg.model.architecture.attention,\n",
    "    )\n",
    "    model = unet.UNet(cfg_unet)\n",
    "    \n",
    "    key, key_jaxlightning = jax.random.split(key)\n",
    "    if isinstance(cfg.model, cs.ModelDiffusion):\n",
    "        jax_lightning = diffusion.JaxLightning(cfg, key_jaxlightning, dataloaders, data_std, None, model)\n",
    "    elif isinstance(cfg.model, cs.ModelFlowMatching):\n",
    "        jax_lightning = flow_matching.JaxLightning(cfg, key_jaxlightning, dataloaders, data_std, None, model)\n",
    "    else:\n",
    "        raise ValueError(f'Unknown model: {cfg.model}')\n",
    "        \n",
    "    jax_lightning.params = orbax_checkpointer.restore(cfg.run_dir/ckpt_name)\n",
    "    jax_lightning.params_ema = orbax_checkpointer.restore(cfg.run_dir/f'{ckpt_name}_ema')\n",
    "\n",
    "    info['jax_lightning'] = jax_lightning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b2cd2005-a3f4-4b90-8bc1-c7d3c2b8197b",
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint = event_constraints.get_event_constraint(cfg.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "018cdb30-0c27-4849-ad77-6d761bca1dac",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                                                                                                                                                                          | 0/64 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 70.22528076171875 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|███████▌                                                                                                                                                                                                                                          | 2/64 [01:43<52:45, 51.06s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 74.07532501220703 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|███████████▎                                                                                                                                                                                                                                      | 3/64 [02:31<50:39, 49.83s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 72.88363647460938 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|███████████████▏                                                                                                                                                                                                                                  | 4/64 [03:20<49:14, 49.24s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 71.77883911132812 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|██████████████████▉                                                                                                                                                                                                                               | 5/64 [04:08<47:58, 48.79s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 74.27108764648438 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|██████████████████████▋                                                                                                                                                                                                                           | 6/64 [04:55<46:47, 48.40s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 73.96844482421875 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|██████████████████████████▍                                                                                                                                                                                                                       | 7/64 [05:42<45:21, 47.75s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 71.61772155761719 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|██████████████████████████████▎                                                                                                                                                                                                                   | 8/64 [06:28<44:12, 47.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 74.14404296875 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|██████████████████████████████████                                                                                                                                                                                                                | 9/64 [07:14<42:58, 46.87s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 71.82096862792969 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█████████████████████████████████████▋                                                                                                                                                                                                           | 10/64 [08:01<42:17, 47.00s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 71.85791015625 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|█████████████████████████████████████████▍                                                                                                                                                                                                       | 11/64 [08:48<41:34, 47.07s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 71.61241149902344 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 19%|█████████████████████████████████████████████▏                                                                                                                                                                                                   | 12/64 [09:36<40:59, 47.29s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 86.2186508178711 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|████████████████████████████████████████████████▉                                                                                                                                                                                                | 13/64 [10:25<40:33, 47.72s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 74.12003326416016 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|████████████████████████████████████████████████████▋                                                                                                                                                                                            | 14/64 [11:12<39:37, 47.54s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 73.4232406616211 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 23%|████████████████████████████████████████████████████████▍                                                                                                                                                                                        | 15/64 [12:00<38:50, 47.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 73.73075103759766 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 25%|████████████████████████████████████████████████████████████▎                                                                                                                                                                                    | 16/64 [12:47<38:05, 47.62s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 70.26171875 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 27%|████████████████████████████████████████████████████████████████                                                                                                                                                                                 | 17/64 [13:36<37:26, 47.79s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 73.88910675048828 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|███████████████████████████████████████████████████████████████████▊                                                                                                                                                                             | 18/64 [14:23<36:30, 47.63s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 70.24373626708984 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███████████████████████████████████████████████████████████████████████▌                                                                                                                                                                         | 19/64 [15:11<35:48, 47.76s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 70.95075225830078 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 31%|███████████████████████████████████████████████████████████████████████████▎                                                                                                                                                                     | 20/64 [16:00<35:18, 48.14s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 72.40174102783203 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 33%|███████████████████████████████████████████████████████████████████████████████                                                                                                                                                                  | 21/64 [16:47<34:20, 47.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 72.1256103515625 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 34%|██████████████████████████████████████████████████████████████████████████████████▊                                                                                                                                                              | 22/64 [17:37<33:57, 48.52s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 73.83699035644531 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|██████████████████████████████████████████████████████████████████████████████████████▌                                                                                                                                                          | 23/64 [18:25<33:03, 48.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 73.29657745361328 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|██████████████████████████████████████████████████████████████████████████████████████████▍                                                                                                                                                      | 24/64 [19:13<32:05, 48.14s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 72.83537292480469 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 39%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                                                                                                                  | 25/64 [20:03<31:36, 48.64s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 69.63021087646484 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 41%|█████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                                                                                               | 26/64 [20:51<30:42, 48.49s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 72.29289245605469 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|█████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                                                                           | 27/64 [21:38<29:36, 48.01s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 72.2069320678711 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                                                                                       | 28/64 [22:25<28:38, 47.74s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 76.44893646240234 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 45%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                                                                                   | 29/64 [23:12<27:44, 47.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 76.86505126953125 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 47%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                                                                                | 30/64 [24:01<27:07, 47.88s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 70.1072998046875 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                                                            | 31/64 [24:48<26:16, 47.76s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 72.1590805053711 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                                        | 32/64 [25:35<25:15, 47.35s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 85.25033569335938 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 52%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                                                                    | 33/64 [26:22<24:32, 47.51s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 75.59387969970703 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 53%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                                 | 34/64 [27:11<23:59, 47.97s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 74.04283142089844 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 55%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                                                             | 35/64 [28:01<23:21, 48.32s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 71.77711486816406 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                         | 36/64 [28:49<22:29, 48.21s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 75.75042724609375 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                                                     | 37/64 [29:37<21:47, 48.43s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 73.31507873535156 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 59%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                  | 38/64 [30:24<20:45, 47.92s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 73.52549743652344 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 61%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                                              | 39/64 [31:11<19:46, 47.47s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 72.28459930419922 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                          | 40/64 [31:57<18:53, 47.22s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 71.66193389892578 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 64%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                                      | 41/64 [32:46<18:18, 47.76s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 74.83935546875 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                                  | 42/64 [33:34<17:31, 47.80s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 71.40079498291016 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 67%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                               | 43/64 [34:21<16:37, 47.50s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 74.15164184570312 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 69%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                           | 44/64 [35:09<15:55, 47.79s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 72.7043228149414 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                       | 45/64 [35:57<15:09, 47.87s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 76.62957000732422 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 72%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                   | 46/64 [36:46<14:23, 47.95s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 71.76721954345703 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 73%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                | 47/64 [37:33<13:33, 47.87s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 74.98538970947266 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 75%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                            | 48/64 [38:22<12:48, 48.04s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 72.36785125732422 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 77%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                        | 49/64 [39:10<12:02, 48.19s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 74.51898193359375 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 78%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                    | 50/64 [39:57<11:09, 47.83s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 72.78501892089844 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                 | 51/64 [40:45<10:21, 47.79s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 71.39997863769531 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 81%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                             | 52/64 [41:32<09:31, 47.62s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 74.11803436279297 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 83%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                         | 53/64 [42:20<08:44, 47.70s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 71.78073120117188 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 84%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                     | 54/64 [43:08<07:58, 47.83s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 72.95416259765625 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 86%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                  | 55/64 [43:56<07:10, 47.81s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 73.7789077758789 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                              | 56/64 [44:43<06:20, 47.58s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 73.1640853881836 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 89%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                          | 57/64 [45:30<05:31, 47.32s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 76.65107727050781 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 91%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                      | 58/64 [46:17<04:44, 47.43s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 78.8259506225586 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 92%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                  | 59/64 [47:04<03:55, 47.14s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 72.083740234375 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 94%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉               | 60/64 [47:52<03:09, 47.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 73.5163803100586 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 95%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋           | 61/64 [48:41<02:23, 47.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 68.40653228759766 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 97%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍       | 62/64 [49:29<01:35, 47.90s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 75.61605834960938 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏   | 63/64 [50:16<00:47, 47.60s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 82.04244995117188 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [51:04<00:00, 47.88s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xf std 80.90313720703125 and std_max: 300.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [46:10<00:00, 43.29s/it]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [42:01<00:00, 39.40s/it]\n"
     ]
    }
   ],
   "source": [
    "nlls = defaultdict(lambda: defaultdict(list))\n",
    "for (config_alt_id, source), info in cfg_info.items():\n",
    "    # use same key for each model\n",
    "    _, key_nll = jax.random.split(key)\n",
    "    for batch in tqdm(dataloaders['test']):\n",
    "        key_nll, key_nll_batch = jax.random.split(key_nll)\n",
    "        if isinstance(info['cfg'].model, cs.ModelFlowMatching):\n",
    "            x_noise, nll_no_div, nll = info['jax_lightning'].compute_nll(key_nll_batch, 1., batch, **config_alt_ids[k]['sample'])\n",
    "        else:\n",
    "            x_noise, nll_no_div, nll = info['jax_lightning'].compute_nll(key_nll_batch, 1., batch)\n",
    "        nlls[(config_alt_id, source)]['nll_no_div'].append(nll_no_div)\n",
    "        nlls[(config_alt_id, source)]['nll'].append(nll)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b348fea6-cf1c-4bb0-a381-296e43db4776",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "k=('jzke0dh6', 'DM'), NLL=-7.365\n",
      "k=('trauw532', 'FM'), NLL=-13.942\n",
      "k=('7io88gsu', 'FM+Reg'), NLL=-14.408\n"
     ]
    }
   ],
   "source": [
    "nlls_concat = defaultdict(dict)\n",
    "for k, out in nlls.items():\n",
    "    for kk, arr in out.items():\n",
    "        nlls_concat[k][kk] = jnp.concat(arr)\n",
    "    print(f\"{k=}, NLL={nlls_concat[k]['nll'].mean():.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ea9cf712-76fa-4542-b6e8-462b792e941d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pickle\n",
    "# with open('nlls_fitzhugh.pkl', 'wb') as f:\n",
    "#     pickle.dump(dict(jax.tree.map(np.array,nlls)), f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "df7db466-f486-434f-abcd-f19c4407f266",
   "metadata": {},
   "outputs": [],
   "source": [
    "# with open('nlls_fitzhugh.pkl', 'rb') as f:\n",
    "#     test_nll = pickle.load(f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
