{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "focal-bible",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../../\")\n",
    "\n",
    "import pickle\n",
    "import timeit\n",
    "\n",
    "import bayesflow as bf\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import tensorflow as tf\n",
    "from bayesflow.computational_utilities import maximum_mean_discrepancy\n",
    "from tqdm.autonotebook import tqdm\n",
    "from train import build_trainer, configurator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a016067-22db-4aa6-a544-1b4b1946197e",
   "metadata": {},
   "outputs": [],
   "source": [
    "physical_devices = tf.config.list_physical_devices(\"GPU\")\n",
    "if physical_devices:\n",
    "    try:\n",
    "        tf.config.experimental.set_memory_growth(physical_devices[0], True)\n",
    "    except (ValueError, RuntimeError):\n",
    "        # Invalid device or cannot modify virtual devices once initialized.\n",
    "        pass"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "typical-jewel",
   "metadata": {},
   "source": [
    "# Set up Forward Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6f29cb9-cbb1-42d9-bbd4-91c2891fdbe3",
   "metadata": {},
   "outputs": [],
   "source": [
    "fashion_mnist = tf.keras.datasets.fashion_mnist\n",
    "(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b160c4f-62f0-4102-b03a-3f3364d0cdb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "forward_train = {\"prior_draws\": train_images, \"sim_data\": train_images}\n",
    "\n",
    "num_val = 500\n",
    "perm = np.random.default_rng(seed=42).permutation(test_images.shape[0])\n",
    "\n",
    "forward_val = {\n",
    "    \"prior_draws\": test_images[perm[:num_val]],\n",
    "    \"sim_data\": test_images[perm[:num_val]],\n",
    "}\n",
    "\n",
    "forward_test = {\n",
    "    \"prior_draws\": test_images[perm[num_val:]],\n",
    "    \"sim_data\": test_images[perm[num_val:]],\n",
    "}\n",
    "\n",
    "val_labels = test_labels[perm[:num_val]]\n",
    "test_labels = test_labels[perm[num_val:]]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "61d3b727-d160-4bf6-ba6e-35036989b926",
   "metadata": {},
   "source": [
    "# Sanity Check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c718b67-c189-4155-9679-fe50b13efe23",
   "metadata": {},
   "outputs": [],
   "source": [
    "how_many = 5\n",
    "conf = configurator(\n",
    "    {\n",
    "        \"sim_data\": forward_train[\"sim_data\"][:how_many],\n",
    "        \"prior_draws\": forward_train[\"prior_draws\"][:how_many],\n",
    "    }\n",
    ")\n",
    "\n",
    "f, axarr = plt.subplots(how_many, 2)\n",
    "for i in range(how_many):\n",
    "    if i == 0:\n",
    "        axarr[i, 0].set_title(\"Blurred\")\n",
    "        axarr[i, 1].set_title(\"True\")\n",
    "    axarr[i, 0].imshow(\n",
    "        conf[\"summary_conditions\"][i, :, :, 0],\n",
    "        cmap=plt.cm.get_cmap(\"Greys\"),\n",
    "    )\n",
    "    axarr[i, 1].imshow(forward_train[\"prior_draws\"][i].reshape(28, 28), cmap=plt.cm.get_cmap(\"Greys\"))\n",
    "    axarr[i, 0].axis(\"off\")\n",
    "    axarr[i, 1].axis(\"off\")\n",
    "f.tight_layout()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "seven-moderator",
   "metadata": {},
   "source": [
    "## Set up Network, Amortizer and Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b19eba61-ee55-4cbb-b50d-8dad16e6c720",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sampling steps for CMPE - two-step sampling\n",
    "cmpe_steps = 2\n",
    "# step size for FMPE, following Flow Matching for Scalable Simulation-Based Inference, https://arxiv.org/pdf/2305.17161.pdf\n",
    "fmpe_step_size = 1 / 248"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6288e1c5-87de-4ba1-aa78-af796dc2a883",
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_id(method, architecture, num_train):\n",
    "    return f\"{method}-{architecture}-{num_train}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "723ec30d-8e63-4174-b682-ede31ad627e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_path_dict = {\n",
    "    to_id(\"cmpe\", \"naive\", 2000): \"checkpoints/cmpe-naive-2000-23-12-02-155659/\",\n",
    "    to_id(\"cmpe\", \"naive\", 60000): \"checkpoints/cmpe-naive-60000-23-12-02-160801/\",\n",
    "    to_id(\"fmpe\", \"naive\", 2000): \"checkpoints/fmpe-naive-2000-23-12-02-161806/\",\n",
    "    to_id(\"fmpe\", \"naive\", 60000): \"checkpoints/fmpe-naive-60000-23-12-02-161806/\",\n",
    "    to_id(\"cmpe\", \"unet\", 2000): \"checkpoints/cmpe-unet-2000-23-12-02-144825/\",\n",
    "    to_id(\"cmpe\", \"unet\", 60000): \"checkpoints/cmpe-unet-60000-23-12-02-161035/\",\n",
    "    to_id(\"fmpe\", \"unet\", 2000): \"checkpoints/fmpe-unet-2000-23-12-02-161806/\",\n",
    "    to_id(\"fmpe\", \"unet\", 60000): \"checkpoints/fmpe-unet-60000-23-12-02-161806/\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d58d1944-9298-4cd7-9140-2c75d6d4b481",
   "metadata": {},
   "outputs": [],
   "source": [
    "arg_dict = {}\n",
    "for key, checkpoint_path in checkpoint_path_dict.items():\n",
    "    with open(os.path.join(checkpoint_path, \"args.pickle\"), \"rb\") as f:\n",
    "        arg_dict[key] = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "601ed234-68b5-45d1-b0d7-e79a8279019c",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_dict = {}\n",
    "for key, checkpoint_path in checkpoint_path_dict.items():\n",
    "    trainer_dict[key] = build_trainer(checkpoint_path, arg_dict[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9df37450-5d93-4e17-afa5-fea046810f59",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, trainer in trainer_dict.items():\n",
    "    fig_dir = f\"figures/{key}\"\n",
    "    os.makedirs(fig_dir, exist_ok=True)\n",
    "    h = trainer.loss_history.get_plottable()\n",
    "    f = bf.diagnostics.plot_losses(h[\"train_losses\"], h[\"val_losses\"])\n",
    "    f.savefig(os.path.join(fig_dir, \"loss_history.pdf\"), bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "57162bc6-f4ea-439c-bfaa-17a4e72aae46",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "362ec7be-89bb-46b3-9447-3a1bee1b5635",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update(\n",
    "    {\n",
    "        \"axes.labelsize\": 24,\n",
    "        \"xtick.labelsize\": 16,\n",
    "        \"ytick.labelsize\": 16,\n",
    "        \"legend.fontsize\": 24,\n",
    "        \"text.usetex\": False,\n",
    "        \"font.family\": \"serif\",\n",
    "        \"text.latex.preamble\": r\"\\usepackage{{amsmath}}\",\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fca3dd4-11b5-4be1-baab-5936435ee247",
   "metadata": {},
   "outputs": [],
   "source": [
    "conf = configurator(forward_test)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ae8dc7cc-1287-4c7b-b537-50a1fc236f9c",
   "metadata": {},
   "source": [
    "## Per-Class Generation: Means and STDs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b3bcbcb-da48-4286-9dfa-dac457824c9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class_names = [\n",
    "    \"T-Shirt/Top\",\n",
    "    \"Trouser\",\n",
    "    \"Pullover\",\n",
    "    \"Dress\",\n",
    "    \"Coat\",\n",
    "    \"Sandal\",\n",
    "    \"Shirt\",\n",
    "    \"Sneaker\",\n",
    "    \"Bag\",\n",
    "    \"Ankle Boot\",\n",
    "]\n",
    "\n",
    "y_labels = [r\"Parameter $\\theta$\", r\"Observation $x$\", \"Mean\", \"Std.Dev\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e03d9ead-2de2-4227-a97f-2a4637d192b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_indices_per_class(labels, seed=42):\n",
    "    out = {}\n",
    "    unique = np.unique(labels)\n",
    "    perm = np.random.default_rng(seed).permutation(labels.shape[0])\n",
    "    for i in unique:\n",
    "        for idx in perm:\n",
    "            if i == labels[idx]:\n",
    "                out[i] = idx\n",
    "                break\n",
    "    return out\n",
    "\n",
    "\n",
    "def create_mean_std_plots(\n",
    "    trainer, seed=42, filepath=None, n_samples=500, cmpe_steps=30, fmpe_step_size=1 / 248, method=\"\"\n",
    "):\n",
    "    \"\"\"Helper function for displaying Figure 7 in main paper.\n",
    "    Default seed is the one and only 42!\n",
    "    \"\"\"\n",
    "\n",
    "    idx_dict = random_indices_per_class(test_labels, seed=seed)\n",
    "    f, axarr = plt.subplots(4, len(idx_dict), figsize=(12, 4))\n",
    "    for i, (c, idx) in tqdm(enumerate(idx_dict.items()), total=len(idx_dict)):\n",
    "        # print(f\"{i+1:02}/{len(class_names)}\", end=\"\\r\")\n",
    "        # Prepare input dict for network\n",
    "        inp = {\n",
    "            \"parameters\": conf[\"parameters\"][idx : (idx + 1)],\n",
    "            \"summary_conditions\": conf[\"summary_conditions\"][idx : (idx + 1)],\n",
    "        }\n",
    "\n",
    "        # Obtain samples and clip to prior range, instead of rejecting\n",
    "        if method == \"cmpe\":\n",
    "            samples = trainer.amortizer.sample(inp, n_steps=cmpe_steps, n_samples=n_samples)\n",
    "        else:\n",
    "            samples = trainer.amortizer.sample(inp, n_samples=n_samples, step_size=fmpe_step_size)\n",
    "        samples = np.clip(samples, a_min=-1.01, a_max=1.01)\n",
    "\n",
    "        # Plot truth and blurred\n",
    "        axarr[0, i].imshow(inp[\"parameters\"].reshape((28, 28, 1)), cmap=matplotlib.colormaps[\"binary\"])\n",
    "        axarr[1, i].imshow(\n",
    "            inp[\"summary_conditions\"].reshape((28, 28, 1)),\n",
    "            cmap=matplotlib.colormaps[\"binary\"],\n",
    "        )\n",
    "        axarr[2, i].imshow(samples.mean(0).reshape(28, 28, 1), cmap=matplotlib.colormaps[\"binary\"])\n",
    "        axarr[3, i].imshow(samples.std(0).reshape(28, 28, 1), cmap=matplotlib.colormaps[\"binary\"])\n",
    "\n",
    "        axarr[0, i].set_title(class_names[i])\n",
    "\n",
    "    for j, label in enumerate(y_labels):\n",
    "        axarr[j, 0].set_ylabel(label, rotation=0, labelpad=55, fontsize=12)\n",
    "\n",
    "    # get rid of axis\n",
    "    for ax in axarr.flat:\n",
    "        ax.spines[\"right\"].set_visible(False)\n",
    "        ax.spines[\"left\"].set_visible(False)\n",
    "        ax.spines[\"top\"].set_visible(False)\n",
    "        ax.spines[\"bottom\"].set_visible(False)\n",
    "        ax.set_yticklabels([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_xticklabels([])\n",
    "        ax.set_xticks([])\n",
    "    f.tight_layout()\n",
    "\n",
    "    if filepath is not None:\n",
    "        f.savefig(filepath, dpi=300, bbox_inches=\"tight\")\n",
    "    return f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "165cc48a-5de2-4fab-8f44-e7ca07a846bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, trainer in trainer_dict.items():\n",
    "    print(key)\n",
    "    fig_dir = f\"figures/{key}\"\n",
    "    os.makedirs(fig_dir, exist_ok=True)\n",
    "    f = create_mean_std_plots(\n",
    "        trainer,\n",
    "        seed=42,\n",
    "        filepath=os.path.join(fig_dir, \"main.pdf\"),\n",
    "        method=arg_dict[key].method,\n",
    "        cmpe_steps=cmpe_steps,\n",
    "        fmpe_step_size=fmpe_step_size,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5fda05e2-da0a-416a-82e9-9b5dad398964",
   "metadata": {},
   "source": [
    "## Per-Class Generation: Samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6371b6d7-3b1f-4c04-9a7e-ce19c258b9bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_sample_plots(trainer, seed=42, filepath=None, cmpe_steps=30, fmpe_step_size=1 / 248, method=\"\"):\n",
    "    \"\"\"Helper function for displaying Figure 7 in main paper.\n",
    "    Default seed is the one and only 42!\n",
    "    \"\"\"\n",
    "\n",
    "    idx_dict = random_indices_per_class(test_labels, seed=seed)\n",
    "    n_samples = 5\n",
    "    f, axarr = plt.subplots(len(idx_dict), 2 + n_samples, figsize=(8.27, 11.69))\n",
    "    titles = [r\"Param. $\\theta$\", r\"Obs. $x$\"] + n_samples * [\"Sample\"]\n",
    "    for i, (c, idx) in tqdm(enumerate(idx_dict.items()), total=len(idx_dict)):\n",
    "        # Prepare input dict for network\n",
    "        inp = {\n",
    "            \"parameters\": conf[\"parameters\"][idx : (idx + 1)],\n",
    "            \"summary_conditions\": conf[\"summary_conditions\"][idx : (idx + 1)],\n",
    "        }\n",
    "\n",
    "        # Obtain samples and clip to prior range, instead of rejecting\n",
    "        if method == \"cmpe\":\n",
    "            samples = trainer.amortizer.sample(inp, n_steps=cmpe_steps, n_samples=n_samples)\n",
    "        else:\n",
    "            samples = trainer.amortizer.sample(inp, n_samples=n_samples, step_size=fmpe_step_size)\n",
    "        samples = np.clip(samples, a_min=-1.01, a_max=1.01)\n",
    "\n",
    "        # Plot truth and blurred\n",
    "        axarr[i, 0].imshow(inp[\"parameters\"].reshape((28, 28, 1)), cmap=matplotlib.colormaps[\"binary\"])\n",
    "        axarr[i, 1].imshow(\n",
    "            inp[\"summary_conditions\"].reshape((28, 28, 1)),\n",
    "            cmap=matplotlib.colormaps[\"binary\"],\n",
    "        )\n",
    "        for j in range(n_samples):\n",
    "            axarr[i, 2 + j].imshow(samples[j].reshape(28, 28, 1), cmap=matplotlib.colormaps[\"binary\"])\n",
    "\n",
    "        axarr[i, 0].set_ylabel(class_names[i], fontsize=12)\n",
    "\n",
    "    for i, title in enumerate(titles):\n",
    "        axarr[0, i].set_title(title, fontsize=12)\n",
    "\n",
    "    # get rid of axis\n",
    "    for ax in axarr.flat:\n",
    "        ax.spines[\"right\"].set_visible(False)\n",
    "        ax.spines[\"left\"].set_visible(False)\n",
    "        ax.spines[\"top\"].set_visible(False)\n",
    "        ax.spines[\"bottom\"].set_visible(False)\n",
    "        ax.set_yticklabels([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_xticklabels([])\n",
    "        ax.set_xticks([])\n",
    "    f.tight_layout()\n",
    "\n",
    "    if filepath is not None:\n",
    "        f.savefig(filepath, dpi=300, bbox_inches=\"tight\")\n",
    "        pass\n",
    "    return f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaf7518e-8e89-4005-84d1-b2291db60660",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, trainer in trainer_dict.items():\n",
    "    print(key)\n",
    "    fig_dir = f\"figures/{key}\"\n",
    "    os.makedirs(fig_dir, exist_ok=True)\n",
    "    f = create_sample_plots(\n",
    "        trainer,\n",
    "        seed=42,\n",
    "        filepath=os.path.join(fig_dir, \"samples_main.pdf\"),\n",
    "        method=arg_dict[key].method,\n",
    "        cmpe_steps=cmpe_steps,\n",
    "        fmpe_step_size=fmpe_step_size,\n",
    "    )\n",
    "    f.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "e2d54ffa-b8d2-427e-af7a-d8c4e12ac4e7",
   "metadata": {},
   "source": [
    "### Averaged RMSE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4913dd61-1dfb-4982-833c-c0ce3b9fe484",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_samples = 100\n",
    "n_datasets = 100\n",
    "parameters = conf[\"parameters\"][:n_datasets]\n",
    "\n",
    "for key, trainer in trainer_dict.items():\n",
    "    print(key, end=\"\")\n",
    "\n",
    "    # sample once, to avoid contaminating timing with tracing\n",
    "    c = conf[\"summary_conditions\"][0, None]\n",
    "    print(f\" Initializing...\")\n",
    "    if arg_dict[key].method == \"cmpe\":\n",
    "        trainer.amortizer.sample({\"summary_conditions\": c}, n_steps=cmpe_steps, n_samples=n_samples)\n",
    "    else:\n",
    "        trainer.amortizer.sample({\"summary_conditions\": c}, n_samples=n_samples, step_size=fmpe_step_size)\n",
    "\n",
    "    # store samples\n",
    "    post_samples = np.zeros((n_datasets, n_samples, conf[\"parameters\"].shape[-1]))\n",
    "\n",
    "    tic = timeit.default_timer()\n",
    "    for i in range(n_datasets):\n",
    "        print(f\"{i+1:03}/{n_datasets}\", end=\"\\r\")\n",
    "        c = conf[\"summary_conditions\"][i, None]\n",
    "        if arg_dict[key].method == \"cmpe\":\n",
    "            post_samples[i] = trainer.amortizer.sample(\n",
    "                {\"summary_conditions\": c}, n_steps=cmpe_steps, n_samples=n_samples\n",
    "            )\n",
    "        else:\n",
    "            post_samples[i] = trainer.amortizer.sample(\n",
    "                {\"summary_conditions\": c}, n_samples=n_samples, step_size=fmpe_step_size\n",
    "            )\n",
    "    toc = timeit.default_timer()\n",
    "\n",
    "    duration = toc - tic\n",
    "    rmse = bf.computational_utilities.aggregated_rmse(parameters, post_samples)\n",
    "\n",
    "    output_dir = f\"evaluation/{key}\"\n",
    "    os.makedirs(output_dir, exist_ok=True)\n",
    "    with open(os.path.join(output_dir, \"rmse.csv\"), \"w\") as f:\n",
    "        f.write(f\"duration,rmse\\n{duration},{float(rmse)}\\n\")\n",
    "    np.save(os.path.join(output_dir, \"rmse_samples.npy\"), post_samples)\n",
    "    print(f\"duration: {duration/(n_datasets * n_samples) * 1000:.2f}ms\\nRMSE:{float(rmse):.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2b7c0f5-87ea-4e67-93f8-50ad53f6be8d",
   "metadata": {},
   "source": [
    "RMSE for predicting only zeros:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfdf263f-46ef-4445-8b80-fce4ba28f1a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "bf.computational_utilities.aggregated_rmse(parameters, tf.zeros_like(conf[\"summary_conditions\"][:n_datasets, None]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "774258a6-bf11-49c7-bb6f-d2de2154f5b5",
   "metadata": {},
   "source": [
    "RMSE for predicting the noisy image:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "952a3544-c360-4b2f-9ace-c2c5321ad61f",
   "metadata": {},
   "outputs": [],
   "source": [
    "bf.computational_utilities.aggregated_rmse(parameters, conf[\"summary_conditions\"][:n_datasets, None])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0b37340-626e-4d2e-8b07-ef85fbae7cc4",
   "metadata": {},
   "source": [
    "### MMD"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27142160-3875-4339-86d5-9255f32a4d29",
   "metadata": {},
   "source": [
    "Split the training images into six parts (due to memory limits) and calculate the maximum mean discrepancy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "174b2720-6293-41ba-9a82-d74ad04d9029",
   "metadata": {},
   "outputs": [],
   "source": [
    "parameters = conf[\"parameters\"]\n",
    "split_size = 1583\n",
    "\n",
    "for key, trainer in trainer_dict.items():\n",
    "    print(key)\n",
    "\n",
    "    if arg_dict[key].method == \"cmpe\":\n",
    "        samples = trainer.amortizer.sample(conf, n_steps=cmpe_steps, n_samples=1, to_numpy=False)\n",
    "    else:\n",
    "        samples = trainer.amortizer.sample(conf, n_samples=1, step_size=fmpe_step_size, to_numpy=False)\n",
    "    mmds = np.zeros((6,))\n",
    "    for i in range(6):\n",
    "        mmds[i] = maximum_mean_discrepancy(\n",
    "            conf[\"parameters\"][(i * split_size) : ((i + 1) * split_size)],\n",
    "            samples[(i * split_size) : ((i + 1) * split_size), 0],\n",
    "        ).numpy()\n",
    "\n",
    "    output_dir = f\"evaluation/{key}\"\n",
    "    os.makedirs(output_dir, exist_ok=True)\n",
    "    np.save(os.path.join(output_dir, \"mmds.npy\"), mmds)\n",
    "    print(f\"{mmds.mean():.5f}, std: {mmds.std():.5f}\")"
   ]
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "python3",
   "name": "tf2-gpu.2-11.m108",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-11:m108"
  },
  "kernelspec": {
   "display_name": "cons-mod",
   "language": "python",
   "name": "cons-mod"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
