{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ab12f31b-b486-4d80-a6ea-7ef47ac614a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b11186f2-ec68-4943-b13a-c4321b74857d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2026-01-12 15:48:23.191345: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "E0000 00:00:1768258103.211621  628379 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "E0000 00:00:1768258103.217675  628379 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "/home/xxx/GitHub/Aligning-Flow-Div-User-Defined-Sampling/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import functools\n",
    "import itertools\n",
    "import pprint\n",
    "import os\n",
    "\n",
    "import orbax.checkpoint\n",
    "import numpy as np\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import torch.utils.data.dataloader\n",
    "import tensorflow as tf\n",
    "import sqlalchemy as sa\n",
    "import seaborn as sns\n",
    "sns.set_theme(style='whitegrid', font_scale=1.3, palette=sns.color_palette('husl'),)\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from userdiffusion import samplers, unet\n",
    "from userfm import cs, datasets, diffusion, sde_diffusion, flow_matching, utils, main as main_module, plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d83fa9ea-2afc-4c1b-9cd8-7c4ed43b9d9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ['CUDA_VISIBLE_DEVICES'] = '3'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "82fd6df7-944a-4ac8-b1e9-d30dce4864a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# somehow, this line of code prevents a segmentation fault in nn.Dense\n",
    "# when calling model.init\n",
    "tf.config.experimental.set_visible_devices([], 'GPU')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "26f2a2a9-1783-4794-bfcb-8cf73dde97a5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<sqlalchemy.orm.session.SessionTransaction at 0x759e369cf7c0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "engine = cs.get_engine()\n",
    "cs.create_all(engine)\n",
    "session = cs.orm.Session(engine)\n",
    "session.begin()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "abdb5d9c-b9a4-4ed4-9170-92195c5f8b09",
   "metadata": {},
   "outputs": [],
   "source": [
    "config_alt_ids = {\n",
    "    # Lorenz\n",
    "    # ('0y35hp7d', 'DM'): {},\n",
    "    # ('fba4g7bp', 'FMOT'): {'sample': {'use_score': False}},\n",
    "    # ('1g2n8baa', 'FMOT+Reg'): {'sample': {'use_score': False}},\n",
    "    # ('eug367ja', 'Flow Matching (VE)'): {'sample': {'use_score': False}},\n",
    "    # ('3bjjfgwa', 'FM (no score)'): {'sample': {'use_score': False}},\n",
    "    # ('c0ijllm1', 'FM+Reg (no score)'): {'sample': {'use_score': False}},\n",
    "    # ('3bjjfgwa', 'FM'): {'sample': {'use_score': True}},\n",
    "    # ('c0ijllm1', 'FM+Reg'): {'sample': {'use_score': True}},\n",
    "    # FitzHughNagumo\n",
    "    # ('wyrwide1', 'Diffusion (VE SDE)'): {},\n",
    "    # ('gcior3bc', 'Flow Matching (OT)'): {'sample': {'use_score': False}},\n",
    "    # ('tybh75p1', 'Flow Matching (VE)'): {'sample': {'use_score': False}},\n",
    "    # ('tybh75p1', 'Flow Matching (VE Score)'): {'sample': {'use_score': True}},\n",
    "    # Gaussian Mixture\n",
    "    ('mv0mvu9d', 'Diffusion'): {},\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "71d5b2f5-2d6d-4e50-bc64-f0799d351243",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfgs = session.execute(sa.select(cs.Config).where(cs.Config.alt_id.in_([c[0] for c in config_alt_ids])))\n",
    "cfgs = {c.alt_id: c for (c,) in cfgs}\n",
    "reference_cfg = cfgs[next(iter(cfgs.keys()))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "14e6d166-d85a-487e-acb4-dc6b729da54e",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.key(reference_cfg.rng_seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f22b503a-c4e9-4838-8ff6-b1c21daedb76",
   "metadata": {},
   "outputs": [],
   "source": [
    "key, key_dataset = jax.random.split(key)\n",
    "ds = datasets.get_dataset(reference_cfg.dataset, key=key_dataset)\n",
    "splits = datasets.split_dataset(reference_cfg.dataset, ds)\n",
    "dataloaders = {}\n",
    "for n, s in splits.items():\n",
    "    dataloaders[n] = torch.utils.data.dataloader.DataLoader(\n",
    "        list(tf.data.Dataset.from_tensor_slices(s).batch(reference_cfg.dataset.batch_size).as_numpy_iterator()),\n",
    "        batch_size=1,\n",
    "        collate_fn=lambda x: x[0],\n",
    "    )\n",
    "data_std = splits['train'].std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3d9f7a72-ecb8-4e67-8532-a11237e53524",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/xxx/GitHub/Aligning-Flow-Div-User-Defined-Sampling/.venv/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1330: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()\n",
    "x_sample = next(iter(dataloaders['train']))\n",
    "ckpt_name = 'epoch_2499'\n",
    "\n",
    "cfg_info = {}\n",
    "for k in config_alt_ids:\n",
    "    cfg = cfgs[k[0]]\n",
    "    assert cfg.rng_seed == reference_cfg.rng_seed\n",
    "    assert cfg.dataset == reference_cfg.dataset\n",
    "\n",
    "    cfg_unet = unet.unet_64_config(\n",
    "        splits['train'].shape[-1],\n",
    "        base_channels=cfg.model.architecture.base_channel_count,\n",
    "        attention=cfg.model.architecture.attention,\n",
    "    )\n",
    "    model = unet.UNet(cfg_unet)\n",
    "    \n",
    "    key, key_jaxlightning = jax.random.split(key)\n",
    "    if isinstance(cfg.model, cs.ModelDiffusion):\n",
    "        jax_lightning = diffusion.JaxLightning(cfg, key_jaxlightning, dataloaders, data_std, None, model)\n",
    "    elif isinstance(cfg.model, cs.ModelFlowMatching):\n",
    "        jax_lightning = flow_matching.JaxLightning(cfg, key_jaxlightning, dataloaders, data_std, None, model)\n",
    "    else:\n",
    "        raise ValueError(f'Unknown model: {cfg.model}')\n",
    "        \n",
    "    jax_lightning.params = orbax_checkpointer.restore(cfg.run_dir/ckpt_name)\n",
    "    jax_lightning.params_ema = orbax_checkpointer.restore(cfg.run_dir/f'{ckpt_name}_ema')\n",
    "\n",
    "    cfg_info[k] = dict(\n",
    "        cfg=cfg,\n",
    "        jax_lightning=jax_lightning,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b2cd2005-a3f4-4b90-8bc1-c7d3c2b8197b",
   "metadata": {},
   "outputs": [],
   "source": [
    "if isinstance(reference_cfg.dataset, cs.DatasetLorenz):\n",
    "    def constraint(x):\n",
    "        fourier_magnitudes = jnp.abs(jnp.fft.rfft(x[..., 0], axis=-1))\n",
    "        return -(fourier_magnitudes[..., 1:].mean(-1) - .6)\n",
    "elif isinstance(reference_cfg.dataset, cs.DatasetFitzHughNagumo):\n",
    "    def constraint(x):\n",
    "        return jnp.max(x[..., :2].mean(-1), -1) - 2.5\n",
    "else:\n",
    "    raise ValueError(f'Unknown dataset: {referenc_cfg.dataset}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "acda3a0e-c8a3-4e62-8626-1a2a5662601f",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluation_trajectories = splits['train']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ae985767-3a8c-4c19-889c-3a4d48a52a9b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'constraint' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[14], line 32\u001b[0m\n\u001b[1;32m     29\u001b[0m             t \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mones((evaluation_trajectories\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m)) \u001b[38;5;241m*\u001b[39m t\n\u001b[1;32m     30\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m info[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjax_lightning\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mscore(x, t, cond, info[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjax_lightning\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mparams_ema)\n\u001b[1;32m     31\u001b[0m     event_scores \u001b[38;5;241m=\u001b[39m samplers\u001b[38;5;241m.\u001b[39mevent_scores(\n\u001b[0;32m---> 32\u001b[0m         info[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjax_lightning\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mdiffusion, score, \u001b[43mconstraint\u001b[49m, reg\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-3\u001b[39m\n\u001b[1;32m     33\u001b[0m     )\n\u001b[1;32m     34\u001b[0m     info[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mevent_samples\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m samplers\u001b[38;5;241m.\u001b[39msde_sample(\n\u001b[1;32m     35\u001b[0m         info[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjax_lightning\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mdiffusion, event_scores, key_samples, x_shape\u001b[38;5;241m=\u001b[39mevaluation_trajectories\u001b[38;5;241m.\u001b[39mshape, nsteps\u001b[38;5;241m=\u001b[39minfo[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcfg\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mtime_step_count_sampling, traj\u001b[38;5;241m=\u001b[39mkeep_path\n\u001b[1;32m     36\u001b[0m     )\n\u001b[1;32m     37\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "\u001b[0;31mNameError\u001b[0m: name 'constraint' is not defined"
     ]
    }
   ],
   "source": [
    "cond = main_module.condition_on_initial_time_steps(evaluation_trajectories, reference_cfg.dataset.time_step_count_conditioning)\n",
    "trajectory_count = reference_cfg.dataset.batch_size\n",
    "keep_path = isinstance(reference_cfg.dataset, cs.DatasetGaussianMixture)\n",
    "# use same sampling key for all models\n",
    "key, key_samples = jax.random.split(key)\n",
    "for k, info in cfg_info.items():\n",
    "    cfg = info['cfg']\n",
    "    if isinstance(info['cfg'].model, cs.ModelFlowMatching):\n",
    "        info['samples'] = info['jax_lightning'].sample(key_samples, 1., cond, x_shape=evaluation_trajectories.shape, keep_path=keep_path, **config_alt_ids[k]['sample'])\n",
    "        if (\n",
    "            isinstance(info['cfg'].model.conditional_flow, cs.ConditionalSDE)\n",
    "            and isinstance(info['cfg'].model.conditional_flow.sde_diffusion, cs.SDEVarianceExploding)\n",
    "            and config_alt_ids[k]['sample']['use_score']\n",
    "        ):\n",
    "            def score(x, t):\n",
    "                if not hasattr(t, 'shape') or not t.shape:\n",
    "                    t = jnp.ones((evaluation_trajectories.shape[0], 1, 1)) * t\n",
    "                return info['jax_lightning'].score(x, t, cond, info['jax_lightning'].params_ema)\n",
    "            # event_scores = samplers.event_scores(\n",
    "            #     info['jax_lightning'].diffusion, score, constraint, reg=1e-3\n",
    "            # )\n",
    "            # info['event_samples'] = samplers.sde_sample(\n",
    "            #     info['jax_lightning'].diffusion, event_scores, key_samples, x_shape=evaluation_trajectories.shape, nsteps=info['cfg'].model.time_step_count_sampling, traj=keep_path\n",
    "            # )\n",
    "    elif isinstance(info['cfg'].model, cs.ModelDiffusion):\n",
    "        info['samples'] = info['jax_lightning'].sample(key_samples, 1., cond, x_shape=evaluation_trajectories.shape, keep_path=keep_path)\n",
    "        def score(x, t):\n",
    "            if not hasattr(t, 'shape') or not t.shape:\n",
    "                t = jnp.ones((evaluation_trajectories.shape[0], 1, 1)) * t\n",
    "            return info['jax_lightning'].score(x, t, cond, info['jax_lightning'].params_ema)\n",
    "        # event_scores = samplers.event_scores(\n",
    "        #     info['jax_lightning'].diffusion, score, constraint, reg=1e-3\n",
    "        # )\n",
    "        # info['event_samples'] = samplers.sde_sample(\n",
    "        #     info['jax_lightning'].diffusion, event_scores, key_samples, x_shape=evaluation_trajectories.shape, nsteps=info['cfg'].model.time_step_count_sampling, traj=keep_path\n",
    "        # )\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown model: {info['cfg'].model}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cae76cf7-b3d2-46bc-b554-46404d2b4b0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "trajectory_count = 10\n",
    "df = pd.concat([\n",
    "    *itertools.chain.from_iterable([\n",
    "        [\n",
    "            pd.DataFrame(dict(\n",
    "                Source=source,\n",
    "                Values=trajectory[:, 0],\n",
    "            ))\n",
    "            for i, trajectory in zip(range(trajectory_count), info['samples'][constraint(info['samples']) > 0])\n",
    "        ]\n",
    "        for (_, source), info in cfg_info.items()\n",
    "    ])\n",
    "], axis=0, keys=len(cfg_info) * list(map(str, range(trajectory_count)))).reset_index(names=['Trajectory', 'Time Step'])\n",
    "sns.relplot(\n",
    "    kind='line',\n",
    "    data=df,\n",
    "    x='Time Step', y='Values',\n",
    "    hue='Trajectory',\n",
    "    col='Source',\n",
    "    col_order=[c[1] for c in cfg_info],\n",
    ")\n",
    "print('Model-sampled events')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2fc9b77-6ffd-4c0f-9220-e8f4c614326b",
   "metadata": {},
   "outputs": [],
   "source": [
    "trajectory_count = 5\n",
    "df = pd.concat([\n",
    "    *[\n",
    "        pd.DataFrame(dict(\n",
    "            IsEvent=False,\n",
    "            Values=trajectory[:, 0]\n",
    "        )) for i, trajectory in zip(\n",
    "            range(trajectory_count),\n",
    "            evaluation_trajectories[constraint(evaluation_trajectories) <= 0]\n",
    "        )\n",
    "    ],\n",
    "    *[\n",
    "        pd.DataFrame(dict(\n",
    "            IsEvent=True,\n",
    "            Values=trajectory[:, 0]\n",
    "        )) for i, trajectory in zip(\n",
    "            range(trajectory_count),\n",
    "            evaluation_trajectories[constraint(evaluation_trajectories) > 0]\n",
    "        )\n",
    "    ],\n",
    "], axis=0, keys=map(str, range(2 * trajectory_count))).reset_index(names=['Trajectory', 'Time Step'])\n",
    "sns.relplot(\n",
    "    kind='line',\n",
    "    data=df,\n",
    "    x='Time Step', y='Values',\n",
    "    hue='Trajectory',\n",
    "    col='IsEvent',\n",
    ")\n",
    "print('Data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c32586b-e33f-4603-aabb-daa5bcdba99d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat([\n",
    "    pd.Series(constraint(info['samples']), name=source)\n",
    "    for (_, source), info in cfg_info.items()\n",
    "], axis=1).melt(var_name='Source', value_name='Constraint Value')\n",
    "df_data = pd.DataFrame({'Source': 'Data', 'Constraint Value': constraint(splits['train'])})\n",
    "bins = np.histogram(np.zeros(2), bins=128, range=pd.concat((df, df_data))['Constraint Value'].agg(['min', 'max']))[1]\n",
    "plot = (\n",
    "    sns.displot(\n",
    "        data=df,\n",
    "        stat='density',\n",
    "        x='Constraint Value',\n",
    "        col='Source',\n",
    "        col_order=[c[1] for c in cfg_info],\n",
    "        hue='Source',\n",
    "        hue_order=[c[1] for c in cfg_info],\n",
    "        common_norm=False,\n",
    "        bins=bins,\n",
    "        facet_kws=dict(\n",
    "            # sharey=False,\n",
    "        )\n",
    "    )\n",
    "    .set(yscale='log' if isinstance(reference_cfg.dataset, cs.DatasetFitzHughNagumo) else 'linear')\n",
    "    .set_titles('')\n",
    ")\n",
    "plot.map(\n",
    "    sns.histplot,\n",
    "    data=df_data,\n",
    "    bins=bins,\n",
    "    stat='density',\n",
    "    color='tab:grey',\n",
    "    x='Constraint Value',\n",
    "    zorder=-1,\n",
    ").set_xlabels('').set_ylabels('')\n",
    "for (row, col, hue), data in plot.facet_data():\n",
    "    ax = plot.axes[row][col]\n",
    "    ax.axvline(x=0, c='r', ls=':')\n",
    "    ax.xaxis.set_tick_params(labelbottom=True)\n",
    "    ax.yaxis.set_tick_params(labelleft=True)\n",
    "plot.tight_layout()\n",
    "sns.move_legend(\n",
    "    plot,\n",
    "    loc='upper center',\n",
    "    ncol=len(cfg_info) + 1,\n",
    "    title='',\n",
    "    bbox_to_anchor=(.455, 1.06),\n",
    "    frameon=True,\n",
    "    fancybox=True,\n",
    ")\n",
    "\n",
    "data_hist = np.histogram(df_data['Constraint Value'], bins=bins)[0] / len(df_data)\n",
    "for (row, col, hue), data in plot.facet_data():\n",
    "    print(plot.col_names[col])\n",
    "    model_hist = np.histogram(data['Constraint Value'], bins=bins)[0] / len(data)\n",
    "    kl_divergence = np.where(data_hist == 0., 0., data_hist * np.log(data_hist / (model_hist + 1e-12)))\n",
    "    print(kl_divergence.sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "786505d2-c9c2-4f57-aa67-f7219cbc9e95",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plots.save_all_subfigures(plot, f'event_histogram.unconditional.{reference_cfg.dataset.__class__.__name__}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7c880ee-97ef-4ef8-8226-4ba48efdb514",
   "metadata": {},
   "outputs": [],
   "source": [
    "splits['train'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f7d73f0-0f6c-4f72-9502-b95fe45186c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat([\n",
    "    pd.Series(constraint(info['event_samples']), name=source)\n",
    "    for (_, source), info in cfg_info.items()\n",
    "    if 'event_samples' in info\n",
    "], axis=1).melt(var_name='Source', value_name='Constraint Value')\n",
    "# reuse bins from previous plot\n",
    "data_color = 'tab:gray'\n",
    "plot = (\n",
    "    sns.displot(\n",
    "        data=df,\n",
    "        stat='density',\n",
    "        x='Constraint Value',\n",
    "        row='Source',\n",
    "        row_order=[c[1] for c, info in cfg_info.items() if 'event_samples' in info],\n",
    "        # row_order=['Data', *(c[1] for c, info in cfg_info.items() if 'event_samples' in info)],\n",
    "        hue='Source',\n",
    "        hue_order=[*(c[1] for c, info in cfg_info.items() if 'event_samples' in info), 'Data'],\n",
    "        palette=[*sns.color_palette()[:3], data_color],\n",
    "        common_norm=False,\n",
    "        bins=bins,\n",
    "        facet_kws=dict(\n",
    "            # sharex=True\n",
    "        ),\n",
    "        height=1.8,\n",
    "        aspect=2.2,\n",
    "    )\n",
    "    .set(yscale='log' if isinstance(reference_cfg.dataset, cs.DatasetFitzHughNagumo) else 'linear')\n",
    "    .set_titles('')\n",
    ")\n",
    "df_data = pd.DataFrame({'Source': 'Data', 'Constraint Value': constraint(splits['train'][constraint(splits['train']) > 0])})\n",
    "plot.map(\n",
    "    sns.histplot,\n",
    "    data=df_data,\n",
    "    bins=bins,\n",
    "    stat='density',\n",
    "    color=data_color,\n",
    "    x='Constraint Value',\n",
    "    zorder=-1,\n",
    ").set_xlabels('').set_ylabels('')\n",
    "for (row, col, hue), data in plot.facet_data():\n",
    "    ax = plot.axes[row][col]\n",
    "    ax.axvline(x=0, c='r', ls=':')\n",
    "    ax.xaxis.set_tick_params(labelbottom=True)\n",
    "    ax.yaxis.set_tick_params(labelleft=True)\n",
    "    if row != len(plot.row_names) - 1:\n",
    "        ax.xaxis.set_visible(False)\n",
    "plot.tight_layout()\n",
    "sns.move_legend(\n",
    "    plot,\n",
    "    loc='upper center',\n",
    "    ncol=len(cfg_info) + 1,\n",
    "    title='',\n",
    "    bbox_to_anchor=(.455, 1.06),\n",
    "    frameon=True,\n",
    "    fancybox=True,\n",
    ")\n",
    "\n",
    "data_hist = np.histogram(df_data['Constraint Value'], bins=bins, density=True)[0]\n",
    "for (row, col, hue), data in plot.facet_data():\n",
    "    print(plot.row_names[row])\n",
    "    model_hist = np.histogram(data['Constraint Value'], bins=bins, density=True)[0]\n",
    "    kl_divergence = np.where(data_hist == 0., 0., data_hist * np.log(data_hist / model_hist))\n",
    "    print(kl_divergence.sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d17b88f4-345a-4c74-81df-d344c7259f64",
   "metadata": {},
   "outputs": [],
   "source": [
    "plots.save_all_subfigures(plot, f'event_histogram.conditional.{reference_cfg.dataset.__class__.__name__}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b108c607-90d9-4470-b7ef-ef1157a26192",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Event Likelihood: Direct Monte-Carlo')\n",
    "for (_, source), info in ((('', 'Data'), {'samples': splits['train']}), *cfg_info.items()):\n",
    "    is_event = constraint(info['samples']) > 0\n",
    "    print(f'{source}: P(E) = {is_event.mean():.3f}+-{is_event.std()/jnp.sqrt(len(is_event)):.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "018cdb30-0c27-4849-ad77-6d761bca1dac",
   "metadata": {},
   "outputs": [],
   "source": [
    "key, key_nll = jax.random.split(key)\n",
    "for (_, source), info in cfg_info.items():\n",
    "    x_noise, nll_no_div, nll = info['jax_lightning'].compute_nll(key_nll, 1., evaluation_trajectories[:10])\n",
    "    print(f'{source=}, {nll_no_div.mean()=}, {nll.mean()=}, {x_noise.mean()=}, {x_noise.std()=}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b348fea6-cf1c-4bb0-a381-296e43db4776",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
