{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Figure 5: Failure modes of C2ST\n",
        "\n"
      ],
      "metadata": {
        "id": "wjEBWMQXKv7z"
      },
      "id": "wjEBWMQXKv7z"
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "39c4d6b8-9316-4eba-a648-3095ed4b5d22",
      "metadata": {
        "id": "39c4d6b8-9316-4eba-a648-3095ed4b5d22"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from torch import ones, zeros, float32, as_tensor, tensor, eye, sum, Tensor, manual_seed\n",
        "from torch.distributions import MultivariateNormal, Normal\n",
        "from typing import Any\n",
        "import matplotlib as mpl\n",
        "import seaborn as sns\n",
        "import numpy as np\n",
        "from sklearn.neural_network import MLPClassifier\n",
        "\n",
        "import time\n",
        "import IPython.display as IPd\n",
        "from svgutils.compose import *\n",
        "\n",
        "from sklearn.datasets import fetch_openml\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.mixture import GaussianMixture\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import pickle"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "a215b377-4ddb-4f7f-a6fa-2e2ea93ef155",
      "metadata": {
        "id": "a215b377-4ddb-4f7f-a6fa-2e2ea93ef155"
      },
      "outputs": [],
      "source": [
        "from labproject.metrics.c2st import c2st_optimal, c2st_nn, c2st_knn, c2st_rf, c2st_scores\n",
        "from labproject.data import toy_mog_2d"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f3f6191a-b33e-4536-b511-310304c7708e",
      "metadata": {
        "id": "f3f6191a-b33e-4536-b511-310304c7708e"
      },
      "source": [
        "## Visualize data and fit"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "f79753b9-4d8b-4118-aa62-b74704e03745",
      "metadata": {
        "id": "f79753b9-4d8b-4118-aa62-b74704e03745"
      },
      "outputs": [],
      "source": [
        "# seed for reproducibility\n",
        "_ = torch.manual_seed(0)\n",
        "_ = np.random.seed(0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "ff550e7b-23a3-429b-b68a-7d4c0d3aeb69",
      "metadata": {
        "id": "ff550e7b-23a3-429b-b68a-7d4c0d3aeb69"
      },
      "outputs": [],
      "source": [
        "# sample from mog\n",
        "data = toy_mog_2d()\n",
        "data_samples = data.sample((10_000,))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "id": "243aa674-446d-4fb0-80eb-012024c472a7",
      "metadata": {
        "id": "243aa674-446d-4fb0-80eb-012024c472a7"
      },
      "outputs": [],
      "source": [
        "# make gaussian approx\n",
        "mean = torch.mean(data_samples, dim=0)\n",
        "cov = torch.cov(data_samples.T)\n",
        "gen_model = MultivariateNormal(mean, cov)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "id": "f66a5f10-e716-4fcd-a683-3836b7fb5b52",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 218
        },
        "id": "f66a5f10-e716-4fcd-a683-3836b7fb5b52",
        "outputId": "ad6d6d50-bb8a-4421-ea9c-8a6377276d41"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 90x215 with 2 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGkAAADJCAYAAAAgnhzUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAT20lEQVR4nO2dXWgT2d/Hv2nTF/OCTfrertZa+i9K2dZdBV9gUSyisVRLBe90watdvFi8E2+8ci+8UfbCCxEXURAWF8Xu4pWIII/gPmsrPOxjtbWmptXaNtE2sW/O/C+yZ3JmMjOZSTKTOfF8QEwmk5nT88nvnN+cczJxiaIoguNoSgpdAE5muCQG4JIYgEtiAC6JAbgkBuCSGMBt9QlEUUQikQAAeDweuFwuq09ZdFgeSYlEAj6fDz6fT5LFMQdv7hiAS2IALokBuCQG4JIYgEtiAC6JAbgkBuCSGMDyYaFiZC6+qvla0Jv/KuWRZIK5+KquILJPvuGRZAArKt4MXJIOhZZD4JI00BI0FVuUPW+sqrS8LFySCkpBSjGlJck5sc+CiKnYouWimJKk9enOV0alJ6e0xIXxaFx6viHgRWmJC58F69eWMiHJTEaVrTAtQRMfUhOVlWXJZHhxRcB4NI4NAW9W5zKLoyVl03GT95iRRZ9HKWc0mkCJyox/a5VH9VhWXCc5VlKumdVcfDVjhelFD5EzOruYJqmt2vpkgcaRkowIIhWq12nriVITFI4l4HIlo2d0dhETs3G4FYaag+oRZCWOlKSFWpaVKSXOJJyOnrFYKnomZuN4N5dAWUlqUKYuuEb2XmXSYEVTBzAkiVRmqeKTTT83mxIrmzelnOnpBMrdyePX1qgnCV/kdZLaJ18piE6FaUi2ZaQp1BNE5PynPnm88dlUhieIyX924jhJSmhBtBySDhNIWgwkZZGoAtJlZRL0diyMcncJ3o4l92/YuA4rgoBVQUR7TSVaqzzSOezA8ZIAuaDKshL8835Bc9/WKo+mLOXx1ATNhd+g3F2C9nofAODFuwUsryZltFBNnl2CAIdLUlYuLWh0djFt/7bqSryKJZsmpSwC2aYliMhR0qzSJ9nRHwEm5pPu3LmDX375BaOjo7Ltly5dynuhaOgoUgqq9pQjMpdAZC6Bak85RmcXMTq7iBczi3gVS0jCxqNx6R8AvIrJBQ3/z/+rCkpGkYDmhuR2ZfAoP0RWjZobknT69GlcvHgRIyMj2LlzJ27cuCG9dvnyZUsKpoT0QURQZC6BZ29iaPRXoNFfgWdvYpIwkkYDkGRVlpVIj1/MyCOooqxEM4LW/6cFAKT+aFNtcj9llmklhpq7wcFBPH36FG63GydPnsS+ffvg9Xpx+PBhFOLL6ySCGv0VGH4dk71WF1yDyFwCzUGPJKqtulIWhUabOMKKIMj6I7sx3CeVlpYCADo6OnD79m0cOHAA9fX1Bf0qy/DrGNYF5P3CxNwn2XNaFgC8nknP4pyOoRLu378foVAIQ0NDAIAtW7bg+vXr6Ovrw/j4uIXFM8bfL2alf+9nkv3O9L+ySH9Fos9d4pIJyhRFTsCQpPPnz+PEiROybT09PXj48CFCoZAlBQOS2dNnQcSGgBeLK4K0fTaxLD3++8UsAGBDtQcbqpPjakpRz97E8OxNTNpW7nZlJchf4dZN/wFrkgfDsX7kyBF0d3fj48ePCIfDCIfD8Hq9OHfuXF4LlGn8S20EmsihHxNRJLFo9Fdgeu4TIm/1K5nw4t0Ctm/rwPKqKKXfQ5Pp71VmeFZg6jrp1KlTuHLlCqqrq6WEweVyYWxszJLCEciFY2uVR0qrgWSHXlvjxfhMPE3U+GwC6wKVaYlFuduFt2MThs47PptAc4MPjf4KKRmRXvt30s9xM7N3795FJBKBz2dfO95YVYmp2CI2BLzSdY4gpqYMphWJgpJ1gUpMRBelxyTC9HjxLhUxRDQ9Al5ZViJrfq3GlKSuri4sLS1ZLinodae17XQ0AfIRBxJNWpAMkPRf27d14PGT5zIZSrZv65ANrBYSU5L6+/vR3t6Ozs5OuN2pt96/fz/vBaNRiyZAHk1kKoFOyUkEke3ftFfLRLGCKUlnzpzBhQsX0NLSYlV5DNNWXYnR2UVZPzE990kmBgC6Wqow/DombadFZeKb9uq04xUCU5KCwSCOHTtmVVkMs6nWh3/eL6C7yYehyQU0Bz2IzCXSZk4BYGp+CV0tVQAgydKawFNeGBOUxyX9kSOnKnbs2IH+/n4cPHgQ5eXl0vZ8izNyraEmilDtSZXt2ZsYpuaXAKRXNpBM0QHIok2N5qBHNsBKMjvHzcwmEglUVVXh0aNHsu1WR5fetcj80iq6m1KJzNDkguxil8ijxdGQi1wikEgDIMklxyATfnbjsvr2nvF4XMoGFxYW4PXqD1RqLbNSTvwRyAiAv0L786Z2EUpDBBJhNEQQuYhWzsqqRVK+F6QYOlpvby8GBwfR2tqqOqBq1cWsniB6eGZxVcDi6jIqNQZLtdbJEbFEotpyLUFMRhCQEkSwQxBgUNKuXbtw7do1nD17Nu8FoFFbSUoLokcb6MjxVySbPTLXo4S+8FSLQiKRPub8Uqosm2p9smPY1RcRDEkaGRnByMgIxsbG8PLlS4RCIbjdbty7dw+bN2/G8ePHcy6I3lJfIufFTPpqUgCyPglA2mgA+fSPR+Oy15RS6egkry2uCFhcEaRj6GV0BV13d/XqVQDAnj17MDw8jJqaGgBALBbDoUOHLCkYmfkkM6klLmiuKB2aXEhr0tQW0yu3KZeG0UkBLVMpyK5mjmDqyJFIBMFgUHru8XgwNTWVcyHUomg8GpcJiswl0laUrgjJimwOeqRpBGX2pfbJJx8AM9+KKJQgwKSk3t5e7N27FwMDAxBFETdv3sTRo0fzXih6/QARND33SbaiFJCvKqWjSe+TPxVbzOoiVKsPsloQkEUKfuvWLTx48AAulws9PT3o6+vT3d9ICq6MpIkP6gtGyDQDPfhZW+NFXXANmoMetFWnL1y0soO3QxCQxbq7gYEBDAwMWFGWNEpckKa76Xmgx0+eGxqlLgZBAEP3cSCTemTK+/GT59Jrq4Jo6/psOwUBDl/BqgaZA2rYuA7Lq6K03Io0dYB11zF2yyEwIam5wYfxmbg0B0Saubo6D+o1Bj6VFKqC84GjSy6IyaZsRRCwvCpKcpZXRdTVedBc45UlDIB6ys2yIMDBkkilk+QBgHSNtCIIMkHEy7q1yffYOWRjB46UpLboZFUQ4S5xYVUQpSW/RFBbwKM5v8N6FAEOkUQvPCHrGQD5MI0gJqNKEOXTBkAygopVEOAQSWqsW+vBxIdE2jCPIMrlaTVxxSIIcKgkEk3r1nqkqQpajDKDK2ZBgMNmZjPdeImmkGNpduOov0i5KNJMllaMcgiOGxbKprKLWRDgsEgikEq3+tZprODov/ZLk6GF45o7TjpcEgNwSQzAJTEAl8QAXBID2PI7s4R4PPP3Vb80jPz2ruWS6N+Wra+vt/p0zGHkmya8uWMAy0fBBUHAzMwMAP6z2moYqRPLJXFyhzd3DMAlMQCXxABcEgNwSQzAJTEAl8QAXBIDcEkMwCUxAJfEAFwSA3BJDMAlMQCXxABcEgPYshCFrHPgM7PZYXkkJRIJ+Hw++Hw+2aIUjnF4c8cAXBIDcEkMwCUxAJfEAFwSA3BJDMAlMQD/erdB5ldStyvwl/HbezoCWorWa3bJ+qIl6Ykw836rZX1xknIVk+mYVgj7IiRZIUbvXPkWVbSS7BSjde58ySo6SdnKiS4tyZ4HKio09jRXlnyIKipJZgQppQBACTUhqfZ6PsRlQ1FIMiqHrviSDDPEytcFUZTeb0ZWPqKJaUlG5CgjIpMcLcj7spWVC0xKMitHS0wkETV8zmZPQHYsIsuIqFyjiTlJmQRpydESUl6SefhyWRAQSUQlUfSx7YgqZiSZiR4tOVpCns1Nah7z62CT9D5yLKUsweK7LDjqltNaGI0eNTmkgvVEeNxlqtsTqyvS46+DTQCSUQXIRRFJetFUtM1dNtGjjBxajpYMAHgyPSp7vq2uTdo/sbqCZ3OTUlQpmz+ro8mxkZRN30NHj5ocpQiaQEXqDv7RpdT6wG11bQBSUUVHlJloyiWSmJSUqXkjgjzuMpkYWgThf98Op237tqHr3/MkZamJ0mr2rJDkyObOjCAzctSEAEDjmrXS46lPH6T9iKwn06NS80dEkWbPzN+UrShHStLCrKBAhUcmhpYBAH+NPk4e59/nW9u2S/sQWUpRhcCwpDt37iAcDiMUCqGtLVXYS5cu4YcffshbgbSiyIwgNTl/jT6WZBDW+2tlz4k0Imvq0wcASdl0P2U3hhainD59GhcvXsTIyAh27tyJGzduSK9dvnw5b4XJRdD/zYXTBDWuWYvI5D9S5a/312K9vxaTkxOYnJzA4+d/4/Hzv6XzEGlkf0C7ibQTQ5E0ODiIp0+fwu124+TJk9i3bx+8Xi8OHz6MfOUduQoC5M1bZPIfKXImJ5M/IjyJCem4GwNfSY8fP/8b2zu+AZAUFZ5/DwCyaMoVWxKH0tJSAEBHRwdu376NAwcOoL6+Pi/fNzIjSHlxSgSNRV8ASDVtQEoOIJeiZGPgK5koIBlNW9u2Z/X35BtDkvbv349QKISff/4Z3d3d2LJlC65fv46+vj58/vzZ0gKqXaTSfRCQEkSiZ72/VmrG5l7PyP5XsrW7W3pMRNHR5AQM9Unnz5/HiRMnZNt6enrw8OFDhEKhnAqQKYoA/SSBFgSkBG0MfCWJ6ahrVf0HAH8NDQHQjzQj6I045DqfZHgF65EjR9Dd3Y2PHz8iHA4jHA7D6/Xi3LlzORVAt3BUU0oPjtJZHJBKrWlBpPKJDDWUogDIEgk9lCMOgHUj4aYUnzp1CleuXEF1dbWUMLhcLoyNjWV1cjNRBCSbOfo6iCQJdDZGoyeI3uf59Cvp+faObxCef4+tbdulpCG6lMC2ujbZgKudmJJ09+5dRCIRaZjHSpRRRJo5IohOFOh+iESREUEAZILU+LahS3aNRA8JEayeqjC1YL+rqwtLKgs08olWFAGpRIFu5rSiyAhE0NbuboxF30hRBEAWRQDSoshoU2f7aqH+/n60t7ejs7MTbnfqrffv3zd9Yr3xOa0o0kKtLzIKEQRAEtTctAlAKopIU1eIKAJMSjpz5gwuXLiAlpaWnE6azdo4tSGfXPoiZTOn7IuUgmjsjCLApKRgMIhjx47l5cRq6DV1Sui+SMnz6VeaougmDoBqM0f3Q2rTEwQ7oggwKWnHjh3o7+/HwYMHUV5eLm03Iy5TFGml3UbZ2t2Nv4aGJBnK7I3sQ5o4IL2Zo+eRlM0ciSIrJ/mUmDpSIpFAVVUVHj16JNtuZXRlIjz/Pi2a6FEE0kfR24ggOoKIoI2BdgDQbebsFASYlHT16tW8npwmurRkauHi1KcP2Nq2XdYvkcqnRw9oOfQ+TU3rDAkiUWRUkBUYktTb24vBwUG0traqDqhmezFrlm11bXgyPYpvG7pkUwjh+feywVG9UQM6eugLVrVEQdkPWb0qSAtDR9y1axeuXbuGs2fP5r0AuUKiiVT8en+tTJgSOnqIoI2Bdl1BdDZntyDAoKSRkRGMjIxgbGwML1++RCgUgtvtxr1797B582YcP37c0MnmV1bhLc9PM0GiqXHNWmlKgZalhVr0AOpNHJDeD2lh5VcyDR2Z9EV79uzB8PAwampqAACxWAyHDh2yrHA0idUVacQhupRAoMKT1uxlmv+Z+vRBlmKbFWRXoqDE1NEjkQiCwaD03OPxYGpqKm+FEURRNXn4OtgkjTqQfokWBSQHWzPNomrJIWk2kL5CtdCCAJOSent7sXfvXgwMDEAURdy8eRNHjx7NS0ECFRVpX1NZFgTpWomI8rjLZKKS703JUoPsR/odAJIccuxslg/bdYsA04sjb926hQcPHsDlcqGnpwd9fX26+9OLI6eiMd3Fkco0XGs9N71cWG9VKoFeiqVc301nb2bSbDtvuGHrClazkgD5ugZAXZQeygtSZbMG2Lv4PhsctzhS2S81ewKyiCKVnGlknIa8h5DtNyMA+wUBBVgLLmRIwdXWeQPpTZ9ZcokcoDByCLaf2V/m1h1kJQmEVkTRla0mTG99tpoccs5MZS4kjmvuALkoIBVVynkcrekM5X40ZuQAhRcEFEhSpmgCUhWoFlUEPRlKzMoh5XQCBSuFEVGAXBaN3oi52hCOmVFrp8ghOKs0OtCVTDeFmfY1g9PkEApaKqPRpCTfczlOlUMo+D1YC11BhT6/EQouCShMRfnL3EwIAhzUJ2Xb9Jk9B4s4qtR0JeYijFUZWjj2rym2is4FR/RJHH24JAbgkhiAS2IALokBuCQGsOV3ZgnxeNzq0zGHkd/etVwS/duy9fX1Vp+OOYzcXo43dwxg+UIUQRAwM5O86QX/We10jNSJ5ZI4ucObOwbgkhiAS2KAopT0/fff49y5c1ndQWxychJNTU2Zd7SRopQEAE1NTfjzzz9NveePP/7A7t278/qdq3xQFJJEUcSpU6fQ3t6O3bt3Y3Q0+XWYDRs2AEhG1o8//ogtW7Zg/fr1uH79OgYGBtDW1oaffvpJOs7ly5fx+++/F+AvyIBYBPz222/id999Jy4vL4vT09NiQ0ODePXqVbGlpUUURVE8fvy42NfXJ4qiKP7666/i2rVrxXfv3okfP34U/X6/GI1GZcdzWrUURSQ9ePAAAwMDKCsrQ21trWpfdPDgQQBAS0sLOjs7UVdXB7/fj2AwiFgsZnOJzVEUklwul2wgt6ws/Qtm9G126DuMsUBRSOrp6cHNmzextLSEWCyGe/fuFbpIeYWtj5QGhw4dwpMnT9DZ2YmGhgZs2rSp0EXKK3zsjgGKorkrdrgkBuCSGIBLYgAuiQG4JAbgkhiAS2IALokBuCQG4JIY4L+AO2TBjjO6GQAAAABJRU5ErkJggg==\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "# plot density\n",
        "\n",
        "_ = torch.manual_seed(0)\n",
        "\n",
        "with mpl.rc_context(fname=\"matplotlibrc\"):\n",
        "    n_samples = 10000\n",
        "    samples_mog = data_samples\n",
        "    samples_np = samples_mog.numpy()\n",
        "    samples_normal_approx = gen_model.sample((100_000,)).numpy()\n",
        "\n",
        "    n_plot = 10\n",
        "    samples_to_plot = samples_np[0:10,:]\n",
        "\n",
        "    al=0.8\n",
        "    ms=8\n",
        "    mec='k'\n",
        "\n",
        "    densities = [samples_np, samples_normal_approx]\n",
        "    cmaps = ['Blues', 'BuGn']\n",
        "\n",
        "    fig, axs = plt.subplots(2, 1, figsize=(0.9, 2.15))\n",
        "    for i_a, ax in enumerate(axs):\n",
        "        density, cmap = densities[i_a], cmaps[i_a]\n",
        "        sns.kdeplot(x=density[:,0], y=density[:,1], fill=True, thresh=0.05, levels=10, cmap=cmap, ax=ax, alpha=al)\n",
        "        ax.set_xticks([]); ax.set_yticks([])\n",
        "\n",
        "        ax.set_xlim([-7,4]); ax.set_ylim([-6,4])\n",
        "\n",
        "    axs[1].set_xlabel('dim1')\n",
        "    axs[0].set_ylabel('dim2')\n",
        "    axs[1].set_ylabel('dim2')\n",
        "\n",
        "    plt.subplots_adjust(hspace=0.3)\n",
        "\n",
        "    plt.savefig(\"svg/fig2_illustration.svg\", bbox_inches=\"tight\", transparent=True)\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "6d4745a2-51ae-428d-a0a0-4cc746bc51db",
      "metadata": {
        "id": "6d4745a2-51ae-428d-a0a0-4cc746bc51db"
      },
      "source": [
        "## Too few samples"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "id": "813f4fdb-e500-49f7-97b9-340937e15336",
      "metadata": {
        "id": "813f4fdb-e500-49f7-97b9-340937e15336"
      },
      "outputs": [],
      "source": [
        "# get true c2st\n",
        "c2st_gt = c2st_optimal(data, gen_model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "id": "742ca764-e77d-494c-a5cc-052e3c3b9c9d",
      "metadata": {
        "id": "742ca764-e77d-494c-a5cc-052e3c3b9c9d"
      },
      "outputs": [],
      "source": [
        "# set sample sizes\n",
        "budgets = [10, 100, 1000, 10_000]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "id": "507fde3e-dc26-4c24-9c01-c4e364dccab2",
      "metadata": {
        "id": "507fde3e-dc26-4c24-9c01-c4e364dccab2"
      },
      "outputs": [],
      "source": [
        "# compute c2sts for all sample sizes\n",
        "estimates = []\n",
        "for budget in budgets:\n",
        "    _ = torch.manual_seed(0)\n",
        "    estimates.append(c2st_nn(data.sample((budget,)), gen_model.sample((budget,)), seed=0).item())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "id": "4035cc12-7753-4760-8a67-4451d50b5602",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 123
        },
        "id": "4035cc12-7753-4760-8a67-4451d50b5602",
        "outputId": "f838c863-9f09-4877-9837-e5783ada214c"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 160x70 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAALMAAABqCAYAAAD3JWAUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWCUlEQVR4nO2deVBUV9r/Pw3Ivm/GFoFJMLEAUTEqiAtBBcEtOvrDStSMNcm4JZm4TabGCqJO3miJpWMyRs0YcUmM5pc4OiQUaEbQvEBYDC6oLKNEZTMoOwJNc98/kJZmbfa253yqbnXfc8495+m+3zr3OeeeRSZJkoRAoAPoDbQBAkFvIcQs0BmEmAU6gxCzQGcQYhboDELMAp1BiFmgMwgxC3QGIWaBziDELNAZeiTm8vJyRo4cSW5ubqu469ev4+Pjw4gRI1i4cCFVVVU9KUog6JRuizkxMRE/Pz8yMzPbjF+yZAk7duzg1q1buLu789e//rXbRgoEmtBtMR88eJB9+/Yhl8tbxd27d4+ysjKmTp0KwJtvvsnJkyfbzauqqkp1VFZW8uDBA6qqqhBjoARdodtiPnz4MJMnT24zLi8vj6FDh6rO5XI59+/fbzcvc3Nz1WFhYcHgwYMxNzenurq6u+bpBEqlkh07duDp6YmHhwfDhw/n/fffR6FQ9Ksd4eHhhIeHtwrfvHkzly5d6lJeBw8e5MSJEwD87ne/IzIyshcsbKRPGoANDQ2tC9ITbc2u8s4773Dp0iUuXrxIRkYG165d486dO3zwwQcDbRoA8fHxKJXKLl2TkJBAbW1tn9jTJwpzcnKioKBAdV5QUICTk1O76SsrK1VHUVFRX5g0YEiSRH19fadHS5cqPz+fyMhIjhw5gq2tLQDGxsZ8/PHHvPDCC0BjjRkUFISHhwd79uwhKysLf39/vLy88PX1JSUlBWhdA7q6upKbm0tkZCShoaEEBwfz4osvEhoaSl1dHQA7d+5k+PDh+Pr6kpyc3Op3HT58mNTUVN58803S09Px9/dn/vz5jBgxgqSkJGQymSptXFwc/v7+xMTEcPbsWcLCwvj+++8BiI6OZsKECbi4uPS4XWXQo6vbwdnZGVNTU+Lj45k6dSqff/45s2bNaje9mZlZX5ihFSiVSqKjoztNFxwcjIHB09vx008/4e7ujp2dnVq6wYMH89Zbb6nOq6urycjIAGD8+PFs3LiRRYsWkZSUxMKFC8nKyuqw3ISEBK5fv465uTnjx48nJiaGIUOG8I9//IPLly+jr6+Pj48P48ePV7tu+fLlHDlyhPDwcEaPHg2Ah4cHp0+fbresoKAg5s6di7+/PyEhIZw6dYqqqioSExMpLi7G1dWVP/7xj1hYWHT6f7VFr9bMISEhpKamAnDixAn+/Oc/4+7uTnJyMlu3bu3NonQeSZLUareYmBhGjx7N6NGjee6551Thvr6+QOPTLTs7m0WLFgHg4+ODra1tu71NTfj5+WFlZYW+vj6enp48evSIuLg4Zs2ahYWFBaampqo8O2PixIld/Zm8+uqr6Onp4ejoiIODA48ePepyHk30uGZu3sfc9OgA8PT0JDExsafZP/Po6+sTHBysUbrmjB07lhs3blBWVoaVlRVBQUEEBQUBqInc1NQUaLudIkkSCoUCmUym5sY0uRLQ6Lo00ZSuZfpBgwZp5Bs32dK8fJlMplZeS5o/jVqW21VEq6yPkclkGBgYdHo0FyiAi4sLb7zxBkuXLuXhw4dAozhOnz7dZmPa0tKSF154ga+//hqApKQk8vPz8fT0xN7enitXrgDw448/qrVn2mLatGmcPXuW0tJSamtr+ec//9lmOgMDA+rr69uMa15mk02dXdNT+sRnFvQOn3zyCXv37iUwMJCGhgZqamoYNWqUqmHXkuPHj7Ny5Uq2bt2KoaEh33zzDUZGRqxevZrFixfj6enJ2LFj8fb27rDc0aNHs27dOsaNG4etrS3Ozs5tpps1axYrV67k8OHDreIiIiKYP38+Dg4OBAcHk52dDTT6zX/605+67Rd3hEzbZmdXVVVhbm4ONPqButw4FPQuws0Q6AxCzAKdQYhZoDMIMWspubm5yGSyVi9cmt7e9SUte1a6y/Lly/nll196JS9NEGLuYyRJ4nFdXadHW+3wQYMGsWLFCsrKygbA8p5z4cKFfh35KLrm+pgahYJ5Ebs7TXdmw1pMDA3VwuRyOTNmzGDdunUcOnSo1TURERF88cUXNDQ0MGXKFHbv3s39+/fx9/dX1d6RkZHExcURGRmJq6sr48ePJz09nX//+998+umnnD9/ntLSUmxsbDh9+jRDhgxp077IyEiio6MpLy/nP//5D2PGjOHYsWMYGhpy/Phxdu/ejVKpxMPDg/3797N3717y8/MJCQkhLi4OR0fHrv95XUSjmjkvL6+v7RC0w65duzh//jwxMTFq4bGxsSQmJpKSkkJ6ejq1tbXs37+/0/wCAwPJysqipqaGjIwMEhISyMzM5KWXXuLLL7/s8NqEhAS++uorbt68SU5ODjExMdy8eZP9+/fz448/kp6ejru7O1u2bGHTpk3I5XK+//77fhEyaFgzz5kzh8uXL/e1LTqJ8aBBnNmwVqN0bWFpaclnn33GW2+9xbVr11ThsbGxJCcn8/LLLwNQU1ODgYEBs2fP7rCcpvETbm5u7Nmzh0OHDpGZmUlCQgLPP/98h9c2jeMAVOM4fvjhB7Kzs1VjRBQKRaf59BUaiVnL3qs8U8hkslbuQ1cJDAwkMDCQ9evXq8KUSiVr165l3bp1AJSVlSGTySgpKWl3HAY8HT+RlpZGaGgo69evZ+HChejr63d6n9sax6FUKgkNDWXv3r1A40uvvhqv3BkaibmwsLDDUW9hYWG9ZpCgbXbt2sXIkSNV4yoCAgIICwvjD3/4AyYmJixevJiZM2eyfPlyHj16REFBAY6Ojnz77bdqo+yaiI+PJyAggFWrVlFWVsaqVauYM2dOl+3y9/cnIiKCTZs2MXjwYNauXYuZmRm7d+/u03EYbaFxb4YkSe0egr6nyd1omjI1Z84cFi5cyIQJE/Dw8MDFxYU1a9ZgaWnJX/7yF3x9ffHz88PLy6vN/EJDQ7ly5QpeXl4EBAQwatQobt++3WW7Ro0aRXh4ONOnT8fDw4NHjx6xbds2AObNm0dISAg5OTnd/+FdQKOxGd7e3v3mM4uxGYLuolHNLGpfwbOARmJuPuheINBWNBKzlZUV69evV01sfPfddzE3N8fPz6/DJQQEgv5EIzG/9957PH78GFdXV7777jtOnjxJeno6Gzdu5J133ulrGwUCjdCoAThy5EhVh/2KFSswNDTk448/BsDd3Z0bN270mkGiASjoLhrVzM0nW164cIHp06erzjuarCgQ9CcavTSxs7MjOTmZ8vJy8vPzVWKOj4/vcHEXgaA/0UjMe/bsITQ0lKKiIvbt24eZmRkffvghe/fuJSoqqq9tFAg0otsTWnNycnBwcFANPOkthM8s6C4a+cz19fX87W9/Y/369apVH93c3LCystKaRfwEAo3EvGLFCtLS0pDL5Sxbtozt27er4r777rs+M04g6Aoa+cwpKSlcvXoVgKVLlxIQEIClpSWrV68Wr7oFWoPG45lramowNjbG0dGRqKgoJk2ahFwu77XJjwJBT9HIzVi+fDkTJkxQ+cuurq5ERUWxatWqTpdMFQj6C41q5nXr1qmm5zSRn5/Pv/71r3YX1RMI+huNauakpCQWLVpETU2NKiw9PZ25c+cyd+7cPjNOIOgKGvUz+/v7s337dnx8fNTCL168SFhYGHFxcb1mUPN+5tLSUtHP/F9C83Wau52HJonKy8tbCRlgypQplJaW9tiI9oiNjVWbRCnQDUqqqsgpLKL8cQ2WJsa4PTeYZYsX9zhfjcSsUChoaGhotci1Uqns04FGj+vqkMQuVTpFWXU1Kbm/UA8okVFbUUFxRQVj793HY1jPxvlo5Ga8++672NjYsGXLFrXwsLAw7t6926t7uTV3M67k5GBtadlreQv6h8d1dRRXVPKwspLiiorG708+80pKaGiS3JNPfZmMMc//hv9Z/P96VK5GYq6oqGDWrFnk5+czbtw4GhoaSEtLY8iQIZw9exYbG5seGdGc5mK+mZuLw5NtwwTaQ02dgl8ryvm1vIJfyyt4UF5OcUXFk/NyKruxboaduTkn3l3TI7s0cjMsLCyIj4/nwoUL/Pzzz+jp6fH222+3u0Or4NmmVqFQifNBeUUz4TZ+VjTr1WoPc2MjHCwscbC0YLCVFc9ZWzHYyor//1Myt/LV91TRk8n4jaNDj+3WuAkpk8kICAggICCgx4UK+pbM/AK+TUnlXvFDhtnbsWDcy7wkf7ogYl19/dOatKKCX8vKGz+f1LLljx93WoaZkRH2FhZPxWplxWBrK4ZYWzPYygoLE2P0ZLJWb4jtzM3ZcPxLJKBBkhrTAK9P6vq2ay3R6j1NhJvRdTLzC9jyzWla3tSRw5yoUSj4taKC0qrO9yQ3MRyEvYUljpYWOLYQ63PWVliamLQpVk3IuHefL/43gTsPfuU3jg68PmkiHr0wyUOI+RmmVqGgoLSMvJIS8h+VkFdSwpVf7lKjwUbxRgYGOFhaYG9p2SjUJrFaWTHExhpLExP09fSeqbE3Yn3mZ4Dyx49VYs0vKSXvUQn5pSUUl1e0qoHbw8TQkLeDZjDYyhK5jQ3WpqbPnFg7Q4hZS2iQJIrLK54ItqRRsCWl5JWUUNlBg8vMyAi5jQ1OtjYMs7fjx1tZ3H7wQC2NnkyGh9NQZoz07OufMaAIMfczdfX1FJSWPqlpS8l/It78klIUHWzpa29hgdzGmmF2tgyzs8PZ3g4Xe3tszc3VfNcxLi591sDSdoTP3A066y0AqKypIU/lGjytaX8tL2/XNTDQ0+M5ayuG2triZGuLi70dw+zsGGZvh5mREfoavg3tqwaWtiPE3EUy8wvY+qS3oOmPkwGBXiOpVypVtW1H3VumhoaNroGdTaNYn9S0Q21sMGxjH22BZvTIzfj6668JDw+nrq6OJUuWsHnzZrX4mJgYXn/9ddXaGmPGjGlzn2VtoEGSqK6tparpqHn6vbLZ97Tbd2hoca0ExFy91ipPW3MzhtrYMNTOFmfbJ66Bgz125uY61/jSBrpdMxcWFjJhwgRSU1OxtrYmODiYjRs3EhQUpEqzbds2zM3NWbu28z09mmheM685+BmvTZnS6hHeHk3blFXV1lLZTJBPxVjTSpxNaaprazXuGWiLQfr6zPEeg7O9PcPsbHG2t8Pc2Fhj10DQc7pdM587d46AgAAcHBpfQy5btoyTJ0+qiTklJYXq6mqOHj2Ks7Mzf//739tcAamqqkr1vbKyUvX9+p1f+CDvJAvGjcXK1JTqJyKtrq1TCbDqyfeqmhqqFArV4JXuMshAHzNDI0yNjDAzMsLUyBALY2PMjI0xNzIiISubgjaGvXq5urB0YrNhspJEjQZv0gRPMTU17dnTSuomH330kbRp0ybV+blz56QZM2aopVm2bJkUHR0tSZIk7du3T5o8eXKbeQHiEIdUWVnZXTlKkiRJ3X4GNjS09BxpNd75yJEjzJw5E4BVq1Zx9erVZ3a3UYH2020xOzk5qXY+AigoKFBzIWpqavjoo4/UrpEkiUFt7HdXWVmpOprnWVhYqBbX1lFUVKRKX1RUpFF887Dmm9IUFRWpxbWVpiVtldkVGzuzr6VNzW3RZEOdZ8m+pm3duk13q/S8vDzJxcVFKiwslOrq6qQZM2ZI3377rVqaF198UTp79qwkSZL0+eefS4GBgZ3mW1lZ2aXHTmfp24pvHlZUVKQW3zyurTQtj57a2Jl9LW1qbktHdumSfZrSbTFLkiSdOnVK8vT0lIYPHy5t2LBBkiRJ+v3vfy+dOXNGkiRJSktLk8aPHy+5u7tLr7zyinT37t1O8+yqmAcCbbfxv9U+rX5poq2rgGq7jf+t9mmdmAWC7iJ69AU6gxCzQGcQYhboDELMAp1BiFmgMzxzYr59+3ar5XUFukN9fT1Tp04lNTW1y9c+U2IuLS3lwIEDqj5Kge6xZcsWhg0b1q1rnykxW1tbs2PHDiFmHeXo0aP4+Pjg5ubWrevFhFaB1vDNN98wZMgQUlNTycrK4ssvv+zS9ULMAq3hzJkzAISHhzN79uyuZ9Brozx6QFlZmeTp6SnduXNHFXbq1CnJ3d1dcnNzk8LDwwfOuA4YM2aM1NDQIB0/flzasWNHv5bt4uKi9n9pM/11fwfcZ05MTMTPz4/MzExVWGFhIRs2bCAuLo4bN25w6dIlYmJiBtDK1ty9excnJydkMhkXL15k0qRJA22SVtKf93fAxXzw4EH27duHXC5XhTWfXzho0CDV/EJtISgoiIkTJ5Kens7o0aM5evQoK1euZPny5YwaNQpvb2/Cw8MByMvLY+bMmfj4+ODs7MzGjRsBiIyM5Le//S1+fn44OzuzdetW3nvvPby8vJg6dSqPHz8mNzcXd3d3FixYgIeHB0FBQTx8+FDNFqVSycaNG/H29sbLy4tt27YBUFxczIwZMxg7diwvv/yy6hHe3/Tn/R1wMR8+fLjVOs95eXkMHTpUdS6Xy7l//35/m9YuMTExLF26lAMHDpCSksLIkSM5duwYV69e5cqVKyQkJJCdnU11dTUnTpxg0aJFJCUlcf36dQ4ePEhxcTEAycnJREdHc+nSJTZv3kxwcDBXr15FT0+P2NhYAG7evMmaNWvIyMjA09OTsLAwNVsOHTpEXV0daWlppKWlkZiYSFRUFF988QVeXl6kpaVx7Ngx4uPj+/1/gv69v1rZANRkfuFAk5GRwerVq7l16xYjRozAzc2N2tpapkyZQkhICB9++CGmpqZs2LCBCxcuEBERwfXr16mtrVXNRp88eTKWlpZYPtnqYtq0aQC4uLhQUlICwPPPP68Kf+ONN3jttdfU7IiNjeXnn39WibWqqopr164REhJCUFAQd+7cITg4WPWk0Ab66v5qpZidnJzUapKW8wsHmqCgIC5dusScOXMoKSmhvr4ePz8/EhMTSUpKIiYmBl9fX+Lj4zlw4AA5OTksWbKEV199lfPnz6v2Gzc0NFTLt63tw5qHNTQ0oK+vrxavVCrZuXMnCxYsAODhw4eYmJhgampKVlYW0dHRREVFsWvXLm7evKkVC8/01f3VruruCdOnT+eHH36gqKgIhULBsWPHmDVr1kCbpeLw4cO88sorpKenM3v2bL766is++eQTFixYQEBAABEREbi7u5OZmcm5c+d4//33WbRoEffu3SMvLw9lBwsktiQnJ4e0tDRVuYGBgWrxAQEBfPbZZygUCqqrq5k2bRrnzp1j+/bt7Ny5k9DQUD799FMePHigNTPj++r+amXNLJfL2blzJ9OnT6e2tpZ58+Yxf/78gTZLRWJiIr6+vgBcvnyZiIgIjI2NGTFiBJ6enpiYmODt7U1wcDCVlZUsXboUa2trHB0dGTdunEazlpuws7Nj27ZtZGdn4+HhwaFDh9TiV65cSU5ODmPGjEGhULB48WLmzZuHn58fr732Gl5eXhgYGLBlyxasra1782/oNn11f8W0KS0mNzcXf39/cnNzB9qUZwKtdDMEgu4gamaBziBqZoHOIMQs0BmEmAU6gxCzQGcQYhboDELMAp1BiFmgMwgxC3QGIWaBziDELNAZ/g/2vmhbBVEQWAAAAABJRU5ErkJggg==\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "# plot curve\n",
        "\n",
        "good_vals = [c2st_gt for _ in range(len(budgets))]\n",
        "\n",
        "with mpl.rc_context(fname=\"matplotlibrc\"):\n",
        "    fig, ax = plt.subplots(1, 1, figsize=(1.6, 0.7))\n",
        "    _ = ax.axhline(c2st_gt, color=\"gray\", alpha=0.6)\n",
        "    _ = ax.plot(budgets, estimates, c=\"#458588\")\n",
        "    _ = ax.scatter(budgets, estimates, c=\"#458588\", s=15.0)\n",
        "    _ = ax.fill_between(budgets, good_vals, estimates, color=\"#458588\", alpha=0.1)\n",
        "    _ = ax.set_ylim([0.5, 1.0])\n",
        "    _ = ax.set_xlim([10, 12200])\n",
        "    _ = ax.legend([\"Ground truth\", \"Neural net\"], handlelength=0.7, handletextpad=0.4, labelspacing=0.1, loc=\"upper right\", bbox_to_anchor=[1.1, 1.2, 0.0, 0.0])\n",
        "    _ = ax.set_ylabel(\"C2ST\", labelpad=-5)\n",
        "    _ = ax.set_xscale(\"log\")\n",
        "    _ = ax.set_xticks(budgets)\n",
        "    _ = ax.set_xticklabels([r\"$10^1$\", \"\", \"\", r\"$10^4$\"])\n",
        "    _ = ax.set_yticks([0.5, 1.0])\n",
        "    _ = ax.set_xlabel(\"#samples\", labelpad=-8.0)\n",
        "\n",
        "    locmin = mpl.ticker.LogLocator(base=10.0,subs=np.arange(0, 1.0, 0.1),numticks=12)\n",
        "    ax.xaxis.set_minor_locator(locmin)\n",
        "    ax.xaxis.set_minor_formatter(mpl.ticker.NullFormatter())\n",
        "\n",
        "    plt.savefig(\"svg/fig2_panel_a.svg\", bbox_inches=\"tight\", transparent=True)\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c08cd877-0baf-4c94-9336-0fc37afae320",
      "metadata": {
        "id": "c08cd877-0baf-4c94-9336-0fc37afae320"
      },
      "source": [
        "## A too poor classifier"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "id": "22aa6528-e5f9-4b4c-a0de-261542c7c60a",
      "metadata": {
        "id": "22aa6528-e5f9-4b4c-a0de-261542c7c60a"
      },
      "outputs": [],
      "source": [
        "# create bad c2st metric with control of hidden layer size\n",
        "\n",
        "def poor_c2st(\n",
        "    X: Tensor,\n",
        "    Y: Tensor,\n",
        "    seed: int = 1,\n",
        "    n_folds: int = 5,\n",
        "    metric: str = \"accuracy\",\n",
        "    hidden_size: int = 5,\n",
        "    clf_kwargs: dict[str, Any] = {}\n",
        "):\n",
        "    clf_class = MLPClassifier\n",
        "    ndim = X.shape[-1]\n",
        "    defaults = {\n",
        "        \"activation\": \"relu\",\n",
        "        \"hidden_layer_sizes\": (hidden_size * ndim),\n",
        "        \"max_iter\": 1000,\n",
        "        \"solver\": \"adam\",\n",
        "        \"early_stopping\": True,\n",
        "        \"n_iter_no_change\": 50,\n",
        "    }\n",
        "    defaults.update(clf_kwargs)\n",
        "\n",
        "    scores_ = c2st_scores(\n",
        "        X,\n",
        "        Y,\n",
        "        seed=seed,\n",
        "        n_folds=n_folds,\n",
        "        metric=metric,\n",
        "        z_score=True,\n",
        "        noise_scale=None,\n",
        "        verbosity=0,\n",
        "        clf_class=clf_class,\n",
        "        clf_kwargs=defaults,\n",
        "    )\n",
        "\n",
        "    scores = np.mean(scores_).astype(np.float32)\n",
        "    value = torch.from_numpy(np.atleast_1d(scores))\n",
        "    return value"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "id": "6e6dbc38-46a0-46e8-b9c7-433a7cf1a8c0",
      "metadata": {
        "id": "6e6dbc38-46a0-46e8-b9c7-433a7cf1a8c0"
      },
      "outputs": [],
      "source": [
        "# compute c2st with various neural net hidden sizes\n",
        "\n",
        "budget = 10_000\n",
        "\n",
        "hidden_sizes = [1, 2, 4, 8, 16]\n",
        "poors = []\n",
        "for hidden_size in hidden_sizes:\n",
        "    _ = torch.manual_seed(1)\n",
        "    poors.append(poor_c2st(data.sample((budget,)), gen_model.sample((budget,)), hidden_size=hidden_size, seed=1).item())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "id": "39910e45-fcc9-46f1-88d4-13f58d193a9a",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 134
        },
        "id": "39910e45-fcc9-46f1-88d4-13f58d193a9a",
        "outputId": "c462a12b-981e-4587-cfbf-2cc49139133e"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 160x70 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAK4AAAB1CAYAAADX9doCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAS2ElEQVR4nO2dfVQU1f/H37vLPrCwIKAIKyg+lQjIuiChkKJyTCPJVKIEFXw4ah09GXAUKUSkb6ZmFpzq5AMomGn+fiYqhdhBrZ8WoiD6NU0KEHlWE9hdnnb3/v4AJ1aeFmSB0fs6Zw4zd+7ez52ZN3c+M3Pv53IIIQQUCsvg9ncFKJSeQIVLYSVUuBRWQoVLYSVUuBRWQoVLYSVUuBRWQoVLYSVUuBRW8lTCrampgYuLCwoLC9vsu3HjBjw9PTFu3DgsXLgQSqXyaUxRKDr0WLiXLl2Cl5cXbt++3e7+4OBgfPLJJ7h16xbGjx+PuLi4DstSKpXMolAoUFlZCaVSCfo1mtIhpIeEhISQCxcukBEjRpCCggKdfXfv3iUODg7MdlFRERk5cmSHZQFod1EoFD2tHuUZx6ingk9MTOxwX0lJCYYNG8ZsS6VS3Lt3r6emKJQ29Fi4naHVatukcbkdeyUKhYJZVyqVGDp0qCGqRXmGMIhw7ezsUFZWxmyXlZXBzs6uw/wmJiaGqAblGcYgr8OGDx8OsViM8+fPAwD2798PPz8/Q5iiPKf0qnBfffVVZGdnAwAOHz6MjRs3Yvz48cjKykJsbGxvmqI853AIGVjvnJRKJUxNTQE0+77UjaC0B/1yRmElVLgUVkKFS2ElVLgUVkKFS2ElVLgUVkKFS2ElVLgUVkKFS2Elegm3pKTE0PWgULqFXsKdO3euoetBoXQLvYQ7wLozUCj69cctLy/vtHdXdHR0r1WIQtEHvTuS01aXMpDQq1ujXC7H1atX+6I+tFsjRS+oj0thJXoJNy0tzdD1oFC6hV7CNTc3R1hYGLKysgAA69atg6mpKby8vOiwc0q/oJdw33vvPdTV1cHBwQGnT5/GkSNHkJubi4iICKxdu9bQdaRQ2qDXw5mLiwuuX78OAFi1ahUEAgHi4+MBAOPHj8fNmzd7rUL04YyiD3q1uDwej1nPzMyEr68vs93Y2Nj7taJQukCv97hWVlbIyspCTU0NSktLGeGeP3++00AfFIrB0CfAWF5eHnF0dCSWlpbkwIEDhBBC4uLiiLW1NcnKyurVYGYKhYIGvaN0SY/jKuTn52PIkCEwNzfv1X8k6uNS9EEvH1etVuPzzz9HWFgYfvnlFwDAmDFjYG5ujg8//NCgFaRQ2kMv4a5atQpXrlyBVCrFkiVLsG3bNmbf6dOnDVY5CqUj9Ho4u3z5MvLy8gAAixcvxowZM2BmZoZ33nmHfg6m9At6CZcQgvr6eohEIlhbW+PUqVPw9vaGVCoFh8MxdB0plDbo5SqEhobipZdeYvxbBwcHnDp1CmvWrMGff/5p0ApSKO2hV4v7/vvvw93dXSettLQUJ0+exA8//GCIelEonaKXcH/77TcEBAQgJSWFScvNzUVCQoJBhatWq6FWqw1WPmXgYGTUveD4er3H9fHxwbZt2+Dp6amTfuHCBURHR+PcuXPdMtoZrd/jHj16FCKRqNfKpgxcujsgVy+Z19TUtBEtAEydOhWPHj3qlsHuUN/UBLTqJ0F5NqhWqlBwvwqK+gaYGxtjjE33J6vRS7hNTU3QarVtZs7RaDQG7WTj+tJLMJdIDFb+805+WTlO5V5D6YOHkFpZ4jWZK8bY2hjU5p+lZdh7+kc8vs3zVPXI/rsIbsX34GSvf78XvVyFdevWwcLCAlu2bNFJj46Oxt27d5GUlNSdundKa1fhj8JCDLG07LWyByq3S8vwv5ezUXz/AewHW2H+JHe8KLXtVRuNajXqGhuhamxEXWMjbpeWI/mXX9H64nMAzHR2goWJCdRaDdQaLTRaLdQaDdSt/jJpGq1uPq1umlqjaUn/97ftyY3L4UA+0gH/eetNvY9HL+HW1tbCz88PpaWlmDRpErRaLa5cuQJbW1ukpqbCwsJCb4Nd8bwJ93ZpGWL/5zgznSanZYle8AZelNpCo9WirkVsqsZG1DX8K77Hf+saG6Fq0E1TPZGubmfuuYGElakpDq97V+/8eneyIYQgMzMTOTk54HK5cHd3x8svv9zjinbEsy5crVaLRyoVHigUeFCrwNHfslDWznOCEZcLHo+Lhqbefasi4vNhLBCgWqWCtp1LL+Dx8LLjOPB5PBhxueAb8WDE5cGIx4MRjwsjHg/8lm0ejws+tyWtVT5+S97H+f7dx8XOU2nIu1usY9NgLW5fwmbhEkJQW1+PB7UKRpgPFArcr63FQ0Xz+j9KFTQ9aP34PB6MBYKWhQ+xQAixUACxUAixQAATkRBigRCmQiGMhQKYCoUQC5u3TURCmIhEEAsE4HG54HA4iPruKHIKi3TE2xMBdZf/Ft9DeMq3IAC0hIDL4YADYOfiRXDqRt9ug8wsyUb08TNVDQ2MIO+3EuYDhQIPW9abNJoubXE5HAwyEcPK1BSVNTWoVtW1yfOirS3CX3sVYqEApiIRBEZG4LRc5N74zB7s7YXcwiJwORwdAQV5T3nqsjvDyd4OO4MX4dD/XURBZRVGWg9BkPeUbokWoC0ugH/9zNbtIAfARIcRUGu1jDDrmpr0Ks/M2BhWpqawkphiiJkZrM3MMMRMgiFmEgw1N4eVqSn4Rkbgcji91gL1hP8W33tqAfUXz61w1RoNiu7fR355JX7IvoJHKpVevzMRChlRDpY0i3GIxAzW5mawNpNgiJkZhHw+eJ1Muv0kbBZQf/FcCJcQgqqaWuRXVDQv5RUorLrf5W1dxOdjxXQfDDGXwLql5RQLhc2tIu0V1688kz6uqqEBf1VUtgi1EvnlFaipa+tHmgiFGD3UGhXV1aiortHZx+Vw4GxvB393eV9Vm9INWC9cjVaL4gcPmZY0v7wCpf/8gydvIzwuF8MHW2GsjQ0cpVI42klhb2kJvpFRh36moR9UKD1nQLsK736zB4umTtV5un+gUDACza+oQEFlFRra6UE2WCLBGJuhGGdrC0c7Kcba2MBEKOzwFk/9THbxVML9/vvvERMTg8bGRgQHB2Pz5s06+9PT0xEUFMTEXpg4cSISExM7LbO1cGdsjoWRQICZzk6oVqmQX1GJf5TKNr8xFvAxytoaL9jawnGYFI5SKQabSbr1gERhFz12FcrLyxEeHo7s7GwMGjQIc+bMQXp6Ol555RUmT1ZWFqKiorB+/fpOy1K2EqNCoWDWNS0deNKv5jBpHA4HwywtMGaoNV6USvGCrQ1GDB4MPo+n05rWt+PTUgY2YrFY/4fengZkOHjwIAkJCWG2Dxw4QEJDQ3XyzJ07l8ycOZPIZDLi7+9PiouL2y0LAF3o0q0AMD2+l5aUlGDYsGHMtlQqbRNy1MLCAuHh4cjJycHs2bOxaNGinpqjUHTosaugbed7+5P9dQ8cOMCsr1mzBpGRkaiurm4T/aa1e1BbWwtb2+aHsfLycsbfNTRKpRJDhzZ3aK6oqOjTCDr9ZXug2RWLxXqX0WPh2tnZ4fz588x2WVmZTgC8+vp6fPbZZ4iMjGTSCCHg8/ltyurohJmamvZLCCYTE5N+C/3UX7bZZrfHroKvry9+/vlnVFRUoKmpCcnJyfDz82P2i0QiJCUl4eTJkwCAxMREeHp6duu/ikLpiB4LVyqVYseOHfD19YWTkxNcXV3xxhtvYMWKFUhNTQUAHD58GHFxcXByckJycjL27t3baxWnPN8MuA8QFIo+0Df0FFZChUthJVS4FFZChUthJVS4FFYyYIVbU1MDFxcXFBYW9pnNXbt2wdnZGc7OzggNDe3zqbAiIiIQEhLSZ/ZSUlLg5OQEJycnhIeHG9zek9f00qVL8PT0hJOTE95+++3une8e9K8xOBcvXiTOzs6Ez+eTgoKCPrH5+++/E2dnZ6JQKIhWqyXBwcFk165dfWKbEELOnj1LBg8eTJYuXdon9pRKJbGwsCAVFRWkqamJeHh4kIyMDIPZe/KaVldXExsbG3Lt2jVCCCFvvfUWSUhI0Lu8AdnifvPNN/jyyy8hlUr7zKaFhQUSEhJgYmICDocDV1dX3L17t09sP3z4EFFRUdi0aVOf2AOa475ptVrU1dUx4VyNjY0NZu/Ja5qRkYHJkydjwoQJAID4+HjMnz9f7/IG5NCdrjqbG4KxY8di7NixAIDKykokJCT0aky0zli1ahU++ugjFBcXd525l5BIJNi6dSvGjRsHsViMadOmYcoUww1VevKa5ufnQyKRYMGCBbhz5w68vb2xa9cuvcsbkC1uf1JYWIjp06dj5cqV8PHxMbi9vXv3wt7eHjNnzjS4rdbk5eVh//79KCoqQmlpKXg8Hnbu3Nln9tVqNdLS0rB9+3bk5ORApVLpzObUFVS4rcjNzYWXlxdWr16NqKioPrF55MgRnDlzBjKZDNHR0UhNTcW6desMbjc9PR0zZ86EtbU1hEIhQkJCejVAd1fY2NjAw8MDo0ePBo/Hw5tvvomsrCy9f0+F20JVVRVmz56N+Ph4rF27ts/sZmRk4MaNG8jNzUVsbCz8/f3xxRdfGNyuq6sr0tPToVAoQAjByZMn4ebmZnC7j5k1axZycnJQVFQEAEhLS4Ncrn8oACrcFnbv3o2amhrExsZCJpNBJpP1WavbH8yaNQvBwcFwc3PDhAkT0NDQgI0bN/aZfXt7e+zZswf+/v4YN24cKisrdfpudwXtHUZhJbTFpbASKlwKK6HCpbASKlwKK6HCpbASKlwKK6HCpbASKlwKK6HCpbASKlwKK6HCpbCS50a4hYWFcHBwaJPu4OCAwsJCZGdnY8WKFXr/DuidifJ6s5ye0tGxD2SeG+F2hbu7+3Mb24yNx06F28K5c+eYEQ85OTmQy+WQy+XYsmULk6ewsBDe3t6QyWRYvXo1k65UKrFs2TLI5XK4uroyIkhKSkJgYCDmzJmDF154AYGBgZ2OZC0pKcHs2bPh6emJ4cOHIyIiAgAwffp0pKWlMflcXFxw+/Zt/PXXX5g1axbkcjkmT56MixcvAgBCQkLw2muvwdHREceOHdOxsWfPHri6usLNzQ0LFy5EXV0dc+yNjY1Ml06ZTAYLCwsEBwcDaB4R7ObmBplMhqCgINTW1j7F2e4FDDKkcwBSUFBA+Hw+cXV11VkejzrNzMwk06ZNI4QQ4uzsTH766SdCCCGxsbFkxIgRhBBC/Pz8yNdff00IaZ5K4PHpi4yMJJ9++ikhhBCFQkEmTpxIrl27RhITE4mdnR159OgRUavVRC6Xk9TU1DZ1e1zOjh07yN69ewkhhFRXVxMzMzNSVVVFkpOTSWBgICGEkMuXL5MpU6YQQgjx8vIily9fJoQQkp+fTxwcHEhTUxNZunQpCQoKavc8WFpakkePHhFCCImKiiLZ2dk6x/6Yq1evklGjRpF79+6RmzdvEi8vL6JSqQghhMTFxZGwsLBunP3eZ0AOljQUUqkUubm5OmlP+q/3799HSUkJMwlLSEgI9u3bB6C5Vf72228BAEFBQVi+fDkA4MyZM1AqlTh48CCA5vgBeXl5AAAvLy8mAruzszMePnzYYf3Cw8ORmZmJnTt34saNG2hoaIBSqcSCBQsQERGBmpoaJCUlYdmyZVAoFMjKytLxTZuampiRyR0NfPT394eHhwdef/11zJ8/H25ubm2G7JSXlyMgIADJyckYNmwYjh8/jjt37mDy5MmMnVGjRnV4HH3BcyVcfeBwOCCt+ta3jqDeeh+HwwGPxwPQPNT70KFDzNCTyspKmJub4/DhwxCJRB2W/SRhYWHIz89HcHAw5s2bh7Nnz4IQAmNjY8ybNw9Hjx7FqVOn8PHHH0Oj0UAkEun8I5aUlDDTEHQUQDsxMRE5OTn48ccfERwcjJiYGJ1I8g0NDZg3bx42bNjAiF+j0SAwMJAZUqRUKtHQ0ND1yTQg1Md9AisrKzg4OODEiRMAgO+++47Z5+vrywxZP3HiBOrr6wEAM2bMwFdffdU8Z3BVFeRyOf74449u287IyMCGDRsQEBCA4uJilJSUQNMy3/Dy5csRExMDHx8fSCQSmJubY+zYsUhJSQEA/Prrr3Bzc4O6nckKH6NSqTB69GjY29tj06ZNWLJkCXJycnTyrFixAu7u7li5ciWT5uPjg+PHj6OiogIAsH79emzdurXbx9eb0Ba3HVJSUhAaGoqYmBjm9ggACQkJWLx4Mfbt2wcPDw9IJBIAwObNm/Huu+/CxcUFarUaH3zwAWQyWRu3pCsiIyOxePFiDBo0CNbW1pg0aRL+/vtvjB49Gu7u7hCJRAgNDWXyHzp0CGvWrMH27dvB4/Fw7NgxCASCDssXi8WIiorC1KlTIRaLYWFhgaSkJNy5cwcAcPHiReYhbOLEiSCEYPjw4UhNTUVMTAx8fX2h1Wrh6OjYrRgIhoCOOWMBhBDcunULAQEBuH79er+/9x0IUFeBBezevRszZsxAfHw8FW0LtMWlsBLa4lJYCRUuhZVQ4VJYCRUuhZVQ4VJYCRUuhZVQ4VJYCRUuhZX8P3JaHTmkihsQAAAAAElFTkSuQmCC\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "# plot curve\n",
        "\n",
        "good_vals = [c2st_gt for _ in range(len(hidden_sizes))]\n",
        "\n",
        "with mpl.rc_context(fname=\"matplotlibrc\"):\n",
        "    fig, ax = plt.subplots(1, 1, figsize=(1.6, 0.7))\n",
        "    _ = ax.axhline(c2st_gt, color=\"gray\", alpha=0.6)\n",
        "    _ = ax.plot(hidden_sizes, poors, c=\"#458588\")\n",
        "    _ = ax.scatter(hidden_sizes, poors, c=\"#458588\", s=15.0)\n",
        "    _ = ax.fill_between(hidden_sizes, good_vals, poors, color=\"#458588\", alpha=0.1)\n",
        "    _ = ax.set_ylim([0.5, 1.0])\n",
        "    _ = ax.set_ylabel(\"C2ST\", labelpad=-5)\n",
        "    _ = ax.set_xscale(\"log\")\n",
        "\n",
        "    _ = ax.set_xticks([1, 2, 4, 8, 16])\n",
        "    _ = ax.set_xticklabels([\"1\", \"2\", \"4\", \"8\", \"16\"])\n",
        "    _ = ax.minorticks_off()\n",
        "    _ = ax.set_xlim([1, 17.4])\n",
        "    _ = ax.set_yticks([0.5, 1.0])\n",
        "    _ = ax.set_xlabel(\"Hidden layer size\", labelpad=7.4)\n",
        "    plt.savefig(\"svg/fig2_panel_b.svg\", bbox_inches=\"tight\", transparent=True)\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "5e7594ab-d765-4f86-b114-6382b442ac56",
      "metadata": {
        "id": "5e7594ab-d765-4f86-b114-6382b442ac56"
      },
      "source": [
        "## High-D Gaussian behavior"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "id": "b6f297c6-2857-47d2-b6c9-a014d0430939",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 120
        },
        "id": "b6f297c6-2857-47d2-b6c9-a014d0430939",
        "outputId": "23917f3d-4cfc-4504-8951-7bed05d04159"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 120x70 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJUAAABnCAYAAAAJ+ABdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAATMElEQVR4nO2de1RVZf6HHzzg4RwODijkgMIoaAkIHEBFQCwu3gID0gwbRbExHWksLGfSWhPVLFvNRZIpHSeXhZoGpKlcFUzHMXNm5GLazKiNAoqIopXC4XKA9/cHP/eKvIHuwwHcz1p7LfY5e7/vZ7M+593f924hhBAoKMhIP3MLUOh7KKZSkB3FVAqyo5hKQXYUUynIjmIqBdlRTKUgO4qpFGRHMZWC7MhuqqysLLy8vBg5ciRvvPHGTd/v2bMHBwcH9Ho9er2exMREuSUomBkLObtpLl68SGBgIEePHsXOzo5p06axfPlypkyZIl3z1ltvodPpSE5Ovmt69fX10t9CCAwGAzY2Nmi1WiwsLOSSrSAzspZUhYWFhIeH4+joiJWVFQkJCWRkZHS45l//+he5ubn4+fkRExPD+fPnb5ueTqeTDltbWwYPHoxOp8NgMMgpW0FmZDVVVVUVQ4YMkc6dnZ1vMo29vT0vv/wypaWlTJ06lWeeeUZOCQo9AFlN1dbWdnMG/TpmkZ6eztSpUwH45S9/yVdffcX3339/y/Tq6uqko6amRk6pCiZEVlMNHTqU6upq6by6upqhQ4dK542Njbz99tsd7hFCYGVldcv0bGxsOhwKvQNZTRUZGcm+ffuoqanBaDSyefNmoqKipO+tra356KOPyM7OBuDDDz9k/PjxaLVaOWUomBshM5mZmWL06NFi5MiR4uWXXxZCCPHss8+KXbt2CSGEKC4uFuPGjROenp4iLCxMVFZWdirduro6AQhA1NXVyS1bQUZkbVIwJfX19eh0OqA91lJehz0XpUVdQXYUUynIjmIqBdmxNLeAvo4QgkuXLtHY2IiTkxNWVlbU1dXR1taGra3tTe14fQHFVCaisrKSP/7xj2RmZlJTU4OjoyMhISF4eXlJ7XKWlpa4u7sTGBjI8OHDzaxYPpTan8wIIUhNTeXVV1+lsbGRfv36ERERQVBQkFQq3eh5+GEp5enpSXR0NBqNxiy65UQxlYw0NTWRkJBAZmYmAOHh4URFRXH9+nWgvcfhzJkzrFmzBoPBwMMPP8xvfvMbzp07hxACe3t75syZw8CBA835GPeNYiqZaG5uJjY2lvz8fKysrEhNTcXKyorq6mr69+9PbGwsHh4eAHzzzTc89dRTlJWVodPpyMrK4uTJk3z33XfY2NiQmJjIoEGDzPxE907fixLNQFtbGwkJCeTn56PVasnJyUGr1VJdXY1WqyUxMVEyFMCIESP4+9//TkREBHV1dcyePZvQ0FAGDx5MfX09mzdvlkq33ohiKhl48803ycjIwMrKis8++4yGhgYqKytRq9XMnTuXn/70pzfdo9Pp2L17N8HBwXz33XfMnDmTqKgoBg0axPfff09GRgYtLS1meJr7RzHVfZKfny8Nm/7LX/7CwIEDKSsrw8LCgqeeeuqWhrqBVqtl165duLm5cfbsWRYtWkR8fDwajYaqqioKCgq66zFkRTHVfXDx4kUSEhIAWLJkCdHR0ZIRIiIicHd3v2saDg4O7NixA2tra/Lz80lPT+fJJ58EoLi4mH//+9+mewAToZjqHhFC8Oyzz1JbW4uvry9/+MMf2LFjB62trYwYMYLg4OBOp+Xr68uaNWsAWLFiBQaDgQkTJgCQnZ3d6+KrezLVt99+S3FxMWVlZbcdtdnX+fDDD8nLy0OtVrN161aOHDlCTU0NWq2WmJiYLk/MWLhwITExMRiNRubNm0dISAhOTk40NjaSk5NDL6mkA100VV5eHqGhoTz88MMsXLiQRYsW4eHhQUREBHv27DGVxh7HhQsXWLZsGdA+O2jgwIF88cUXADz++ONS00dXsLCwYP369QwaNIiysjJWr15NbGwsKpWKU6dOcfz4cVmfwZR0up0qMTERBwcHEhIS8Pb27vDdiRMn2LBhA1evXmXTpk0mEdqT2qlmzpzJ9u3bGTt2LF988QUfffQRFy5cwMPDg1mzZt1X2lu2bGHu3Lmo1WqOHz9OdXU1+/fvR6vVkpSU1CtGyXbaVOfOncPFxeWO11RWVuLq6iqLsB/TU0yVk5PD9OnTUalUlJSU0NjYSH5+Pmq1mqSkJGxtbe8rfSEE06ZNY8+ePURERFBQUMBf//pXLl++jF6vJyYmRqYnMR2dfv3dzlCnT5+mpKSEkpISamtrKSkpkU1cT8NgMPCrX/0KgGXLluHm5sbnn38OtNf27tdQ0P4aXLt2LdbW1uzbt4+srCyio6MBKCsro7Ky8r7zMDX3Xfs7dOgQ2dnZ0pGTkyOHrh7JqlWrKC8vx8XFhddff529e/fS1NSEs7MzAQEBsuXj5ubGq6++CrSb187ODj8/P6A9rr3VVLgeRVcHtV+9evWmz77++ut7GB7fNcw98eHUqVOif//+AhDbt28X5eXlIiUlRaSkpIjz58/Lnl9jY6MYOXKkAMSyZctEfX29eOedd0RKSoo4cuSI7PnJSZdLKj8/P4qKim4YknfeeYeIiAhZjd7TEELwwgsv0NzczJQpU4iJiSEvLw8Af3//DrOy5UKtVpOWlgbAmjVrOHv2LOHh4QDs37+furo62fOUiy6bavv27SQnJ5OUlERISAglJSWUlpaaQluPITs7Wxp9kJaWxtGjR7l06RIajcakP6ipU6cSGxtLa2srzz//PH5+fjg5OdHU1MS+fftMlu/90mVTBQQEsGjRIjZv3kxFRQXJycl37N/q7TQ0NPDiiy8C7fGNs7Mz+/fvB9rHS5m6ip+amoq1tTUHDhzg008/5fHHHwfag/Y7LW5iTrpsqgkTJrBz505OnDjBtm3bSEhIYOnSpabQ1iP4/e9/z9mzZxkyZAivvfYaRUVFNDU14eTkhL+/v8nzHzZsGCtWrADgpZdews7ODr1eD/TcoL3TprqxVlR8fDxFRUW4uroyceJESktLaW1tBejR7/l74cyZM9LaD6tXr+bq1ascO3YMaG85765JC8uXL2f48OFUVVXx5ptvEhkZiVqtprq6muLi4m7R0BU6/V+ZP38+7733HnPmzOnwuY2NDatWreLdd99l3rx5sgs0F0IIli5dSlNTE+Hh4cyYMUMKzvV6fYeFR0yNRqORgvbU1FQqKiqkoP3zzz/vsDhcT6DTpsrMzMTS0pLAwEDGjh3LjBkziI+PJzAwkMDAQDQaDVlZWabU2q3s3LmT3NxcrKyseP/99/nnP/9JTU0NGo2GSZMmdbue6OhonnjiCVpaWliyZAkBAQFSh/PevXu7Xc+duKcx6seOHeP06dOoVCpGjBhxU1+gKejObppr167h6elJVVUVK1eu5Ne//jXvv/8+RqOR6dOnd0ssdSsqKirw9PTEYDCwceNGJk+ezIYNGwCYO3cubm5uZtH1Y7ocFDQ3N1NUVER6ejqbNm3i0KFDPTJYvB9WrlxJVVWV1LKdl5eH0WjExcVFatk2Bz/72c9ISUkB2oN2S0tLxo4dC0Bubi5Go9Fs2n5Il021YMEC/vGPf/Dcc88xf/58ioqK+lTt79ChQ6xduxaA9evX87///Y9Tp06hUqmYPn262RewTU5ORq/X8+2337J06VKpz/Hq1ascOHDArNpu0OXXn4eHB//5z3+k87a2Nry8vDp8Zgq64/VnMBjQ6/WcPn2axMRE0tLSWLt2LQ0NDTz22GM8+uijsud5L5SUlDBu3DhaW1vJysrCx8eHbdu2YWFhwYIFC7q1EnErulxSubq6curUKem8urr6rkNiegsrVqzg9OnTODs786c//YmcnBwaGhoYPHiwNLy3J+Dv788rr7wCwOLFixkwYAA+Pj4IIdi5c6fZX4NdNlVDQwN6vZ5JkyYxbdo0PD09OX/+POHh4VI1tzdSUFAgVds3btzImTNnOHnyJP369SMuLg6VSmVmhR357W9/i6+vL1euXGHBggVMmTIFW1tbrly5YvZRuF1eoOOtt94yhQ6zUl1dLbWx3ehj++CDD4D2cVKDBw82p7xb0r9/f7Zs2cKYMWPIz89n3bp1xMXFsXnzZoqLi3Fzc8PT09Ms2h74ae9Go5HIyEgOHjyIt7c3Bw8eZMuWLVy5cgV3d3d+/vOfmz04vxPr1q1jyZIlWFpacuDAAQwGA4cPH0atVrNw4UKzTJ9/4KdovfTSSxw8eBBbW1syMzMpKCjgypUr2NraEhcX16MNBe0xVXx8PC0tLcycOZNHHnkEFxcXmpqayMjIoKmpqds1PdCmWrt2LX/+85+B9k0Dzp07x8mTJ1GpVDz99NM9bhGQW2FhYcGGDRvw9vbm4sWLxMbGMm3aNHQ6HZcvX+bTTz+V+ma7iwfWVDt27JDGm//ud7/DycmJw4cPA/DEE0+YZOCdqbCxsWH37t04OjpSWlrK/PnzmTFjBpaWlnzzzTfdPm/wgTRVTk4O8fHxtLW1sXDhQqZMmSLVmMLCwvDx8TGzwq4zbNgwabWZvXv38sILL0iTWsvKysjLy+s2Yz1wpsrIyODJJ5/EaDQya9Ys5s2bR25uLgBBQUGEhoaaWeG9M27cOHbt2oVarWb37t0sX75c2gfo6NGj7N69u1u61B4YUwkhePvtt4mPj8doNDJ79mzmzp0rjbcPCgpi0qRJPT4wvxuRkZFkZ2ej0WgoKChgyZIlTJw4USqxtm7dSkNDg0k1PBBNCrW1tfziF79g165dALz44ouMHj1aGo4bGRlJcHBwrzfUDzly5AjTp0+ntrYWBwcH1qxZQ3l5OUajETs7O2bMmGGy7pw+bSohBJ988gnJycnU1NSgVqtZtWoVTU1NNDc3o1ariY2NZdSoUaaWbxbKy8uJi4ujrKwMgOeee45Ro0Zx7do1LCwsCAoKYuLEiajValnz7ZOmEkJQWFhISkoKX375JRYWFkyaNInJkydLQ55dXFyIjY3t9Yu23o3GxkZWrFjBu+++C4CjoyNJSUnS9zY2NoSGhuLv73/bLfK6Sp8y1YULF8jMzGTDhg18/fXX2Nvb4+/v3yH41mg0hIWFERAQ0CcXxr8dhw8f5vnnn5em03l7exMTE4OlZXtPnUajwcfHBx8fH5ycnO4rFJDdVFlZWaSkpNDc3MycOXN4/fXXO3xfVVXFnDlzuHjxIk5OTnzyySc89NBDd033x6bSarVUVlZSWlrK4cOH2b9/P2fOnMHZ2ZmhQ4cyYsQIHBwcpPs1Gg3jxo1j/PjxWFtby/nIvYa2tjYyMzNZtWoVx48fR6VS4efnR2hoKD/5yU+k62xtbXF3d8fV1RVnZ+cu9312+27vsbGxxMXFMW/ePDZu3EhhYSHbtm27ZXo/HNBfV1cnzS+MioqiqakJtVqNra0tAwYMwM7OTvrVSQ9nYcGwYcMYPXo0o0aNkq147+0IITh06BBbtmwhOzuba9euScPC3d3d6d+/f4frV65ciVar7XzpJecc+k2bNon58+dL5+np6SIxMVE6b25uFgMGDBDNzc1CCCGMRqOwtbWVzn8M/792gnKY/+jK+hXdutv7jY7aH+7NMmDAAC5fviynDAUT0JVpYLJueHS33d5v15p7u4D5h5NTr1+/jpOTE9D+mr2XJRAVukZ9fb0UT3Vler+spho6dCh/+9vfpPMf7/bu6OjItWvXaGlpwdLSkpaWFq5fv37bMT+3azbQ6XS9YgRBX6IrtcFu3e3dysqKRx99lI8//hiAjz/+mMcee0wJoPsa9xiT35a77fZeWVkpIiIihKenp5gwYYKoqKjoVLrmXvTsQeRe/+e9pvFToffw4DQpK3QbiqkUZEcxlYLsKKZSkB3FVAqy0+tMdezYMYKCgvD19SUsLIyKigpzS+qzvPLKK3h6euLl5cXq1as7f6PJGjlMxJgxY0RhYaEQQoh169aJ2bNnm1lR3yQnJ0dMnDhRtLS0CIPBIIYNGyb++9//dupeWbtpuoMvv/wSS0tL2traqKiowN7e3tyS+iRRUVFMnjwZlUrFpUuXaGlp6XTXWK8zlaWlJbW1tXh7e9PQ0NBjFvrqi1hZWfHaa6+xevVqZs2a1fkJtiYuRe+ZzMxMMWTIkA7H+PHjO1yTm5srXFxcREtLi5lU3pm0tDQRHBws2traxNGjR4Wrq6uora01t6wuU1dXJ8LCwsT69es7dX2PNdWtaG1tFRkZGR0+c3BwEJcvXzaTojvT1tYmwsLCRFpamvDy8hIFBQXmltRpTpw4Ib766ivp/L333hNJSUmdurdXmUoIIby8vERubq4QQojCwkLh4eFhZkV3pry8XOh0OrF48WJzS+kSWVlZIiQkRDQ3N4vGxkYRERFx0w/6dvS6mGrr1q0sXryYlStXYm9vz/bt280t6Y5UVFSg0+mknTF62op8t2PmzJmUlJTg6+uLSqVi1qxZnd7KVxmlYELq6+vR6/Wkp6eTmppKQEAADz30ECEhITzyyCPmlmcyel1J1ZtYvnw5kydPJjg4mOHDh6PX6wkNDe2Vq8p0BaWk6mZSUlKIjo5mzJgx5pZiMnpdN41Cz0cxlYLsKK8/BdlRSioF2VFMpSA7iqkUZEcxlYLsKKZSkB3FVAqyo5hKQXYUUynIjmIqBdlRTKUgO4qpFGTn/wDSlFUsOw7O2AAAAABJRU5ErkJggg==\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "# visualize marginal densities\n",
        "\n",
        "with mpl.rc_context(fname=\"matplotlibrc\"):\n",
        "    fig, ax = plt.subplots(1, 1, figsize=(1.2, 0.7))\n",
        "    x = np.linspace(-3, 3, 100)\n",
        "    gaussian = lambda x, mu=0.0: np.exp(-0.5 * (x - mu) ** 2) / np.sqrt(2 * np.pi)\n",
        "    _ = ax.plot(x, gaussian(x), label=r'$\\mathcal{N}(0, 1)$', c=\"k\")\n",
        "    _ = ax.plot(x, gaussian(x, 0.25), label='$\\mathcal{N}(0.25, 1)$', c=\"gray\")\n",
        "    _ = ax.set_xlabel(r\"$x_i$\", labelpad=-5)\n",
        "    _ = ax.set_ylabel(r\"$p(x_i)$\", labelpad=-5)\n",
        "    _ = ax.set_ylim([0, 0.5])\n",
        "    _ = ax.set_yticks([0, 0.5])\n",
        "    _ = ax.set_xlim([-3, 3])\n",
        "    _ = ax.set_xticks([-3, 3])\n",
        "    plt.savefig(\"svg/fig2_panel_c1.svg\", bbox_inches=\"tight\", transparent=True)\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "id": "14b9348b-9da4-48b1-bbb7-43d49d452d84",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "14b9348b-9da4-48b1-bbb7-43d49d452d84",
        "outputId": "baa5c96e-20d2-4e62-d16b-2a72556d4cb9"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "1 Wed Mar  6 15:06:45 2024\n",
            "2 Wed Mar  6 15:06:45 2024\n",
            "4 Wed Mar  6 15:06:45 2024\n",
            "8 Wed Mar  6 15:06:45 2024\n",
            "16 Wed Mar  6 15:06:45 2024\n",
            "32 Wed Mar  6 15:06:45 2024\n",
            "64 Wed Mar  6 15:06:45 2024\n",
            "128 Wed Mar  6 15:06:46 2024\n"
          ]
        }
      ],
      "source": [
        "# compute true c2sts across dimensionalities\n",
        "\n",
        "import pandas as pd\n",
        "\n",
        "df = []\n",
        "for dim in [1, 2, 4, 8, 16, 32, 64, 128]:\n",
        "    print(dim, time.ctime())\n",
        "    true = MultivariateNormal(0.0 * ones(dim), eye(dim))\n",
        "    model = MultivariateNormal(0.25 * ones(dim), eye(dim))\n",
        "\n",
        "    c2st_optimal_score = c2st_optimal(true, model, 100_000)\n",
        "\n",
        "    df.append(dict(\n",
        "        dim=dim,\n",
        "        c2st_optimal_score=c2st_optimal_score.item(),\n",
        "    ))\n",
        "df = pd.DataFrame(df)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "id": "3e6c9080-63c2-4f57-9b96-4a1d6ca07323",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 137
        },
        "id": "3e6c9080-63c2-4f57-9b96-4a1d6ca07323",
        "outputId": "54823ef2-d6f7-464d-fb7b-510e24473c3a"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 120x70 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI8AAAB4CAYAAADL9KEyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAARS0lEQVR4nO2deUxUVxvGn2FREKgUW1AcAQuoIOKAiAtUBHFBWkBcQptqVJS6oFbcqqkEFWtxQ0NsKR202lpbbLC1VEuNLMYoogQqaKAsguzuUlBkgPf7g88bxxnwMswwA5xfMsnce+49572Xh7O+8x4BEREYDAXQUrcBjJ4LEw9DYZh4GArDxMNQGCYehsIw8TAUhomHoTBMPAyFYeJhKIxSxFNXV4cxY8agtLRUJi0vLw8TJ07EqFGjMG/ePDQ0NCijSIYG0GXxXL16FW5ubigoKJCb/sknnyAqKgr5+fmwt7dHZGRku3k1NDRwn/r6ety7dw8NDQ1gKygaCnWRxYsX06VLl8jS0pLu3LkjlXb37l2ysrLijsvKymj48OHt5gVA7qe+vr6rZjJUgE5XxXfs2LF20yorKzF06FDu2NzcHBUVFV0tkqEhdFk8HdHa2ipzTkur/Zayvr6e+97Q0AAzMzOV2MVQDioVj1AoRHV1NXdcXV0NoVDY7vUGBgaqNIehZFQ6VLewsMCAAQOQnp4OADh69Ch8fX1VWSSjG1GJeGbPno0bN24AAE6dOoXPP/8c9vb2yMzMxM6dO1VRJEMNCIg0cxzc0NAAQ0NDAG19IdakaR5shrmPU1FRgdTUVIVGwUw8fRixWAwLCwt4eXnB0tIS8fHxnbqfNVt9lNzcXDg6Okqd09bWRmlpaYcj4ldhNU8fJDk5GZ6enjLnW1paUFRUxDsfJp4+xPPnz7F27VrMmjULDx8+lEnX1taGjY0N7/yYePoI2dnZGDduHGJiYgAAoaGhOHLkCLS1tQG0Cefbb7/l3WQB6PrCqKqor69nC6NKoLm5mb766ivS1dUlADR48GA6d+4cl15eXk6pqalUXl7e6byZeHoxpaWlNGXKFO49BgQE0P3795WWv0rXthjdS0VFBQoLC2FjY4P09HSsXr0adXV1MDQ0xOHDh7FkyRIIBALlFchHYRUVFUpTK19YzdM5xGIxaWlpyfhCTZo0iYqKilRSJi/xODk5qaTwjmDi4U95eblc4WzYsIEkEonKyuU12iLNnEdk/J/8/Hy5vlMffPABdHRU1zPhlXNNTU2Hq+Hh4eFKM4jROcrKyrBt2zaZ852ds1EE3vM81NbEyf0w1ENiYiJEIhGuX78OPT09rjOs0JyNIvBp21ifR7N49uwZrVixgns/rq6uVFxc3KU5G0Xg1WwRq100hlu3biEoKAh5eXkAgM2bNyMyMhK6uroAoPra5lX4KKyqqkrFGpaF1TzStLa2UlxcHOnr6xMAMjMzo+TkZLXaxEs8DQ0NFBYWRteuXSMiojVr1pCBgQFNnjxZZVUkE0/bEDwlJYXy8vJo/vz53PuYMWMG1dTUqNs8fuJZvnw5rVy5kmpraykpKYlMTU2psLCQzpw5QwEBASoxrK+LR96kn46ODu3du5daWlrUbR4R8RSPg4MD9z0kJIRCQ0O5Yzs7O+VbRX1bPOXl5SQQCGQm/c6ePatu06TgNVR/uWwPAKmpqfD29uaOm5qautrtYrzGTz/9JHeQYmRkpAZr2ofXaGvQoEHIzMxEXV0dqqqqOPGkp6d3b+++l1NVVYX169cjISFBJq07Jv06DZ/q6ebNm2RnZ0cmJiZ0/PhxIiKKjIwkU1NTyszMVEmV2JearebmZjp8+DAZGRkRANLW1qbp06dzfR5tbW0Si8XqNlMGhf15CgsL6cmTJ8q0RYq+Ip7MzExydnbmnnXChAmUnZ1NRF1z1OoOeIlHIpHQoUOHKCwsjC5duiSV9sUXX6jEsN4unsePH9OqVau4jrGxsTHFxsZqzEiKD7zEs3TpUlq4cCHt37+frKysaM+ePVyaqpYueqN4ysvL6eLFixQTE0NmZmbc8y1cuFAj5m06Cy/xjBkzhvteW1tLo0ePpiNHjhARkUgkUolhvU08YrFYZvg9cuRISklJUbdpCsN7bauxsRF6enowNTVFUlIS3N3dYW5urly3xl7K+fPnsWzZMqlzAoEAf/75J6ytrdVklRLgo7ADBw6Qo6OjVH8nOzubBg8eTAYGBipRdW+oea5cuUKzZ89uN1xeamqquk3sErxqnrCwMLi4uEidq6qqwh9//IHffvtNmVru8RAR0tLSEBkZiZSUFABt0dBe9/TTyHmbTsJrhjkjIwPz589HY2Mjdy4nJwd+fn7w8/NTmXGazqsRJogI58+fh7u7O7y8vJCSkgIdHR0EBwejoKAAYrG4az+w00T4VE8eHh509epVmfPp6enk4eGh3Lrw/2h6s/XqwqVAICALCwvO3v79+9Pq1auprKxM6h5Nn7fpLF32JBw7dqyybJFCk8XT3sKlvr4+bdiwQS3+T+qAV59HIpGgtbVVJpJpS0tLn1oYraysRGJiIsRisdyFy1OnTsHf318NlqkHXn0eT09P7NixQ+b8jh074OrqqnSj1IW8KFllZWU4ePAgJk+eDKFQiLVr1+LmzZsy92pra2PcuHHdaa764VM91dXV0fvvv0/W1tYUFBRECxYsIGtra3J3d6dHjx6ppErs7mbr9T5MYGAgubi4yDRNbm5uFB0dTVFRUaStra3RC5eqhndkMCJCamoqsrOzoaWlBRcXF7z//vuq0nS3RgYrLy+HpaWl3KZIS0sLU6ZMwbx58zBnzhyYm5tzaRUVFSgqKoKNjU3PHzkpgnq12z6qrnlaW1vp2rVrtGXLFhIKhXIn8cLCwqi2tlbpZfcWenVMwpdRI2xtbSEUCtHc3IzLly8jMTERZ86c6TACaGfj8/VFem2Ilfj4eISEhKC1tRUCgQBubm7Iz8/HgwcPuGsMDQ3h6+uLwMBA3L9/H+vWrUNLS0vvmcRTMT2q5nm9JnmV58+fo6SkBEVFRcjKysKuXbvk5mtiYgJ/f38EBgbC29sbenp6XFqf78N0kh4jnp9//hnLly8HEUEgEMDf3x/GxsYoLi5GcXExqqqq3pjnwYMHsWbNGpVGjuhL9AjxFBQUYNSoUW/82fPAgQNhY2ODIUOGICkpSSqN9WGUT4/4FywqKpIrnKVLl8Lb2xvW1tawtraGiYkJ518UHx+PTz/9lPVhVEiPqXns7Oyk3Br41CSsD6NauhyH+fTp0xg9ejRsbW3lLmEkJyfjnXfegUgkgkgkwpIlSzpdxtChQxEXF9dplwahUIipU6cy4aiILtU8NTU1mDBhAm7cuAFjY2P4+Phg06ZNmDlzJnfNrl27YGhoiPXr178xv1e3za6vr8fgwYMBALW1tTAwMEBlZSWKi4thbW0ttXcpQ/kMGDDgzS7GXZlhPHHiBC1evJg7Pn78OC1ZskTqmg8//JCmTZtGIpGI/Pz8OvRlQTvumuzT/R8+s/pdarb47F789ttvY+PGjcjOzsasWbPw8ccfd6VIhgbRpdEWn92Ljx8/zn1fuXIltm7diqdPn2LgwIEy9766u/F///2HIUOGyC33ZTPWEa/ujvz69fLSXj8HgDsuKSnBe++9J/O9PXqDfQMGDOgwHeiieIRCIbf5LCC7e3FjYyOio6OxdetW7hwRcSHQXofv+pWBgUGn1ro6ul5eWkfHfMrtbfa1i8IdHiKqrKwkS0tLqqmpoaamJpo+fTolJiZKXTNixAgurszRo0dpxowZCpWlyW6pRH3Tvi67ZCQkJJCDgwPZ2trSxo0biYgoODiYfv/9dyIiysrKIldXV7K3tydPT0+6e/euQuX0xT+OMlGFfRo7ScjQfNhmbQyFYeJhKAwTD0NhmHgYCsPEw1AYJh6GwvR48ZSXl+Ojjz5CSEgITp48qW5z5FJSUiITokZTuHXrFhYuXIjQ0FBERkZ27malzBapke3bt9P169eJiGjmzJlqtkaWx48f0+bNm1UWTaSrpKWlcfEQfXx8OnVvj695ampquJX9VyPVawrGxsaIiorivCI1DQ8PD5iZmSEqKgpBQUGdurfHi2fYsGHcLyfkrfIzOqaxsRGrVq2CSCTCokWLOnVvj3CA74hly5Zh48aN0NPTw9KlS9VtTo9jx44dyMrKQl1dHU6fPg2xWMz/ZlW0o8rg6dOn5ODgQHfu3OHOJSQkkL29PdnY2FBERIT6jCNmH5ESVtVVwZUrV8jBwYF0dXW5h6+uriYLCwu6d+8eNTU10bRp0+ivv/5i9qnRPo3s88TFxeHrr7+WCmdy4cIFeHl54d1334Wuri4WLVqEX375hdmnRvs0ss9z7NgxmXN8/KW7C2ZfGxpZ88iDj7+0OumL9mnO070BoVCI6upq7vh1f2l10xft6zHi8fb2xsWLF1FbWwuJRIIffvgBvr6+6jaLoy/ap5F9HnmYm5tj37598Pb2xosXL+Dv7485c+ao2yyOvmgf82FmKEyPabYYmgcTD0NhmHgYCsPEw1AYJh6GwjDxMBSGiYehMEw8DIXpVeIpLS1Fv379uOCZdnZ28PX1RUlJCaqqqjB79uxutyk2NhaxsbFKzTMiIgIREREAAJFIBADIzMzEli1blFrOm+gxyxN8MTc3R05ODnccExODmTNn4tatWzh37ly327NixQqV5v/yWW/fvs1FDOsuelXNI4+X2wXExsbCysoKALB48WKsWrUKTk5OsLCwwI8//oi5c+fC2toan332GYC2LTA3bdoEZ2dnODo6cntZpKWlwdvbG3PnzoW9vT1mzJiBR48eobW1FStXrsTYsWPh7OzM1Qyv1hJJSUkQiURwdHREQEAA98e2srJCeHg4Jk6cCFtbWyQnJwMA8vLyMHXqVIwfPx4WFhaIjo6WeT6BQICHDx8iPDwcZ8+exc6dO+Hp6Sn1jzJmzBgUFBQo/d32evEAbS/v1W29gTbnqOzsbOzatQuhoaH45ptvkJOTg6NHj+LJkyeIj49HU1MTsrKykJWVhatXr3JbEmRkZCA6Ohq3b9+Gvr4+Tp48idzcXGRmZuKff/7BlStXUFhYiGfPnnHl3bt3DyEhIUhMTMTNmzfh5uaG0NBQLn3gwIHIyMjA3r17sW3bNgCAWCzG1q1bcf36daSnpyM8PFzu8w0aNAg7d+6En58fwsPDERwcjBMnTgAAbty4gbfeegsjR45U6jsF+oh4AEBfX1/q+KU7gqWlJRwcHGBqagojIyOYmJjgyZMn+Pvvv5GUlAQnJyeMHz8ehYWFyM3NBQA4ODjAwsICQFuf49GjR7CxscGLFy8wZcoUHDp0CLt375YKCpmZmQlXV1cu2GRISAguXrwoY8/L/ADgwIEDaG5uxp49e7Bt2zapgJ8dMXfuXKSnp6Ourg7ff/+9yn5V0ifEk5OTIxPds1+/ftx3ebvgtLS0YN++fcjJyUFOTg4yMjKwbt06AJDaZkkgEICIYGBggJycHGzfvh0PHjzApEmT8O+//3LXve7JR0SQSCTc8cs8X+YHAAsWLMCvv/4Ke3t7fPnll7yfV19fHwEBAUhISEBSUhIWLFjA+97O0OvFExMTg/79+8PLy6tT93l5eeG7776DRCLBs2fPMG3aNFy4cKHd6y9fvgwfHx94eXlh//79sLe3l+pnTJgwAdeuXUNJSQmANid1Dw+PDm24cOECdu/eDX9/fy7qbEtLi9xrdXR00NzczB0HBwcjIiICU6dOhZGREe/n7gy9brRVVVXFDV9bW1sxYsQInD9/XurF8mHFihUoKiqCk5MTJBIJgoKC4O/vj7S0NLnXu7m5YdSoUXBwcIC+vj6cnZ3h4+ODrKwsAG0xk+Pi4hAYGAiJRIJhw4YhPj6+QxsiIiLg7u4OY2NjjBgxAsOHD+fE9zoTJ07Ejh07sGnTJuzbtw8uLi7Q09NTaK8PvjBnsF4IESE/Px/z589Hbm7um/eQUJBe32z1RQ4dOgQvLy/ExMSoTDgAq3kYXYDVPAyFYeJhKAwTD0NhmHgYCsPEw1AYJh6GwjDxMBSGiYehMP8DGhwbml3Mhi8AAAAASUVORK5CYII=\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "# plot curve\n",
        "\n",
        "with mpl.rc_context(fname=\"matplotlibrc\"):\n",
        "    fig, ax = plt.subplots(1, 1, figsize=(1.2, 0.7))\n",
        "    _ = ax.plot(df['dim'], df['c2st_optimal_score'], label='gt', color='k')\n",
        "    _ = ax.scatter(df['dim'], df['c2st_optimal_score'], label='gt', color='k')\n",
        "\n",
        "    _ = ax.set_xlabel(r\"Dimensionality\")\n",
        "    _ = ax.set_ylabel(\"C2ST\", labelpad=-4)\n",
        "    _ = ax.set_xscale('log')\n",
        "    _ = ax.set_ylim([0.5, 1.0])\n",
        "    _ = ax.set_yticks([0.5, 1.0])\n",
        "\n",
        "    plt.savefig(\"svg/fig2_panel_c2.svg\", bbox_inches=\"tight\", transparent=True)\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "29aee1c5-fd0c-4153-91db-9dc4505914cd",
      "metadata": {
        "id": "29aee1c5-fd0c-4153-91db-9dc4505914cd"
      },
      "source": [
        "## MNIST behavior"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "id": "b4090aaa-3067-4e40-90af-0228325dd5ea",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "b4090aaa-3067-4e40-90af-0228325dd5ea",
        "outputId": "c38e7e71-5c43-4a39-b40d-b1ccab58be73"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/sklearn/datasets/_openml.py:968: FutureWarning: The default value of `parser` will change from `'liac-arff'` to `'auto'` in 1.4. You can set `parser='auto'` to silence this warning. Therefore, an `ImportError` will be raised from 1.4 if the dataset is dense and pandas is not installed. Note that the pandas parser may return different data types. See the Notes Section in fetch_openml's API doc for details.\n",
            "  warn(\n"
          ]
        }
      ],
      "source": [
        "# download mnist\n",
        "\n",
        "mnist = fetch_openml('mnist_784', as_frame=False, cache=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "id": "012c157c-6767-49f3-83ce-9426116043a2",
      "metadata": {
        "id": "012c157c-6767-49f3-83ce-9426116043a2"
      },
      "outputs": [],
      "source": [
        "# preprocess data\n",
        "\n",
        "X = mnist.data.astype('float32')\n",
        "X /= 255.0\n",
        "y = mnist.target.astype('int64')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "id": "ae75703b-4407-41ff-963d-6f348e2a5c3c",
      "metadata": {
        "id": "ae75703b-4407-41ff-963d-6f348e2a5c3c"
      },
      "outputs": [],
      "source": [
        "# restrict to ones\n",
        "\n",
        "mask = (y == 1)\n",
        "X = X[mask]\n",
        "y = y[mask]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 25,
      "id": "36faf67d-78b8-4cb9-bab7-c53d5a014c59",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 74
        },
        "id": "36faf67d-78b8-4cb9-bab7-c53d5a014c59",
        "outputId": "3a853b1b-2a35-4d16-b7f1-18545942f11b"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "GaussianMixture(n_components=20, random_state=1)"
            ],
            "text/html": [
              "<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GaussianMixture(n_components=20, random_state=1)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GaussianMixture</label><div class=\"sk-toggleable__content\"><pre>GaussianMixture(n_components=20, random_state=1)</pre></div></div></div></div></div>"
            ]
          },
          "metadata": {},
          "execution_count": 25
        }
      ],
      "source": [
        "# fit 20-component gmm\n",
        "\n",
        "_ = torch.manual_seed(1)\n",
        "_ = np.random.seed(1)\n",
        "\n",
        "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)\n",
        "gmm = GaussianMixture(\n",
        "    n_components=20,\n",
        "    random_state=1,\n",
        ")\n",
        "gmm.fit(X_train)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "id": "7901813d-055a-4c84-90c3-0e162f45663b",
      "metadata": {
        "id": "7901813d-055a-4c84-90c3-0e162f45663b"
      },
      "outputs": [],
      "source": [
        "# sample from gmm\n",
        "\n",
        "gmm_samples = gmm.sample(10_000)[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 27,
      "id": "1291ae79-0bb9-47df-87a2-8298b6401685",
      "metadata": {
        "id": "1291ae79-0bb9-47df-87a2-8298b6401685"
      },
      "outputs": [],
      "source": [
        "# save for later use\n",
        "\n",
        "with open(\"many_gmm_samples.pkl\", \"wb\") as handle:\n",
        "    pickle.dump(gmm_samples, handle)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 36,
      "id": "c1ec0e40-1e6e-4030-8187-89ac3ec5e0c2",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 227
        },
        "id": "c1ec0e40-1e6e-4030-8187-89ac3ec5e0c2",
        "outputId": "71598c19-bdda-49a3-ada8-adb878efa943"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 65x65 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAEiElEQVR4nO2cW2wUVRyHv5nt0nZrb8plVaxsQyVFSQvxwSsxBRKDlxgvxAvG6oMxUZNGQ4yGR40+oEGfKA9gYogP1WiCErWgIYBExYCBxPS2KWCJBUsVStvtXsaH/2xLKf/udne6jHi+l0l2zpw5+e135lx2WstxHAfDFOwr3QC/YoJRMMEomGAUTDAKJhgFE4yCCUbBBKNQlG3BNfYTs9mOgtKeastYxhijYIJRMMEomGAUTDAKWY9KfuPc03cAMDxfvtvw5h89rd8Yo/CfMsYuKQGgY1MDH6/dAsCmk/cD0D9wJwBVnxyUsmVlAKQuXMjtXnm19CrGG2PsgBxTSU+q0zizfjkAtzd28sP5pQC0Rj4DYN3Q6wAEqiql8Lzr5NgVzelexhgFb4yZZVPSnF0ZA+BQ7Z7xzx7veQSA0Bc/AWDVLpITw6N53csYo+DrUSmwOALAucb5AKxacmxKmY6dtwBQE5bnXKzmWgCCv3bldW9fBmOHQgD0NIcB+ODJ7QA8EJroHo92rwHg+v3ucByQYIq7+wFIDA3l14a8rr6K8ZUxgfo6ALqemwtA9bK/gAlTTiTEgtUHXiby1G8ABGvdidzAWTmOXvLQtQM5DQ7GGAVfGTO8qAqApqYjALQuPDjp/OodGwCIvDnxeSLaO22dlm3hpGbeFmOMgm+MsRvqGX11EJhqyprfHwKgrrUPgGRwDk4iLiczvMXiJBK5tSenq/4HXDFjrCK5tbV0MQAdr5USbfx0UpnPhyoAiL8v85lgcEBO2FZGU/LFk2Cs4mIA7FLZL0n+/Y9e2F2Jx1c2AND7gjwZo03bxotsPL0MgJ3b7wUgvEt255KWJQUK8BKY6UoK+RmT/gZT8g1Oa4pLvKkRgHUffQPAS1V94+fe6Jdzh1+RY/iAmJI20onF8mruTDDGKORnjNvXnfhYxqLp/dreZ+Wai00B2DMSoL1V9m0XnJdhG/eaKdP8AmCMUSjIcF0UuZmOd6oBiN4no0/MkQnasTExqGVLCwt3HAUmhvLkNKYEqqW+5KDYZQXnANnZmw3GGIWCGDMamcu3d38IQI87kz8SuwGAtzevB6Bm1x84oVIAkv2nM9Y5boprV04rxWkwxijMijGBCpnKn3nsVgCaN3xFpS1znq2DKwBoa10FQPkpWeQlTvTltKGU6yIxE94Ec8lU/VTzbQCsfX4/ACVWnO9HpOtsO3oXANe4rhYPun2rQD/BZIvpSgreLAlcU05uFBvG6kcA+Pq4dKUV4Ur29cgquuxnecBWdYsp9t7DeTVhtjDGKHiyJEjT8syXALz33cMAvLh8NwDRkXmkEmJXeZ88S0raxRSnQC8EzBRjjIIno1J6gfju3gcBWFAnvwfd5O64be28h8pfpEzFvm4AkknXEJ/+5aExRsGbrc2aGwGobRMLxirkpZ23ljQDUH48RfWhP4Hspvt+wBijMHNjLrMhnezskco6J1cauugyf405mZl5MD59WHqN6UoKJhgFE4yCZf63w+UxxiiYYBRMMAomGAUTjIIJRsEEo2CCUTDBKPwLwW0yNWGIeawAAAAASUVORK5CYII=\n"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 65x65 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAEQElEQVR4nO2cW2hcVRSGv3PmkluT9EHTFEmbTsBIW6yxXrBCa7H1hlgq5iFFBS0IiVpJlCK1jylG8II+BAqV9kml+iC+NGAQJFoLxYKRqo1CWxytbZqYiprp3I4PayYmE1eamTmTTGR9L2fYs2efzX/+vdfa68A4nud5GLNwF3sC5YoJo2DCKJgwCiaMggmjYMIomDAKJoxCcL4dt7vtpZzHgvJp+sNr9jHHKJgwCiaMggmjYMIozDsqlRsXejYBMPxS/4z2Da91AXDDez8CkBodLWh8c4zCknKMW1Mj19pluElpe3M8AkBH3TAAsQYpSHqxGABOKIyXiOd/r2In+39lSTnmr23rAPh1s8vehz4G4JOLGwDoP3Y/AC37TwCQLrKUbY5RWBKOCSyvB+BKs0y39+H3ubUyCkDfuQcBWHEq4xCfXnqYYxSWhGOiT8veUnffbwA0hcZ46vsnAFj9gTzb8MAJX+9Z1sJcelaSuN7OIwBsq5oAYN/FTVQfkOXlDp0syb1tKSmUpWMmd9wBQGr77wCsC18CYDheBcA3r7QR1pziOHK1cF0ainKME5Sfe8lMfu4G5JpOzX+MUFjGSMSZePIuAFo7TwPwQuMgACsD0qe97zkArh/4Sh/QwnVpKcoxU07JkodTpsbIHPCCjStY3/UtAFuW/wBAaygNQNuRbgDWHPpaflPQbPPDHKOw6FHJra4GYKQnQtd1hwFYG7oMwPNRSfeb98ueElizGoDk2fMln5c/wmRCpBMMAeRV/xjpvRmAdx45zJbKCQB2nukAIPp5EwDNtbLEFkKQLLaUFHxxjBMIXLtPRQUA3tWrAFx4UdL9R++RZbK18g96RyWxC3VKeF41chyAtB+TzBNzjIIvjpkVtnNwQuEpV13eLUncUPcbACQ88cN42uOz18VF9T+V5mCYD+YYhQUJ114iDre0AtC99ygA9a4cCIfjUs3vONhD00eZBC6bKM5xIMzds/zGHKOwYAnegaPvArCxIjyjfcfAHgBufPU4npsT3eY4EJbKKVnMMQr+OiZnT/iz/U4Avnz7ICBOye4pu07tBmBtn9Rx07W1kEjI51j+h1G/8SfBy6nLxB+4HYC/H78yq+++czsBCAxJzTbWUglAxfgEqcxr1XLAlpJCwY5xgkG8dOYFeuaavHcjAL9slmEH2w4BkPCqCDmysY5Nyml6Zb+E5kBjAwDp1OIvn+mYYxQKdkzWJQCB+joAxiKywd5091kAVgWXAXA6PsljJ58BoOkteRbZcJs8/3OhU5jBrPpzkZhjFAqPStPqu4n1zQA0fCGVt++2NgKwp0ai07HB24i8PEdl3wf8ckoWc4yCL3lMODoO/Ft6bNkl7Wcy30eY5haf3hSWGnOMgi+OyatIXeZOyWKOUTBhFEwYBRNGwbH/dvhvzDEKJoyCCaNgwiiYMAomjIIJo2DCKJgwCv8Aw6skKuvtxSYAAAAASUVORK5CYII=\n"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 65x65 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAADzElEQVR4nO2b72tTVxyHn5ukibZa283InIoohmqHzDnU+mOCqO90c/gDO9jAd7MbCO6FIPsDxlQGls2q+GL4YhtOGChs3d6Ib6TonIKjNBqlUOuPuqazM3ZN7r17cRIhhG9Nk5vclH2fN7n33MM9h0+ecw+ce67luq6LUkDA7w7UKhqMgAYjoMEIaDACGoyABiOgwQhoMAKhYituDeyuZD+qym/OuZfWUWMENBgBDUag6GdMrRKMRs1B00wA3IdDADijo2XdV40RmBLGWCHTzeDc1wAYWTufU19+BcAb4ekALO/5AIAFh+YA8O+GpQBEfr5aUptqjMDUMCYSAcBtbABgcJPzwpQc4V9mAWDHrwAQiSfKalONEfDEmNwzwM1kvLhdAc6zZwDEPzVW3HvvVEGd2SeveNqmGiPgiTGVMiVHMLYYyDcl5YwDsOH6hwBE6fO0TTVGYErMSnsvXM477x1P0XG7HYDou96akqOmg0kcawPgo8YuAP4cfw7ApVQL9fvM8K3UINahJFCTxlirlgNwp92YcnnMlP84vBGA+MctuAO3irtXKFTS5KDGCPhuTLC5GQA7mcRd+yYA3ee/BSBppwBorTP7Di72rAQgdq2n+Aas0v57NUbAd2PsZNIcWBY/nTudLQ0D0BysB6DjvpmdYp9MwpQsbnq8pH6pMQK+G5Nj+qU5pLEB+DVVB0DYMueJVWNV74/vwbRcMyEciJ5nVmAGAJ/37QDg1cO57vVWvV86lAR8M+Z25xoAul8/mS2ZwdcjCwCwvp8NgHPT2zWWyaDGCFTdmMRRM/Xe3dlVcO3M8W0ARM/6Z0oONUagasakt7wNwA87j2dLwnnX99zdTLQr35RKryVPhBojUBVjUu+vobuzE4D6QL4psbP7ze+ROM47bwEQHhgGwHn8BFBjaoqKGuOuXwHAF8dOFJjSdmMXAEu++xuAkS0xmn5/DECmf8BUcuxKdm9CKhKMu86sqwx+lgZg/bRCMSMnXgHg+Vyz1tJ0fQg70W8uehiIruB5TEWMGVphXr7favvmRdmDzD8ArLt4EIBlvWaDj/vIPGDtMjf6SLh2afapMQKeGjO2fTUAqXmF34b1pRsBaD3yCAD3qTGk3C1hL6XE79TUGAFPjZk2ZFbaLLshr3z1H7sZ7TGbCBfVGWPsv4a9bNpz1BgBq9iPRSfaMh9atNAcpLPvkwfui3UDDcam3GYgP9At82XgyTMmc6+/6Lp+mjIZ1BgBDUZAgxHQYASKnq7/b6gxAhqMgAYjoMEIaDACGoyABiOgwQhoMAL/AYxs/x1wl0CYAAAAAElFTkSuQmCC\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "# visualize gmm samples\n",
        "\n",
        "with mpl.rc_context(fname=\"matplotlibrc\"):\n",
        "    for i in range(3):\n",
        "        fig, ax = plt.subplots(1, 1, figsize=(0.65, 0.65))\n",
        "        _ = ax.imshow(gmm_samples[i].reshape(28, 28), clim=[0, 1])\n",
        "        _ = ax.spines[\"bottom\"].set_visible(False)\n",
        "        _ = ax.spines[\"left\"].set_visible(False)\n",
        "        _ = ax.set_xticks([])\n",
        "        _ = ax.set_yticks([])\n",
        "        plt.savefig(f\"svg/fig2_gmm_mnist_{i}.svg\", bbox_inches=\"tight\")\n",
        "        plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 29,
      "id": "6252d5f8-ea29-4a3b-8db0-e174f363d598",
      "metadata": {
        "id": "6252d5f8-ea29-4a3b-8db0-e174f363d598"
      },
      "outputs": [],
      "source": [
        "# make single gaussian approximation\n",
        "\n",
        "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)\n",
        "cov = torch.cov(torch.as_tensor(X_train.T))\n",
        "cov += torch.eye(len(cov)) * 1e-6\n",
        "mean = torch.mean(torch.as_tensor(X_train), dim=0)\n",
        "gaussian_fit = MultivariateNormal(mean, cov)\n",
        "\n",
        "gaussian_samples = gaussian_fit.sample((100,))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 30,
      "id": "8e99ad03-7d04-4c62-ba66-220bb515a0b6",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 227
        },
        "id": "8e99ad03-7d04-4c62-ba66-220bb515a0b6",
        "outputId": "781d51d6-08bb-4654-8895-09a3b2e13015"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 65x65 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAGcUlEQVR4nO2cXWxURRTHf3Pv7rZb2rItHwVKQT6EKlUREPwgkQQ1Eh5MjGJijMZEjcYn33zxyUd5MGpCfNAXw4sfGGPCg4gJMYGIIH7GigryJaVAv3a77bZ77/XhzL3tbncKbe9SQ+b/cvfemTl37pn/nHPmzGRVEAQBFhPgzHYH/q+wijHAKsYAqxgDrGIMsIoxwCrGAKsYA6xiDEhcb8WHnSer2Y8bigP+J9esYxljgFWMAdc9lW4EnIYG+eF5APj5/Kz1ZdYU49TWAhAu7t2FC/Cb6qXw9AV51tgIgDcwUNJWJRI4mbnSPj8EgD9ckELfi6d/sUi5CTF9xigF5akcpUrvdbnT0IBy9RgouQ5uvRWAOd/+AUDXjjZ6No8C0PjbHVK1KE1a3j1cInbgiU24BZFd9/l30/6EyWAZY8D0GVMp8aefqYSIVTU1ADjNGfB8ALyWDADpLrENLJwPwMAq2NYh7Hn70a8AOFYQm/Ny60sArHz9CAA9tysaTwk766b9AZPDMsaAqnilQLtbpa+4Lt588TB+Sl45sDINQPOJgq4LCSWsqlMpALanpX2xSYyN/+DdAIzMDZh7ulCNrkewjDGgOnFMaH9cV24HsqgGYYg3V15ZTIuNyLY3AeAvH6YYyDgllVsibkvH3wAc3SWeLKgp4hw6UZWuh7CMMSBWxoReCF8YE4xIXKJcFzUidsKrkbGouyL2o2etdKH2lzQ/NSyR9stK5b646BAAuXUiv/+dsgpVQDyKCQM7rRCVFLHByIg8TyUJ9LRCV53zT05ui+KSz2938AqpiuJDI/xCZxsAa/ZVJ6gbDzuVDIiHMdrYhm46QsiSIEBd7AYgrYuUnmb1vwpz6tqX4rVNvlvcekBNWh4nLGMMiIcxjjBDOXpEAwnUVEr4oWprKV74V8qu9pS0GX1IgjYVwJalZyqKf6tnFQBzPhtnW7Rdc7TB94eHY/iQMVjGGBCruw6KOk+g2cCQLBS94Ynhe6JNXHP3GvFE/gP9zHFHKsr9YN8jACznSPRMJZLSrqBlh54xplMtljEGzIwx5YmpMkQMGocor6vb5u6XvO7GlotkkpVzvK2HypjkuJEdi7qiGRSMVmbdVDEjxahwLVSugEnyrn42C4CzYB4ALc2Sz11Yk2Vxqq9imzBaHt/ZKDQIp06ZomYKO5UMmBFjKk2VayHM7p16Rozv7tUfArCzbqK73Xh8FwDzevX0CI164E9g64TgsuSlUzfMljEGVCcfM8kIhSPsru8H4LVPnwdg57N72JsVu3O5KAY6lRAW5JfIHlRjoyw4A89HpbSxHcyXyI2Yo3cj8L1puXDLGAOquxNZgTlqUwcAg1dkudBweczlf3F5PQDvL/8SgD3HdgAwL6HTGdrVB1d7CIa1F3JkbJ062S8IA75gdOr2bzwsYwyobs53HNTGdQBculd2C9Jn5XnuljFv8nSLLBKfO/U4AMmczgu3ybVJ72b6+XyFbKH2XNq2hAvakvDGKc0lTwbLGAPiZYy2Ka62BZ6OcgkCcivEo/R1yNxPXZHRa+84FzXfWnsJgDf3rhA5Oqvl6evQ6gUA1KaS0H1Ff4H+hNAr6QWrP1IhrplCdByLYsKgLVwHeb29pRUcFy8pSkv2CUn1Tgn71+6PqiX1NEjo3dv0Vfm4/hWixOH54qIT+XoS/dmSd4eJ9yhryGj07kghNsCbOeJhTFpvppUzJSzfcBu5Vu1WR2TU3njq4wn1vs63AJD5SPIuwX13ATA0X1xxKisMSnQPjE2d0D2HbNDuO9qpGC3aAC9OxMIYf/AaZ+V+7CS14R4AcuvErX6fEwPb5wkbXs2cI+MOlnbuz/MADG9bK/d56W5Ncz2JsgAu6OmTH6NiW6IAb5pHzyxjDIiFMU64oAv0vC6U5niDYpHsSvmdaRZW/JUV13u4S5iz++wOHtv8g8i7czUA3s+d0kntpeovCNvcXIGgVxah0cEBnV+eTiqk4jfFIuUmRDw2Ru/pREdUK9RpOyCjfSaRAeBSu8QWhW/kqNmyk0V+fyUc7c6StovfOwqASsmOQuD7se8jlcMyxoBYlwSTjWLi4HEAVh2U+yhiLZ6U+5oa3MWLRE5vny4rTT4FN/Ck+KydDC83kkGhQPFi1yz1ZiLsVDLAKsYAqxgDlP1vh8qwjDHAKsYAqxgDrGIMsIoxwCrGAKsYA6xiDLCKMeA/FDYJ8sEU4DkAAAAASUVORK5CYII=\n"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 65x65 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAFw0lEQVR4nO2cy28bVRSHv5lxHDt2Xs2blpS0FaEgApFQEatISHTDmh0Llmy6YMn/AAv+BoRArdghsSgI8aigQqCWUtFA2qTvV5LGeTh27Jlhce51Mo6v69iTGKT7bSaZuXN9feZ3z7nn3EmcMAxDLLtw2z2A/yrWMAasYQxYwxiwhjFgDWPAGsaANYwBaxgDiUYbvuW+s5/jOFDOB+ee2qZhwxwETiIROQbFolxoQ9Zip5KB9inG9QBwXAeAsFwm9H25lM2oJl7lGkCoFXQAtM0wbrIDgKBQ2HUtHB8DwCmKQZz7jwDwdxrGcVTj/ZlmdioZiFcxdZ6im0pJEzVNGDok59V0Cf68hvvKSQDWjncDkPnyYqSP0unXAHj0/ibj790GwF9dbXgMe8EqxkC8ijE8JTeVwu3rBSAY7gcgP94DQObaYzk/M829U2kAsncCALzJEwD4s3MApOcWAej9YnSXUrwe6c9f3wDA8badejNYxRg4kKgU+sHuD96U0OwPij/BD8nelXYPZuTY/4cX7ScnKsmeXaicc7vV/YdHAPBUBAsLEsGsYmKmNcWoCOB2dgLmJXxY2iJU6xXn/hIAqdU8AOWFWwB40y9R6OuSn3ukbfD3fKQff2l5e+DPjcu5flGMs6UWgWoNxNytVr6ZVYyJWHyMkxLFaCvXWs0GKlqE5RwAXt+xyHU/m2RlSp765KiKVKdelP4vXNrVnz8oUch7uCJte7MyhiXxQ+VCa+lDa4ZRUyZsYBDVTtD/50bk983hJG++ehWA411imJ+urEnbqr4SY6PkB2TBmM4lAXAeKGPmN/fwBczYqWQglqlUyX6DvS/DvaEhAHrO/8WZj76NXPt+9fWa9/ijA6RviCP255SD1g5fZe0E1TrbG1YxBuJRTBNKqdy7Jn4kKBT4dPkNAD4e+x0A78QEsK2KxNFnpe3sPEFJ+azqNKRFpWisYgzEE651pU1V4CrzPJSlvZPoICxt1bx3Z2j/cPhH9ZOUJir+Q7NVknvy+RhGXR+rGAPxJJFKGXp+Ox1JdV6l/ga17MQbGaZQ5S+qyw7l+w/iGG1DtGYYt2oKKRoxRDXBM0N8tT4JQJ+npornRj6npmOtmrZx1YDtVDLQmmL0E9R11hbqrRsTWX548jwA9zak2telnK2bkaw7UKE9glJKJcOvkac1g1WMgeYV4zjbymhAIXrbtTqZTEwcBWDzkMsvl0Qx6TviN456Uo3Dr7NoU58dl1I0VjEGmlfMHvyIm8ng9kilTYdcR/mE8vxNAPpnD/HuB98BcGXtCADzv0qUSi2qyt0BLOwqYz6wT/qfsS+7BFoNbpdEE6cnS/nm7Uib6g36J5MpPESFnxz5BoC301LB07VeT+1N+Su5HR+mI6KrDmpRqf3STmXrtg1gFWMgVsXoyOMNDQIQdqv1R2cH3Iy22VXqTDmc6ZdGn60dBqCclufmTr0g9+jtqR2KcdNp1XF051Fro9l9pXiya5UbeYOyUR92ST0WtdHmLubQ38k00MRGyIWCtPp66WUAlk+KYXovq3tqhG2daest2ko1sV6IbwA7lQw0rpgay/1KFq0cHmn1qkdRksj1Kdn8SuR9Enfu1u1+eSrkWEKe/s/XpXI3sBBdEjjFknFcerMvNFX2TOcMWMUYaFwxdaxdcahaOZ68RraVFbt3Lj+9DJGdyPH56hQAyTlxqP2z68B2PabeuOJ+P88qxkBLUUkXpPTRW1Ev84xIuO7YlCiTeJgjVK9rVJcOgplpAEq/pTl39jQAKXm3CKckkaXl0lMTe01WMQZiXeD5i/KKB+qYvqrO12ir04atbvFH3QshxV7xUX3XVfS5PBvPwJrYa7KKMdD2vyXIXJS9o0yxWFmLOElZH4Vqv8pRCWIzRfZmaZthdHj1Hz82XmsndioZsIYxYA1jwLH/26E2VjEGrGEMWMMYsIYxYA1jwBrGgDWMAWsYA9YwBv4F3pkOxNJt05IAAAAASUVORK5CYII=\n"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 65x65 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAF8UlEQVR4nO2by28bVRTGfzNjJ847aWzipiVVaJoqJaitxLOgsqJFSBUbWHSJKhAgdYHECom/gFV3XSGxpRJCrFDZ0UIftKhdgFRatUqatCkkdZ62E3tmWJw7E4+T2zi2U0fq/TaT8dw5Pjn3O4977rXl+76PwRrYjVZgu8IYRgNjGA2MYTQwhtHAGEYDYxgNjGE0MIbRIFbpwHfsD7dSj6eKX7xzG47ZloyxYjGsWMVztiXYlobZDmjstGjgF4uNVqEGw1gWPKWFeWz3LgD8+QUA3Pn51Ye2Ex3se+pam27GlTSomjGW46ylvGWtP7hk9mJ7ngegOHY/+mq8icX3DwOQTcl85VIir+9aAYDWO49lcCljPLcq/TeCYYwGVTMmwpYypliO8nt19ZeXw2flTAnlFVYoJkTOJ6d/AuDHqUMAzN4VlrmjSQAefJZi6IvLT1YwiD2eu6rfJuKOYYwG9UnX5TNhKXt7m8sMi7vlvU+7JwH4NTMMwFJW5NgFubYPzlM49jIA8fPX1hdWGnuqyFCGMRpsSYHnF1Y2HOMkewGw2loBiT35lNQgZ2elblksNgPQNCfxbOaA3DfHi7Tc/k/eU/KCWqc4MVmH/8AwRouGLQm82TkAcm/sBSAxdp/+0UcAfDf2OgBxWxhkN8v8LbyaA6DnhyTuxHV51tYm8jKzddWvYYaxW8WFlvokrbZ1d9EaFxcc7vwXgAvfvgJAR5M4zHNJKewSEw5OcgcAnlomlJYEQFUpOqJfVW89A2icK6kZLnTIzM4dG+Fgxw0AOmN5AFI3swBMv9QCQO5yHwADc4t4c2pZoJhhtcgYf3FRPg+YYjtVLRsMYzRoGGOc/jQA+V6Z2WLC5kz/bwB8PnEUgPjDWQAKRyQedd5ThV6uEC43rOYmAPycsCwsLn3FkqANsUkYxmjQMMb4sxIjVvpTAOzYmSHjyaxfHH8BgAFZM+IoMthFYYw18Ygw1wSxxK1v+8EwRoOGMcbNZAD44JAUah/3XqTLlniRn5EM42SmAYgvtQPQc13u/VweLy9ZzUbiT9gGqVPjqnHNcNUv+SYthjk1fpwv0+cjQ9zbdwFITk7JfU4qX7u1dTWoenL1da5kCrz6omGMGf/6NfWXMOby5B5GBsQthvY9jA4OOoJq9r2lpbCw81cKkWfrwnTw6oetZUxp31XBP3IQgIGj4wBMu0sAJJoK4ZiZc7sBSCH9YW9hYa3ssORXS4Jm6dWEaVsVen6xsLboq0T1ikc+Y9haxpQwxenpAeDuuxJHhmPSezl+4yMAZm/v4KB7EoD02UsRMXYiIeKCeFKakhVD7I4OuVfZylfpHN/RZ6wnwDBGg61hzDqx5fF7++ULR2Up0NssseXOstT96Us+/tWuqHJpaTN4fdKUcu5JP9fL5VeZoQo7X9U4wUF3vyCfyx6XYUzdUBtjyveqw/0k5feqH2u1t5E5IGMdT64FT1i160xcFLnxN9m3RyLiilMShyzVz/XK25cl8PL5dT+v9sxDnTfcor0PKy7i544OstIngbMjJkb7/bq41r4LV8I3W37+EwD7RXnm/nVLxD/BIBuiyrWTcSUNamKMFRM3CDbYrCbVTVMz7AZbJEkbLJm5hRlxr+HTV9bIC9Jqdk8nAG0zEnz9LknF7q07tai7KRjGaFBbjCnrpwYpMrx/8xAAK10WVk7tH41Fj4Y5+4cAYcPMKdloW+5ROweDeyNj0w8kGK+7RKhxH6kchjEa1MSYNUfNyjJAsU3EZ0fydHRJAZa42h0ZE8SN2M40liLgco/M+oqq99rVWSN3VHrBsX/uh/tKgQ7BYaX1TnxacRX7KjhsEMAwRoP61DEa/07clNZC6+EhTpz8A4DvD7wFQHeZiOLDKXx7EIDBr6KLSCclOwl+VnYm3ZIlQQDt2WDL2hRTAhjGaGBV+vPip/Eji+AwUXBExA+OqtX5yGolP7LYVkfm3emZRqsQwriSBsYwGhjDaFBx8H3WYBijgTGMBsYwGhjDaGAMo4ExjAbGMBoYw2hgDKPB/ymSCiGaZV+EAAAAAElFTkSuQmCC\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "# visualize single gaussian samples\n",
        "\n",
        "with mpl.rc_context(fname=\"matplotlibrc\"):\n",
        "    for i in range(3):\n",
        "        fig, ax = plt.subplots(1, 1, figsize=(0.65, 0.65))\n",
        "        _ = ax.imshow(gaussian_samples[i].reshape(28, 28), clim=[0, 1])\n",
        "        _ = ax.spines[\"bottom\"].set_visible(False)\n",
        "        _ = ax.spines[\"left\"].set_visible(False)\n",
        "        _ = ax.set_xticks([])\n",
        "        _ = ax.set_yticks([])\n",
        "        plt.savefig(f\"svg/fig2_gauss_mnist_{i}.svg\", bbox_inches=\"tight\")\n",
        "        plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 31,
      "id": "226df20a-8dab-419c-9278-c0c747f56a42",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 227
        },
        "id": "226df20a-8dab-419c-9278-c0c747f56a42",
        "outputId": "ca96a337-901e-4902-cd6d-96173d362ea9"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 65x65 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAB+UlEQVR4nO3avWtTARSG8aeh1NDSoYh0kYtIqbjUNSB00sVBFHFzc3Ox4OruJm6F/gsFdVQMVMShowgiFMQuLYpgoQpKctvrcCGYti8N1OQc8P1BINwEcng4+b5jVVVV2CGN6AGychjBYQSHERxGcBjBYQSHERxGGB/0jlcbt4c5x0i92l899j7eGMFhBIcRHEZwGCE+TGsBWgu83H5H2S4o20X0RECGMEkN/Dlm2LrVHveK1wCstG7UB9ffh83jjRHCN6bxswPAx26X61M7ADx8UP8MXQR+2PbGCGOD/ksw7O9KZbvgxcVnALz5PQHA48tX6tu+fP2nj+XvSieQMsxis8NiswPNU/UlQMowGTiM4DCCwwhpwmy/PRs9Qp80YbJJE2Z6M9fZKGnCZJMmzPSdregR+qQJk43DCA4jOIzgMILDCA4jOIwQ/i/B+PlzANwt1mIHOcAbI4RvzO6lWQBuTn3vHfvQKesr3TJiJCAwTGNyEoDTS5uHbrv1/D4Ac1vroxypj59KQtzGzJ4BYHXuae/Y2q8mABeWvwGwN/qxerwxQviL798efb4GwMTGp+BJvDFS2MZUuz8AeLIzD8DSzEbUKEfyxghpTgMZJZ8GcgIOIziM4DCCwwgOIwz8dv2/8cYIDiM4jOAwgsMIDiM4jOAwgsMIfwC16V1XQIICrgAAAABJRU5ErkJggg==\n"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 65x65 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAACKElEQVR4nO3bv2sTcRzG8XfORutUA6atOEgdHERQikt3i4s4KG46+A9IJ/8B/QecdBEEySS4KjgrXVyKoLRUpRoQhKJE8AfmzuGaa095NIv3+YrPa8mXy0E+PHmO7xFyraIoCuwXWfQAqXIwgoMRHIzgYAQHIzgYwcEIDkaYGPfExezC35yjUY/ye388x40RHIzgYAQHIzgYYexdqQmnng0A6E6Ur7eunQNgqrfc+CxujJBMY7ITR1nq3AUgJwfg+nz53lQvYJ7mP/LfkE4w62+iJ6hJJ5jEJBNMPhhEj1CTTDCpcTBCMts1QEarWgHs/hD3vbkxQjKN2by8QM5TYPsGb+7OBgDfA+ZxY4RkGrN45XG13hx+LRd5HjSNGyOFN6ZYOA7Apc5NYA8Ap29cBeBA/0nUWG6MEt6Yz7OTABxut6tj+9Yj9qG68GBGMjLarV3RY1R8KQnhjXl7ZgiUN3XPv5Xb9N53XyJHAtwYKbwx3dmP1frBp2PlYnklaJptbowQ1phsstymj3TeR43wW26MENeYmS4Atw/dr471Xp0EYD+rITPt5MYIYY15sXQQKO94R+an+wBshExUF7ddbz3aMfq1DqB/cWZr9bL5eX7iS0kIv8HbabgW35QRN0ZIJpjza2ejR6hJJpjUJBPM64dz0SPUJBNMalrjPizqv8wb4GAkByM4GMHBCA5GGHu7/t+4MYKDERyM4GAEByM4GMHBCA5GcDDCD/vjW5gaPr9sAAAAAElFTkSuQmCC\n"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 65x65 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAACqklEQVR4nO2bT0hUURSHv/mXxUxWOg6STIipIUoSRkEEujEKhiAiqF0E0cqgRbRpWRC0qIXCoEEQRJFBuCpoERQZQdiQxeRQWYFBDTIRxrSYZlpMMzX4Do2Ld9+Fzrc858E7/Pjuu4/3x1cqlUooy/B7PYCtaDACGoyABiOgwQhoMAIajIAGI6DBCATrPXDYf8jNOYxyvzj5z2PUGAENRkCDEdBgBDQYgbp3JS/JXB0A4M2eiWpt94vyLtm4760r51RjBKw2JrB+HQADnR8AKPLnYeNiKgZAI2qMUaw2hlgUgBsdt2rKd7430TX+CYCCS6dWYwSsNMa3rReA7itzjv2JYwfwz6dcnUGNEbDSmKXNEQAutj517IfSH/np8gxWBeMLrQJg06mMY39otnxTt3ZpwfVZdCkJWGXM3Fg/AJn2ZE19+NVBACJ73wFg4p2yGiNglTFHdz52rH95tBGAOO+NzaLGCFhjTKB3C31r7tbUHuRXA9A+mgZwfYv+GzVGwHNjAj1dAAzenGF/OAfAfOEHAGcujQAQy00bn8vzYNInNwAw1fS6Wrv+dQcAsTHzgVTQpSTgmTGBlhYAZhKXf1cayBXLS2jy9iAAcdQY6zBujD8cBqDn3iIAEX9DtTeUPA1A/Lx3plRQYwSMG5M9vBWAC62jy3ptD/OmxxFRYwSMGVPZhfqPzzr2dz0/QvOTcs+GT9XVGAFjxmQTnQBMxWuvLde+tQEQPRukWHDrLdHKMRJMINpM34mXjr1z0wkAulPPTIxSN7qUBIwYk9/ewXg86dgLZkMmRlgxaoyA548dwgs+r0dwRI0R8N6Yz0WvR3BEjRHw1fuzqH4yrwAajIgGI6DBCGgwAhqMQN3b9f+GGiOgwQhoMAIajIAGI6DBCGgwAhqMgAYj8AtnBX0OD0TeCgAAAABJRU5ErkJggg==\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "# visualize true digits\n",
        "\n",
        "with mpl.rc_context(fname=\"matplotlibrc\"):\n",
        "    for i in range(3):\n",
        "        fig, ax = plt.subplots(1, 1, figsize=(0.65, 0.65))\n",
        "        _ = ax.imshow(X_train[i].reshape(28, 28), clim=[0, 1])\n",
        "        _ = ax.spines[\"bottom\"].set_visible(False)\n",
        "        _ = ax.spines[\"left\"].set_visible(False)\n",
        "        _ = ax.set_xticks([])\n",
        "        _ = ax.set_yticks([])\n",
        "        plt.savefig(f\"svg/fig2_mnist_{i}.svg\", bbox_inches=\"tight\")\n",
        "        plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Compose figure"
      ],
      "metadata": {
        "id": "rwkh-ZK3Oa6v"
      },
      "id": "rwkh-ZK3Oa6v"
    },
    {
      "cell_type": "code",
      "execution_count": 37,
      "id": "f1441a82-f838-410c-bb15-bb360514e339",
      "metadata": {
        "colab": {
          "resources": {
            "http://localhost:8080/fig/fig2.svg": {
              "data": "CjwhRE9DVFlQRSBodG1sPgo8aHRtbCBsYW5nPWVuPgogIDxtZXRhIGNoYXJzZXQ9dXRmLTg+CiAgPG1ldGEgbmFtZT12aWV3cG9ydCBjb250ZW50PSJpbml0aWFsLXNjYWxlPTEsIG1pbmltdW0tc2NhbGU9MSwgd2lkdGg9ZGV2aWNlLXdpZHRoIj4KICA8dGl0bGU+RXJyb3IgNDA0IChOb3QgRm91bmQpISExPC90aXRsZT4KICA8c3R5bGU+CiAgICAqe21hcmdpbjowO3BhZGRpbmc6MH1odG1sLGNvZGV7Zm9udDoxNXB4LzIycHggYXJpYWwsc2Fucy1zZXJpZn1odG1se2JhY2tncm91bmQ6I2ZmZjtjb2xvcjojMjIyO3BhZGRpbmc6MTVweH1ib2R5e21hcmdpbjo3JSBhdXRvIDA7bWF4LXdpZHRoOjM5MHB4O21pbi1oZWlnaHQ6MTgwcHg7cGFkZGluZzozMHB4IDAgMTVweH0qID4gYm9keXtiYWNrZ3JvdW5kOnVybCgvL3d3dy5nb29nbGUuY29tL2ltYWdlcy9lcnJvcnMvcm9ib3QucG5nKSAxMDAlIDVweCBuby1yZXBlYXQ7cGFkZGluZy1yaWdodDoyMDVweH1we21hcmdpbjoxMXB4IDAgMjJweDtvdmVyZmxvdzpoaWRkZW59aW5ze2NvbG9yOiM3Nzc7dGV4dC1kZWNvcmF0aW9uOm5vbmV9YSBpbWd7Ym9yZGVyOjB9QG1lZGlhIHNjcmVlbiBhbmQgKG1heC13aWR0aDo3NzJweCl7Ym9keXtiYWNrZ3JvdW5kOm5vbmU7bWFyZ2luLXRvcDowO21heC13aWR0aDpub25lO3BhZGRpbmctcmlnaHQ6MH19I2xvZ297YmFja2dyb3VuZDp1cmwoLy93d3cuZ29vZ2xlLmNvbS9pbWFnZXMvbG9nb3MvZXJyb3JwYWdlL2Vycm9yX2xvZ28tMTUweDU0LnBuZykgbm8tcmVwZWF0O21hcmdpbi1sZWZ0Oi01cHh9QG1lZGlhIG9ubHkgc2NyZWVuIGFuZCAobWluLXJlc29sdXRpb246MTkyZHBpKXsjbG9nb3tiYWNrZ3JvdW5kOnVybCgvL3d3dy5nb29nbGUuY29tL2ltYWdlcy9sb2dvcy9lcnJvcnBhZ2UvZXJyb3JfbG9nby0xNTB4NTQtMngucG5nKSBuby1yZXBlYXQgMCUgMCUvMTAwJSAxMDAlOy1tb3otYm9yZGVyLWltYWdlOnVybCgvL3d3dy5nb29nbGUuY29tL2ltYWdlcy9sb2dvcy9lcnJvcnBhZ2UvZXJyb3JfbG9nby0xNTB4NTQtMngucG5nKSAwfX1AbWVkaWEgb25seSBzY3JlZW4gYW5kICgtd2Via2l0LW1pbi1kZXZpY2UtcGl4ZWwtcmF0aW86Mil7I2xvZ297YmFja2dyb3VuZDp1cmwoLy93d3cuZ29vZ2xlLmNvbS9pbWFnZXMvbG9nb3MvZXJyb3JwYWdlL2Vycm9yX2xvZ28tMTUweDU0LTJ4LnBuZykgbm8tcmVwZWF0Oy13ZWJraXQtYmFja2dyb3VuZC1zaXplOjEwMCUgMTAwJX19I2xvZ297ZGlzcGxheTppbmxpbmUtYmxvY2s7aGVpZ2h0OjU0cHg7d2lkdGg6MTUwcHh9CiAgPC9zdHlsZT4KICA8YSBocmVmPS8vd3d3Lmdvb2dsZS5jb20vPjxzcGFuIGlkPWxvZ28gYXJpYS1sYWJlbD1Hb29nbGU+PC9zcGFuPjwvYT4KICA8cD48Yj40MDQuPC9iPiA8aW5zPlRoYXTigJlzIGFuIGVycm9yLjwvaW5zPgogIDxwPiAgPGlucz5UaGF04oCZcyBhbGwgd2Uga25vdy48L2lucz4K",
              "ok": false,
              "headers": [
                [
                  "content-length",
                  "1449"
                ],
                [
                  "content-type",
                  "text/html; charset=utf-8"
                ]
              ],
              "status": 404,
              "status_text": ""
            }
          },
          "base_uri": "https://localhost:8080/",
          "height": 37
        },
        "id": "f1441a82-f838-410c-bb15-bb360514e339",
        "outputId": "92a402d4-be07-4091-e5a6-48a88aad3660"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<img src=\"fig/fig2.svg\" / >"
            ]
          },
          "metadata": {}
        }
      ],
      "source": [
        "def svg(img):\n",
        "    IPd.display(IPd.HTML('<img src=\"{}\" / >'.format(img, time.time())))\n",
        "\n",
        "# > Inkscape pixel is 1/90 of an inch, other software usually uses 1/72.\n",
        "# > http://www.inkscapeforum.com/viewtopic.php?f=6&t=5964\n",
        "svg_scale = 1.25  # set this to 1.25 for Inkscape, 1.0 otherwise\n",
        "\n",
        "# Panel letters in Helvetica Neue, 12pt, Medium\n",
        "kwargs_caption = {'size': '10pt', 'font': 'Arial', 'weight': '800'}\n",
        "kwargs_text = {'size': '7.75pt', 'font': 'Arial'}\n",
        "\n",
        "f = Figure(\n",
        "    \"16.6cm\",\n",
        "    \"5.6cm\",\n",
        "    Panel(SVG(\"svg/fig2_panel_a.svg\").scale(svg_scale)).move(102, 26.4),\n",
        "    Panel(Text(\"Data\", 23, 32.0, **kwargs_text)),\n",
        "    Panel(SVG(\"svg/fig2_panel_b.svg\").scale(svg_scale)).move(102, 111.5),\n",
        "    Panel(Text(\"Model\", 23, 117.0, **kwargs_text)),\n",
        "    Panel(SVG(\"svg/fig2_illustration.svg\").scale(svg_scale).move(0, 5)).move(-3, 10),\n",
        "    Panel(Text(\"a\", 5, 12.0, **kwargs_caption), Text(\"Failure of C2ST to discriminate classes\", 45, 12.0, **kwargs_text)).move(-4, 0),\n",
        "\n",
        "    Panel(Text(\"b\", -35, -8.5, **kwargs_caption), Text(\"C2STs for high-D data\", -12, -8.5, **kwargs_text)).move(342, 20.5),\n",
        "    Panel(SVG(\"svg/fig2_panel_c1.svg\").scale(svg_scale)).move(305, 27.6),\n",
        "    Panel(SVG(\"svg/fig2_panel_c2.svg\").scale(svg_scale)).move(305, 111.8),\n",
        "\n",
        "    Panel(Text(\"c\", -25, 12, **kwargs_caption), Text(\"High C2ST on MNIST\", 3, 12, **kwargs_text)).move(502, 0.0),\n",
        "    Panel(SVG(\"svg/fig2_mnist_0.svg\")).move(485, 25.5),\n",
        "    Panel(SVG(\"svg/fig2_mnist_1.svg\")).move(532.5, 25.5),\n",
        "    Panel(SVG(\"svg/fig2_mnist_2.svg\")).move(580, 25.5),\n",
        "    Panel(Text(\"Data\", 491, 31, **kwargs_text)),\n",
        "\n",
        "    Panel(SVG(\"svg/fig2_gauss_mnist_0.svg\")).move(485, 78),\n",
        "    Panel(SVG(\"svg/fig2_gauss_mnist_1.svg\")).move(532.5, 78),\n",
        "    Panel(SVG(\"svg/fig2_gauss_mnist_2.svg\")).move(580, 78),\n",
        "    Panel(Text(\"Gaussian: C2ST=1.0\", 491, 83.5, **kwargs_text)),\n",
        "\n",
        "    Panel(SVG(\"svg/fig2_gmm_mnist_0.svg\")).move(485, 130.5),\n",
        "    Panel(SVG(\"svg/fig2_gmm_mnist_1.svg\")).move(532.5, 130.5),\n",
        "    Panel(SVG(\"svg/fig2_gmm_mnist_2.svg\")).move(580, 130.5),\n",
        "    Panel(Text(\"MoG: C2ST=1.0\", 491, 135.5, **kwargs_text)),\n",
        ")\n",
        "\n",
        "!mkdir -p fig\n",
        "f.save(\"fig/fig2.svg\")\n",
        "svg(\"fig/fig2.svg\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b6c2ca94-64fc-4f2c-bb1a-c3e054f6f7ed",
      "metadata": {
        "id": "b6c2ca94-64fc-4f2c-bb1a-c3e054f6f7ed"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.18"
    },
    "colab": {
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}