{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ab12f31b-b486-4d80-a6ea-7ef47ac614a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b11186f2-ec68-4943-b13a-c4321b74857d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-01-29 00:19:27.893851: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "E0000 00:00:1738109967.916173   22323 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "E0000 00:00:1738109967.922031   22323 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "/root/workspace/GitHub/pmlr-v202-finzi23a/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import functools\n",
    "import itertools\n",
    "import pprint\n",
    "\n",
    "import orbax.checkpoint\n",
    "import numpy as np\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import torch.utils.data.dataloader\n",
    "import tensorflow as tf\n",
    "import sqlalchemy as sa\n",
    "import seaborn as sns\n",
    "sns.set_theme(style='whitegrid', font_scale=1.3, palette=sns.color_palette('husl'),)\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from userdiffusion import samplers, unet\n",
    "from userfm import cs, datasets, diffusion, sde_diffusion, flow_matching, utils, main as main_module, plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "82fd6df7-944a-4ac8-b1e9-d30dce4864a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# somehow, this line of code prevents a segmentation fault in nn.Dense\n",
    "# when calling model.init\n",
    "tf.config.experimental.set_visible_devices([], 'GPU')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "26f2a2a9-1783-4794-bfcb-8cf73dde97a5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<sqlalchemy.orm.session.SessionTransaction at 0x7fb4bb7b1f00>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "engine = cs.get_engine()\n",
    "cs.create_all(engine)\n",
    "session = cs.orm.Session(engine)\n",
    "session.begin()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "abdb5d9c-b9a4-4ed4-9170-92195c5f8b09",
   "metadata": {},
   "outputs": [],
   "source": [
    "config_alt_ids = {\n",
    "    # Lorenz\n",
    "    ('0y35hp7d', 'DM'): {},\n",
    "    # ('fba4g7bp', 'FMOT'): {'sample': {'use_score': False}},\n",
    "    # ('1g2n8baa', 'FMOT+Reg'): {'sample': {'use_score': False}},\n",
    "    # ('eug367ja', 'Flow Matching (VE)'): {'sample': {'use_score': False}},\n",
    "    ('3bjjfgwa', 'FM (no score)'): {'sample': {'use_score': False}},\n",
    "    ('c0ijllm1', 'FM+Reg (no score)'): {'sample': {'use_score': False}},\n",
    "    ('3bjjfgwa', 'FM'): {'sample': {'use_score': True}},\n",
    "    ('c0ijllm1', 'FM+Reg'): {'sample': {'use_score': True}},\n",
    "    # FitzHughNagumo\n",
    "    # ('wyrwide1', 'Diffusion (VE SDE)'): {},\n",
    "    # ('gcior3bc', 'Flow Matching (OT)'): {'sample': {'use_score': False}},\n",
    "    # ('tybh75p1', 'Flow Matching (VE)'): {'sample': {'use_score': False}},\n",
    "    # ('tybh75p1', 'Flow Matching (VE Score)'): {'sample': {'use_score': True}},\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "71d5b2f5-2d6d-4e50-bc64-f0799d351243",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfgs = session.execute(sa.select(cs.Config).where(cs.Config.alt_id.in_([c[0] for c in config_alt_ids])))\n",
    "cfgs = {c.alt_id: c for (c,) in cfgs}\n",
    "reference_cfg = cfgs[next(iter(cfgs.keys()))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "14e6d166-d85a-487e-acb4-dc6b729da54e",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.key(reference_cfg.rng_seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f22b503a-c4e9-4838-8ff6-b1c21daedb76",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3300/3300 [04:57<00:00, 11.08it/s]\n"
     ]
    }
   ],
   "source": [
    "key, key_dataset = jax.random.split(key)\n",
    "ds = datasets.get_dataset(reference_cfg.dataset, key=key_dataset)\n",
    "splits = datasets.split_dataset(reference_cfg.dataset, ds)\n",
    "dataloaders = {}\n",
    "for n, s in splits.items():\n",
    "    dataloaders[n] = torch.utils.data.dataloader.DataLoader(\n",
    "        list(tf.data.Dataset.from_tensor_slices(s).batch(reference_cfg.dataset.batch_size).as_numpy_iterator()),\n",
    "        batch_size=1,\n",
    "        collate_fn=lambda x: x[0],\n",
    "    )\n",
    "data_std = splits['train'].std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3d9f7a72-ecb8-4e67-8532-a11237e53524",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/workspace/GitHub/pmlr-v202-finzi23a/.venv/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1330: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()\n",
    "x_sample = next(iter(dataloaders['train']))\n",
    "ckpt_name = 'epoch_1999'\n",
    "\n",
    "cfg_info = {}\n",
    "for k in config_alt_ids:\n",
    "    cfg = cfgs[k[0]]\n",
    "    assert cfg.rng_seed == reference_cfg.rng_seed\n",
    "    assert cfg.dataset == reference_cfg.dataset\n",
    "\n",
    "    cfg_unet = unet.unet_64_config(\n",
    "        splits['train'].shape[-1],\n",
    "        base_channels=cfg.model.architecture.base_channel_count,\n",
    "        attention=cfg.model.architecture.attention,\n",
    "    )\n",
    "    model = unet.UNet(cfg_unet)\n",
    "    \n",
    "    key, key_jaxlightning = jax.random.split(key)\n",
    "    if isinstance(cfg.model, cs.ModelDiffusion):\n",
    "        jax_lightning = diffusion.JaxLightning(cfg, key_jaxlightning, dataloaders, data_std, None, model)\n",
    "    elif isinstance(cfg.model, cs.ModelFlowMatching):\n",
    "        jax_lightning = flow_matching.JaxLightning(cfg, key_jaxlightning, dataloaders, data_std, None, model)\n",
    "    else:\n",
    "        raise ValueError(f'Unknown model: {cfg.model}')\n",
    "        \n",
    "    jax_lightning.params = orbax_checkpointer.restore(cfg.run_dir/ckpt_name)\n",
    "    jax_lightning.params_ema = orbax_checkpointer.restore(cfg.run_dir/f'{ckpt_name}_ema')\n",
    "\n",
    "    cfg_info[k] = dict(\n",
    "        cfg=cfg,\n",
    "        jax_lightning=jax_lightning,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b2cd2005-a3f4-4b90-8bc1-c7d3c2b8197b",
   "metadata": {},
   "outputs": [],
   "source": [
    "if isinstance(reference_cfg.dataset, cs.DatasetLorenz):\n",
    "    def constraint(x):\n",
    "        fourier_magnitudes = jnp.abs(jnp.fft.rfft(x[..., 0], axis=-1))\n",
    "        return -(fourier_magnitudes[..., 1:].mean(-1) - .6)\n",
    "elif isinstance(reference_cfg.dataset, cs.DatasetFitzHughNagumo):\n",
    "    def constraint(x):\n",
    "        return jnp.max(x[..., :2].mean(-1), -1) - 2.5\n",
    "else:\n",
    "    raise ValueError(f'Unknown dataset: {referenc_cfg.dataset}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "acda3a0e-c8a3-4e62-8626-1a2a5662601f",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluation_trajectories = splits['train']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ae985767-3a8c-4c19-889c-3a4d48a52a9b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-01-29 00:25:01.649118: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3020] Can't reduce memory use below -3.52GiB (-3785230798 bytes) by rematerialization; only reduced to 48.85GiB (52455755224 bytes), down from 49.05GiB (52664581612 bytes) originally\n",
      "2025-01-29 00:25:15.396441: W external/xla/xla/tsl/framework/bfc_allocator.cc:497] Allocator (GPU_0_bfc) ran out of memory trying to allocate 21.44GiB (rounded to 23016960000)requested by op \n",
      "2025-01-29 00:25:15.403324: W external/xla/xla/tsl/framework/bfc_allocator.cc:508] **__________________________________________________________________________________________________\n",
      "E0129 00:25:15.403666   22323 pjrt_stream_executor_client.cc:3086] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 23016960000 bytes.\n"
     ]
    },
    {
     "ename": "XlaRuntimeError",
     "evalue": "RESOURCE_EXHAUSTED: Out of memory while trying to allocate 23016960000 bytes.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mXlaRuntimeError\u001b[0m                           Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[12], line 26\u001b[0m\n\u001b[1;32m     22\u001b[0m         info[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mevent_samples\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m samplers\u001b[38;5;241m.\u001b[39msde_sample(\n\u001b[1;32m     23\u001b[0m             info[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjax_lightning\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mdiffusion, event_scores, key_samples, x_shape\u001b[38;5;241m=\u001b[39mevaluation_trajectories\u001b[38;5;241m.\u001b[39mshape, nsteps\u001b[38;5;241m=\u001b[39minfo[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcfg\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mtime_step_count_sampling, traj\u001b[38;5;241m=\u001b[39mkeep_path\n\u001b[1;32m     24\u001b[0m         )\n\u001b[1;32m     25\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(info[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcfg\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mmodel, cs\u001b[38;5;241m.\u001b[39mModelDiffusion):\n\u001b[0;32m---> 26\u001b[0m     info[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msamples\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43minfo\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mjax_lightning\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey_samples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1.\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcond\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_shape\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mevaluation_trajectories\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeep_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeep_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     27\u001b[0m     \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mscore\u001b[39m(x, t):\n\u001b[1;32m     28\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(t, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mshape\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m t\u001b[38;5;241m.\u001b[39mshape:\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/src/userfm/diffusion.py:111\u001b[0m, in \u001b[0;36mJaxLightning.sample\u001b[0;34m(self, key, tmax, cond, x_shape, params, keep_path)\u001b[0m\n\u001b[1;32m    108\u001b[0m         t \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mones((x_shape[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m)) \u001b[38;5;241m*\u001b[39m t\n\u001b[1;32m    109\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscore(x, t, cond, params)\n\u001b[0;32m--> 111\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43msamplers\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msde_sample\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdiffusion\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscore\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_shape\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnsteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcfg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtime_step_count_sampling\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraj\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeep_path\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/src/userdiffusion/samplers.py:258\u001b[0m, in \u001b[0;36msde_sample\u001b[0;34m(diffusion, scorefn, key, x_shape, nsteps, traj)\u001b[0m\n\u001b[1;32m    256\u001b[0m key0, key1 \u001b[38;5;241m=\u001b[39m random\u001b[38;5;241m.\u001b[39msplit(key)\n\u001b[1;32m    257\u001b[0m xf \u001b[38;5;241m=\u001b[39m diffusion\u001b[38;5;241m.\u001b[39mnoise(key0, x_shape) \u001b[38;5;241m*\u001b[39m diffusion\u001b[38;5;241m.\u001b[39msigma(diffusion\u001b[38;5;241m.\u001b[39mtmax)\n\u001b[0;32m--> 258\u001b[0m samples, xt \u001b[38;5;241m=\u001b[39m \u001b[43meuler_maruyama_integrate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdiffusion\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscorefn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mxf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    259\u001b[0m \u001b[43m                                       \u001b[49m\u001b[43mkey1\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    260\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m xt \u001b[38;5;28;01mif\u001b[39;00m traj \u001b[38;5;28;01melse\u001b[39;00m samples\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/src/userdiffusion/samplers.py:94\u001b[0m, in \u001b[0;36meuler_maruyama_integrate\u001b[0;34m(diff, scorefn, x0, ts, key)\u001b[0m\n\u001b[1;32m     91\u001b[0m   x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m (t2 \u001b[38;5;241m-\u001b[39m t1) \u001b[38;5;241m*\u001b[39m xdot \u001b[38;5;241m+\u001b[39m diffusion(x, t1) \u001b[38;5;241m*\u001b[39m noise \u001b[38;5;241m*\u001b[39m jnp\u001b[38;5;241m.\u001b[39msqrt(t1 \u001b[38;5;241m-\u001b[39m t2)\n\u001b[1;32m     92\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m (x, key), x\n\u001b[0;32m---> 94\u001b[0m (x, _), xs \u001b[38;5;241m=\u001b[39m \u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscan\u001b[49m\u001b[43m(\u001b[49m\u001b[43mupdate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mx0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt12\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     95\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x, xs\n",
      "    \u001b[0;31m[... skipping hidden 11 frame]\u001b[0m\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1298\u001b[0m, in \u001b[0;36mExecuteReplicated.__call__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m   1296\u001b[0m   \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_handle_token_bufs(result_token_bufs, sharded_runtime_token)\n\u001b[1;32m   1297\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1298\u001b[0m   results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mxla_executable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute_sharded\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_bufs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1300\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dispatch\u001b[38;5;241m.\u001b[39mneeds_check_special():\n\u001b[1;32m   1301\u001b[0m   out_arrays \u001b[38;5;241m=\u001b[39m results\u001b[38;5;241m.\u001b[39mdisassemble_into_single_device_arrays()\n",
      "\u001b[0;31mXlaRuntimeError\u001b[0m: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 23016960000 bytes."
     ]
    }
   ],
   "source": [
    "cond = main_module.condition_on_initial_time_steps(evaluation_trajectories, reference_cfg.dataset.time_step_count_conditioning)\n",
    "trajectory_count = reference_cfg.dataset.batch_size\n",
    "keep_path = isinstance(reference_cfg.dataset, cs.DatasetGaussianMixture)\n",
    "# use same sampling key for all models\n",
    "key, key_samples = jax.random.split(key)\n",
    "for k, info in cfg_info.items():\n",
    "    cfg = info['cfg']\n",
    "    if isinstance(info['cfg'].model, cs.ModelFlowMatching):\n",
    "        info['samples'] = info['jax_lightning'].sample(key_samples, 1., cond, x_shape=evaluation_trajectories.shape, keep_path=keep_path, **config_alt_ids[k]['sample'])\n",
    "        if (\n",
    "            isinstance(info['cfg'].model.conditional_flow, cs.ConditionalSDE)\n",
    "            and isinstance(info['cfg'].model.conditional_flow.sde_diffusion, cs.SDEVarianceExploding)\n",
    "            and config_alt_ids[k]['sample']['use_score']\n",
    "        ):\n",
    "            def score(x, t):\n",
    "                if not hasattr(t, 'shape') or not t.shape:\n",
    "                    t = jnp.ones((evaluation_trajectories.shape[0], 1, 1)) * t\n",
    "                return info['jax_lightning'].score(x, t, cond, info['jax_lightning'].params_ema)\n",
    "            event_scores = samplers.event_scores(\n",
    "                info['jax_lightning'].diffusion, score, constraint, reg=1e-3\n",
    "            )\n",
    "            info['event_samples'] = samplers.sde_sample(\n",
    "                info['jax_lightning'].diffusion, event_scores, key_samples, x_shape=evaluation_trajectories.shape, nsteps=info['cfg'].model.time_step_count_sampling, traj=keep_path\n",
    "            )\n",
    "    elif isinstance(info['cfg'].model, cs.ModelDiffusion):\n",
    "        info['samples'] = info['jax_lightning'].sample(key_samples, 1., cond, x_shape=evaluation_trajectories.shape, keep_path=keep_path)\n",
    "        def score(x, t):\n",
    "            if not hasattr(t, 'shape') or not t.shape:\n",
    "                t = jnp.ones((evaluation_trajectories.shape[0], 1, 1)) * t\n",
    "            return info['jax_lightning'].score(x, t, cond, info['jax_lightning'].params_ema)\n",
    "        event_scores = samplers.event_scores(\n",
    "            info['jax_lightning'].diffusion, score, constraint, reg=1e-3\n",
    "        )\n",
    "        info['event_samples'] = samplers.sde_sample(\n",
    "            info['jax_lightning'].diffusion, event_scores, key_samples, x_shape=evaluation_trajectories.shape, nsteps=info['cfg'].model.time_step_count_sampling, traj=keep_path\n",
    "        )\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown model: {info['cfg'].model}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cae76cf7-b3d2-46bc-b554-46404d2b4b0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "trajectory_count = 10\n",
    "df = pd.concat([\n",
    "    *itertools.chain.from_iterable([\n",
    "        [\n",
    "            pd.DataFrame(dict(\n",
    "                Source=source,\n",
    "                Values=trajectory[:, 0],\n",
    "            ))\n",
    "            for i, trajectory in zip(range(trajectory_count), info['samples'][constraint(info['samples']) > 0])\n",
    "        ]\n",
    "        for (_, source), info in cfg_info.items()\n",
    "    ])\n",
    "], axis=0, keys=len(cfg_info) * list(map(str, range(trajectory_count)))).reset_index(names=['Trajectory', 'Time Step'])\n",
    "sns.relplot(\n",
    "    kind='line',\n",
    "    data=df,\n",
    "    x='Time Step', y='Values',\n",
    "    hue='Trajectory',\n",
    "    col='Source',\n",
    "    col_order=[c[1] for c in cfg_info],\n",
    ")\n",
    "print('Model-sampled events')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2fc9b77-6ffd-4c0f-9220-e8f4c614326b",
   "metadata": {},
   "outputs": [],
   "source": [
    "trajectory_count = 5\n",
    "df = pd.concat([\n",
    "    *[\n",
    "        pd.DataFrame(dict(\n",
    "            IsEvent=False,\n",
    "            Values=trajectory[:, 0]\n",
    "        )) for i, trajectory in zip(\n",
    "            range(trajectory_count),\n",
    "            evaluation_trajectories[constraint(evaluation_trajectories) <= 0]\n",
    "        )\n",
    "    ],\n",
    "    *[\n",
    "        pd.DataFrame(dict(\n",
    "            IsEvent=True,\n",
    "            Values=trajectory[:, 0]\n",
    "        )) for i, trajectory in zip(\n",
    "            range(trajectory_count),\n",
    "            evaluation_trajectories[constraint(evaluation_trajectories) > 0]\n",
    "        )\n",
    "    ],\n",
    "], axis=0, keys=map(str, range(2 * trajectory_count))).reset_index(names=['Trajectory', 'Time Step'])\n",
    "sns.relplot(\n",
    "    kind='line',\n",
    "    data=df,\n",
    "    x='Time Step', y='Values',\n",
    "    hue='Trajectory',\n",
    "    col='IsEvent',\n",
    ")\n",
    "print('Data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c32586b-e33f-4603-aabb-daa5bcdba99d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat([\n",
    "    pd.Series(constraint(info['samples']), name=source)\n",
    "    for (_, source), info in cfg_info.items()\n",
    "], axis=1).melt(var_name='Source', value_name='Constraint Value')\n",
    "df_data = pd.DataFrame({'Source': 'Data', 'Constraint Value': constraint(splits['train'])})\n",
    "bins = np.histogram(np.zeros(2), bins=128, range=pd.concat((df, df_data))['Constraint Value'].agg(['min', 'max']))[1]\n",
    "plot = (\n",
    "    sns.displot(\n",
    "        data=df,\n",
    "        stat='density',\n",
    "        x='Constraint Value',\n",
    "        col='Source',\n",
    "        col_order=[c[1] for c in cfg_info],\n",
    "        hue='Source',\n",
    "        hue_order=[c[1] for c in cfg_info],\n",
    "        common_norm=False,\n",
    "        bins=bins,\n",
    "        facet_kws=dict(\n",
    "            # sharey=False,\n",
    "        )\n",
    "    )\n",
    "    .set(yscale='log' if isinstance(reference_cfg.dataset, cs.DatasetFitzHughNagumo) else 'linear')\n",
    "    .set_titles('')\n",
    ")\n",
    "plot.map(\n",
    "    sns.histplot,\n",
    "    data=df_data,\n",
    "    bins=bins,\n",
    "    stat='density',\n",
    "    color='tab:grey',\n",
    "    x='Constraint Value',\n",
    "    zorder=-1,\n",
    ").set_xlabels('').set_ylabels('')\n",
    "for (row, col, hue), data in plot.facet_data():\n",
    "    ax = plot.axes[row][col]\n",
    "    ax.axvline(x=0, c='r', ls=':')\n",
    "    ax.xaxis.set_tick_params(labelbottom=True)\n",
    "    ax.yaxis.set_tick_params(labelleft=True)\n",
    "plot.tight_layout()\n",
    "sns.move_legend(\n",
    "    plot,\n",
    "    loc='upper center',\n",
    "    ncol=len(cfg_info) + 1,\n",
    "    title='',\n",
    "    bbox_to_anchor=(.455, 1.06),\n",
    "    frameon=True,\n",
    "    fancybox=True,\n",
    ")\n",
    "\n",
    "data_hist = np.histogram(df_data['Constraint Value'], bins=bins)[0] / len(df_data)\n",
    "for (row, col, hue), data in plot.facet_data():\n",
    "    print(plot.col_names[col])\n",
    "    model_hist = np.histogram(data['Constraint Value'], bins=bins)[0] / len(data)\n",
    "    kl_divergence = np.where(data_hist == 0., 0., data_hist * np.log(data_hist / (model_hist + 1e-12)))\n",
    "    print(kl_divergence.sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "786505d2-c9c2-4f57-aa67-f7219cbc9e95",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plots.save_all_subfigures(plot, f'event_histogram.unconditional.{reference_cfg.dataset.__class__.__name__}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7c880ee-97ef-4ef8-8226-4ba48efdb514",
   "metadata": {},
   "outputs": [],
   "source": [
    "splits['train'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f7d73f0-0f6c-4f72-9502-b95fe45186c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat([\n",
    "    pd.Series(constraint(info['event_samples']), name=source)\n",
    "    for (_, source), info in cfg_info.items()\n",
    "    if 'event_samples' in info\n",
    "], axis=1).melt(var_name='Source', value_name='Constraint Value')\n",
    "# reuse bins from previous plot\n",
    "data_color = 'tab:gray'\n",
    "plot = (\n",
    "    sns.displot(\n",
    "        data=df,\n",
    "        stat='density',\n",
    "        x='Constraint Value',\n",
    "        row='Source',\n",
    "        row_order=[c[1] for c, info in cfg_info.items() if 'event_samples' in info],\n",
    "        # row_order=['Data', *(c[1] for c, info in cfg_info.items() if 'event_samples' in info)],\n",
    "        hue='Source',\n",
    "        hue_order=[*(c[1] for c, info in cfg_info.items() if 'event_samples' in info), 'Data'],\n",
    "        palette=[*sns.color_palette()[:3], data_color],\n",
    "        common_norm=False,\n",
    "        bins=bins,\n",
    "        facet_kws=dict(\n",
    "            # sharex=True\n",
    "        ),\n",
    "        height=1.8,\n",
    "        aspect=2.2,\n",
    "    )\n",
    "    .set(yscale='log' if isinstance(reference_cfg.dataset, cs.DatasetFitzHughNagumo) else 'linear')\n",
    "    .set_titles('')\n",
    ")\n",
    "df_data = pd.DataFrame({'Source': 'Data', 'Constraint Value': constraint(splits['train'][constraint(splits['train']) > 0])})\n",
    "plot.map(\n",
    "    sns.histplot,\n",
    "    data=df_data,\n",
    "    bins=bins,\n",
    "    stat='density',\n",
    "    color=data_color,\n",
    "    x='Constraint Value',\n",
    "    zorder=-1,\n",
    ").set_xlabels('').set_ylabels('')\n",
    "for (row, col, hue), data in plot.facet_data():\n",
    "    ax = plot.axes[row][col]\n",
    "    ax.axvline(x=0, c='r', ls=':')\n",
    "    ax.xaxis.set_tick_params(labelbottom=True)\n",
    "    ax.yaxis.set_tick_params(labelleft=True)\n",
    "    if row != len(plot.row_names) - 1:\n",
    "        ax.xaxis.set_visible(False)\n",
    "plot.tight_layout()\n",
    "sns.move_legend(\n",
    "    plot,\n",
    "    loc='upper center',\n",
    "    ncol=len(cfg_info) + 1,\n",
    "    title='',\n",
    "    bbox_to_anchor=(.455, 1.06),\n",
    "    frameon=True,\n",
    "    fancybox=True,\n",
    ")\n",
    "\n",
    "data_hist = np.histogram(df_data['Constraint Value'], bins=bins, density=True)[0]\n",
    "for (row, col, hue), data in plot.facet_data():\n",
    "    print(plot.row_names[row])\n",
    "    model_hist = np.histogram(data['Constraint Value'], bins=bins, density=True)[0]\n",
    "    kl_divergence = np.where(data_hist == 0., 0., data_hist * np.log(data_hist / model_hist))\n",
    "    print(kl_divergence.sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d17b88f4-345a-4c74-81df-d344c7259f64",
   "metadata": {},
   "outputs": [],
   "source": [
    "plots.save_all_subfigures(plot, f'event_histogram.conditional.{reference_cfg.dataset.__class__.__name__}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b108c607-90d9-4470-b7ef-ef1157a26192",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Event Likelihood: Direct Monte-Carlo')\n",
    "for (_, source), info in ((('', 'Data'), {'samples': splits['train']}), *cfg_info.items()):\n",
    "    is_event = constraint(info['samples']) > 0\n",
    "    print(f'{source}: P(E) = {is_event.mean():.3f}+-{is_event.std()/jnp.sqrt(len(is_event)):.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "018cdb30-0c27-4849-ad77-6d761bca1dac",
   "metadata": {},
   "outputs": [],
   "source": [
    "key, key_nll = jax.random.split(key)\n",
    "for (_, source), info in cfg_info.items():\n",
    "    x_noise, nll_no_div, nll = info['jax_lightning'].compute_nll(key_nll, 1., evaluation_trajectories[:10])\n",
    "    print(f'{source=}, {nll_no_div.mean()=}, {nll.mean()=}, {x_noise.mean()=}, {x_noise.std()=}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b348fea6-cf1c-4bb0-a381-296e43db4776",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
