{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "26ScWNvYSgQg"
      },
      "source": [
        "Copyright 2017 Google Inc.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License.\n",
        "\n",
        "\n",
        "# dSprites - Disentanglement testing Sprites dataset\n",
        "\n",
        "## Description\n",
        "Procedurally generated 2D shapes dataset. This dataset uses 6 latents, controlling the color, shape, scale, rotation and position of a sprite (color isn't varying here, its value is fixed).\n",
        "\n",
        "All possible combinations of the latents are present.\n",
        "\n",
        "The ordering of images in the dataset (i.e. shape[0] in all ndarrays) is fixed and meaningful, see below.\n",
        "\n",
        "We chose the smallest changes in latent values that generated different pixel outputs at our 64x64 resolution after rasterization.\n",
        "\n",
        "No noise added, single image sample for a given latent setting.\n",
        "\n",
        "## Details about the ordering of the dataset\n",
        "\n",
        "The dataset was generated procedurally, and its order is deterministic.\n",
        "For example, the image at index 0 corresponds to the latents (0, 0, 0, 0, 0, 0).\n",
        "\n",
        "Then the image at index 1 increases the least significant \"bit\" of the latent:\n",
        "(0, 0, 0, 0, 0, 1)\n",
        "\n",
        "And similarly, till we reach index 32, where we get (0, 0, 0, 0, 1, 0). \n",
        "\n",
        "Hence the dataset is sequentially addressable using variable bases for every \"bit\".\n",
        "Using dataset['metadata']['latents_sizes'] makes this conversion trivial, see below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "cellView": "both",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "jJ02BsnqSa96"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import\n",
        "from __future__ import division\n",
        "from __future__ import print_function\n",
        "from matplotlib import pyplot as plt\n",
        "import numpy as np\n",
        "import seaborn as sns\n",
        "\n",
        "# Change figure aesthetics\n",
        "%matplotlib inline\n",
        "sns.set_context('talk', font_scale=1.2, rc={'lines.linewidth': 1.5})\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "output_extras": [
            {
              "item_id": 2
            }
          ]
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 10952,
          "status": "ok",
          "timestamp": 1495021223246,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "uDL3Iw0WFw1L",
        "outputId": "1a3ce845-1add-41c3-ee3d-6018d09423bc"
      },
      "outputs": [
        {
          "ename": "FileNotFoundError",
          "evalue": "[Errno 2] No such file or directory: 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
            "Cell \u001b[0;32mIn[2], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[39m# Load dataset\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m dataset_zip \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39;49mload(\u001b[39m'\u001b[39;49m\u001b[39mdsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz\u001b[39;49m\u001b[39m'\u001b[39;49m, allow_pickle\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, encoding\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mlatin1\u001b[39;49m\u001b[39m'\u001b[39;49m)\n\u001b[1;32m      4\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m'\u001b[39m\u001b[39mKeys in the dataset:\u001b[39m\u001b[39m'\u001b[39m, dataset_zip\u001b[39m.\u001b[39mkeys())\n\u001b[1;32m      5\u001b[0m imgs \u001b[39m=\u001b[39m dataset_zip[\u001b[39m'\u001b[39m\u001b[39mimgs\u001b[39m\u001b[39m'\u001b[39m]\n",
            "File \u001b[0;32m~/Desktop/USI/PhD/CCM_IJCAI/.venv/lib/python3.9/site-packages/numpy/lib/npyio.py:405\u001b[0m, in \u001b[0;36mload\u001b[0;34m(file, mmap_mode, allow_pickle, fix_imports, encoding, max_header_size)\u001b[0m\n\u001b[1;32m    403\u001b[0m     own_fid \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m\n\u001b[1;32m    404\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 405\u001b[0m     fid \u001b[39m=\u001b[39m stack\u001b[39m.\u001b[39menter_context(\u001b[39mopen\u001b[39;49m(os_fspath(file), \u001b[39m\"\u001b[39;49m\u001b[39mrb\u001b[39;49m\u001b[39m\"\u001b[39;49m))\n\u001b[1;32m    406\u001b[0m     own_fid \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n\u001b[1;32m    408\u001b[0m \u001b[39m# Code to distinguish from NumPy binary files and pickles.\u001b[39;00m\n",
            "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'"
          ]
        }
      ],
      "source": [
        "# Load dataset\n",
        "dataset_zip = np.load('dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', allow_pickle=True, encoding='latin1')\n",
        "\n",
        "print('Keys in the dataset:', dataset_zip.keys())\n",
        "imgs = dataset_zip['imgs']\n",
        "latents_values = dataset_zip['latents_values']\n",
        "latents_classes = dataset_zip['latents_classes']\n",
        "metadata = dataset_zip['metadata'][()]\n",
        "\n",
        "print('Metadata: \\n', metadata)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "9RWpIJtiHYUL"
      },
      "outputs": [],
      "source": [
        "# Define number of values per latents and functions to convert to indices\n",
        "latents_sizes = metadata['latents_sizes']\n",
        "latents_bases = np.concatenate((latents_sizes[::-1].cumprod()[::-1][1:],\n",
        "                                np.array([1,])))\n",
        "\n",
        "def latent_to_index(latents):\n",
        "  return np.dot(latents, latents_bases).astype(int)\n",
        "\n",
        "\n",
        "def sample_latent(size=1):\n",
        "  samples = np.zeros((size, latents_sizes.size))\n",
        "  for lat_i, lat_size in enumerate(latents_sizes):\n",
        "    samples[:, lat_i] = np.random.randint(lat_size, size=size)\n",
        "\n",
        "  return samples\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "W8LKpGjGKaiN"
      },
      "outputs": [],
      "source": [
        "# Helper function to show images\n",
        "def show_images_grid(imgs_, num_images=25):\n",
        "  ncols = int(np.ceil(num_images**0.5))\n",
        "  nrows = int(np.ceil(num_images / ncols))\n",
        "  _, axes = plt.subplots(ncols, nrows, figsize=(nrows * 3, ncols * 3))\n",
        "  axes = axes.flatten()\n",
        "\n",
        "  for ax_i, ax in enumerate(axes):\n",
        "    if ax_i < num_images:\n",
        "      ax.imshow(imgs_[ax_i], cmap='Greys_r',  interpolation='nearest')\n",
        "      ax.set_xticks([])\n",
        "      ax.set_yticks([])\n",
        "    else:\n",
        "      ax.axis('off')\n",
        "\n",
        "def show_density(imgs):\n",
        "  _, ax = plt.subplots()\n",
        "  ax.imshow(imgs.mean(axis=0), interpolation='nearest', cmap='Greys_r')\n",
        "  ax.grid('off')\n",
        "  ax.set_xticks([])\n",
        "  ax.set_yticks([])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "## Fix posX latent to left\n",
        "latents_sampled_dx = sample_latent(size=3000)\n",
        "latents_sampled_dx[:, -2] = 31\n",
        "indices_sampled_dx = latent_to_index(latents_sampled_dx)\n",
        "imgs_sampled_dx = imgs[indices_sampled_dx]\n",
        "y_dx = (latents_sampled_dx[:, 1] == 0) + (latents_sampled_dx[:, 1] == 2).astype(float)\n",
        "c_dx = (np.arange(4) == latents_sampled_dx[:, 1][:,None]).astype(np.float32) "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "## Fix posX latent to left\n",
        "latents_sampled_sx = sample_latent(size=3000)\n",
        "latents_sampled_sx[:, -2] = 0\n",
        "indices_sampled_sx = latent_to_index(latents_sampled_sx)\n",
        "imgs_sampled_sx = imgs[indices_sampled_sx]\n",
        "y_sx = (latents_sampled_sx[:, 1] == 0) + (latents_sampled_sx[:, 1] == 2).astype(float)\n",
        "c_sx = (np.arange(4) == latents_sampled_sx[:, 1][:,None]).astype(np.float32) "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAswAAALJCAYAAACgHHWpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWOUlEQVR4nO3dQXLiShYFULnDA+wFMDb7XxQeswCbmXrS9Yvmm2uBJPRSOmfU0VFF5S+cqhsvL8lL3/d9BwAA/Og/Sy8AAAAqE5gBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACB4HfKLPj4+utPp1O12u+5wOMy9Jmja8Xjszudzt9/vu8/Pz6WXc5N9DcPZ17A+9+zrlyHf9Pf+/t59f39PtT7YhLe3t+7r62vpZdxkX8P97GtYnyH7elAlY7fbTbIg2JLq+6b6+qCi6vum+vqgoiH7ZlBgdqwD96u+b6qvDyqqvm+qrw8qGrJvfOgPAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAACC16UXAHCPvu9//P9fXl6evBIAtsKEGQAAAoEZAAACgRkAAAIdZmAVLrvN+swATMmEGQAAAoEZAAAClQygtFvXyA39PeoZAIxlwgwAAIHADAAAgUoGUM4jNYxHXktdA4AhTJgBACAQmAEAIFDJAEqYsobxyJ+pngHALSbMAAAQCMwAABCoZACLWaKGcYt6BgC3mDADAEAgMAMAQCAwAwBAoMMMPFWl3vIt+swwjUf2uz1HRSbMAAAQCMwAABCoZACzaqGCkahnwH3G7vlbv9/+Y0kmzAAAEAjMAAAQqGQAk2u9hnHL9X+XI2J4HvUolmTCDAAAgcAMAACBSgYwibXWMBJHxLAM9SiezYQZAAACgRkAAAKVDIAJqGfAcuw/5mbCDAAAgcAMAACBwAwAAIEOM/CwLV4lN8StvxfdSoA2mTADAEAgMAMAQKCSAdxFDQOozBVzzMGEGQAAAoEZAAAClQzgV2oYAGyZCTMAAAQCMwAABCoZADPyKX225vJnfuk6lxszmIoJMwAABAIzAAAEKhnAryodsbbA0S/AupgwAwBAIDADAECgkgHcRT3jZ2oY8G+eF6yFCTMAAAQCMwAABAIzAAAEOszAw7beT9Rb/ss3qgFrZsIMAACBwAwAAIFKBjCJW8fwa6pqqBr8ld5X9Qx+svUKF20zYQYAgEBgBgCAQCUDmNX1kXxrR7EqBePcer/9vW5b688FtseEGQAAAoEZAAAClQzgqVr4pLy6wM+mfL+uX8vf+ba18Fxg20yYAQAgEJgBACBQyQAWU+kYViXgZ896X3zZCX9M+VzwszS/R96jFt8XE2YAAAgEZgAACARmAAAIdJiBEir1mVmWbwfkD+/5dKbsGo99Rrf4mQUTZgAACARmAAAIVDKAcp5Vz2jlKPDZqlZiWjzGhSVtsToxFxNmAAAIBGYAAAhUMoDSro8BffPXPKrWMG5J6/Uew/Rae0ZMzYQZAAACgRkAAAKVDKAp996g4XgegLFMmAEAIBCYAQAgUMkAmqVu8bg1feLdzwEwNxNmAAAIBGYAAAgEZgAACJrsMPumL4D7ram3DPzu3ms4n6XFHGbCDAAAgcAMAABBM5WMKY8SLl+rxWMBgK3z7AaeyYQZAAACgRkAAIJmKhlzUc8AtqLqJ+aH8oyGxy29/1vfvybMAAAQCMwAABBsvpJx6fqIovXjA4Bb0vOtUl3Dcximd2tfTb3317R/TZgBACAQmAEAIFDJAOD/LP1pemAZ1xWKe/f/mioY10yYAQAgEJgBACBoppKxxBGhLzUBtu5Zn6Yf8mcCz7XE/q/KhBkAAAKBGQAAAoEZAACCZjrMl/SZAZY19vqp314PqGuL+9WEGQAAAoEZAACCZioZla4wUc+4zyPvnb9XaIvrp4A1M2EGAIBAYAYAgKB0JaPqUZ66wM+mfL+Gvpb3Amq7dauRvQu0xIQZAAACgRkAAILSlYylOTL83dK1mVt/vvcO6rEvgVaZMAMAQCAwAwBAoJJxwXEhAADXTJgBACAQmAEAIBCYAQAgKN1hvvUNUXP9GQAAcM2EGQAAAoEZAACC0pWMS2PrGaoX83hGbQYAYEkmzAAAEAjMAAAQNFPJuKReUdPS9Qw/FwDAHEyYAQAgEJgBACBospJBfUPrEUOqG6oWAMCSTJgBACAQmAEAIFDJYFHqFgBAdSbMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQDArMx+Nx7nXA6lTfN9XXBxVV3zfV1wcVDdk3gwLz+XwevRjYmur7pvr6oKLq+6b6+qCiIfvmdcgL7ff77nQ6dbvdrjscDqMXBmt2PB678/nc7ff7pZcS2dcwnH0N63PPvn7p+75/wpoAAKBJPvQHAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAADB65Bf9PHx0Z1Op26323WHw2HuNUHTjsdjdz6fu/1+331+fi69nJvsaxjOvob1uWdfv/R93//2gu/v79339/dU64NNeHt7676+vpZexk32NdzPvob1GbKvB1UydrvdJAuCLam+b6qvDyqqvm+qrw8qGrJvBgVmxzpwv+r7pvr6oKLq+6b6+qCiIfvGh/4AACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACB4XXoBAAAsr+/7u3/Py8vLDCupx4QZAAACgRkAAAKVDACAjXqkhnHr96+5nmHCDAAAgcAMAACBSgYAAKNd1zvWVNEwYQYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIDAtXIAAExuTd8CaMIMAACBwAwAAIFKBgDAhlx/Ix+/M2EGAIBAYAYAgEAlAwBgQy5vrFDPGMaEGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACt2QA5Tzyqe3LT30DUMvlc73F57UJMwAABAIzAAAEAjMAAAQ6zEAJY79tqvV+HMCatf5cNmEGAIBAYAYAgEAlA1idVO9o/VgQYEqXz8Sx1bg1M2EGAIBAYAYAgEAlA9gUt2kAcC8TZgAACARmAAAIVDKAEp71SW01DICfXT8f3ZrxlwkzAAAEAjMAAAQqGQAA/MvYqtyaKnAmzAAAEAjMAAAQqGRcuD5uWNNRAmyVfQww3pB6xpqftybMAAAQCMwAABAIzAAAEGymw/zIdShb7OjAGtijAPPZ4jPWhBkAAAKBGQAAglVVMh6pXQDrsMUjQgCew4QZAAACgRkAAILmKxlqGLA+6hUAVGLCDAAAgcAMAABB85WMy6Nb9QwAAKZmwgwAAIHADAAAQfOVjDn5pD4AACbMAAAQCMwAABAIzAAAEKyqw3zdOXbNHAAAY5kwAwBAIDADAECwqkrG1C4rHa6YAwDYJhNmAAAIBGYAAAhWXcm4rFE8cmOGGgYAACbMAAAQCMwAABCsupJxaUg9QwUDAKYztA7p31+qM2EGAIBAYAYAgGAzlYxLjn4AoA5VSaozYQYAgEBgBgCAQGAGAIBgkx1mAGAej3yz7r2vpdvMs5kwAwBAIDADAECgksHTDTmuc9wGwC2+QZBnM2EGAIBAYAYAgEAlg9k98onpy9/jSA2gtilvxpiSWzaYigkzAAAEAjMAAAQqGQCUPVK/xZE6Y6Sfdz9b/MSEGQAAAoEZAAAClQzKuz46c1wGsKzWKjz3cLMGPzFhBgCAQGAGAIBAJYPJrPmIDgDYLhNmAAAIBGYAAAgEZgAACHSYAQB+4Co5/jBhBgCAQGAGAIBAJYNRXCUHsA2e92yZCTMAAAQCMwAABCoZjHL5CWLHdQDAGpkwAwBAIDADAECgkkF5Lo4HWMYWq3b+zeEnJswAABAIzAAAEKhkMBnHWADAGpkwAwBAIDADAEAgMAMAQKDDDAD8Y2tXyfn8DUOYMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBA4JYMAGBT3IzBvUyYAQAgEJgBACBQyQAA/nGrrrC1LzSBSybMAAAQCMwAABCoZAAAvxp6s4TqBmtkwgwAAIHADAAAgcAMAACBDjMAMJmq19L5dj/GMGEGAIBAYAYAgEAlAwDH1cwu/YwtXdeA35gwAwBAIDADAECgkgEALGpIJeiR2oaqEVMxYQYAgEBgBgCAQCUDACjPLRssyYQZAAACgRkAAAKVDACgaW7DYG4mzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAMCszH43HudcDqVN831dcHFVXfN9XXBxUN2TeDAvP5fB69GNia6vum+vqgour7pvr6oKIh++Z1yAvt9/vudDp1u92uOxwOoxcGa3Y8Hrvz+dzt9/ullxLZ1zCcfQ3rc8++fun7vn/CmgAAoEk+9AcAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAMHrkF/08fHRnU6nbrfbdYfDYe41QdOOx2N3Pp+7/X7ffX5+Lr2cm+xrGM6+hvW5Z1+/9H3f//aC7+/v3ff391Trg014e3vrvr6+ll7GTfY13M++hvUZsq8HVTJ2u90kC4Itqb5vqq8PKqq+b6qvDyoasm8GBWbHOnC/6vum+vqgour7pvr6oKIh+8aH/gAAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACB4XXoBAMC29X3/z/9+eXlZcCXwMxNmAAAIBGYAAAhUMgCA2V3WLh75daoaLMmEGQAAAoEZAAAClQwAYBZDaxj3vpZ6Bs9mwgwAAIHADAAAgUoGADCZKWsYUIUJMwAABAIzAAAEAjMAAAQ6zABA0+7tTbuWjnuZMAMAQCAwAwBAoJIBADRl7NV1vjWQe5kwAwBAIDADAECgkgEAbNZ1vUNF43dbrLSYMAMAQCAwAwBAoJIBAEA09maS1pkwAwBAIDADAECgkgEATOby1oStH+O3bsj7t5UbM0yYAQAgEJgBACBQyQAAYHSFZs1fAmPCDAAAgcAMAACBwAwAAIEOMwCwWWvq2T5izqv/1nTlnAkzAAAEAjMAAAQqGQDALKp+61/r9YCxlngvWq9nmDADAEAgMAMAQKCSwWRuHfG0ePQCQPv8+/NXpUpMi0yYAQAgEJgBACBQyWB2Q4+BHJ0BQ6XnimfJ+nmPh6law2jxxgwTZgAACARmAAAIVDIoY+zRUSvHOtCSqke6rJdn+eNa3K+t1DNMmAEAIBCYAQAgEJgBACDQYWY1fNMg/FuLnUbWybOYlpkwAwBAIDADAECgksHq+UYwAGAME2YAAAgEZgAACFQyGKX1T+C38g1DAMByTJgBACAQmAEAIFDJgP9RzwCgZdf/dk1Zm7x87dbrmI8wYQYAgEBgBgCAQCUDAIB/UU/8y4QZAAACgRkAAAKBGQAAgmY6zJWuMNHpgXYMfXbY18Da3Hqujb1GdewVcy0+b02YAQAgEJgBACBoppJRya3jhxaPGB5RqR4Df/i5BBhmyryylexjwgwAAIHADAAAgUoG0JS5qhdjPzVOHd5LYGomzAAAEAjMAAAQqGTA/zi6XZZbLgCoyoQZAAACgRkAAILSlQxHtLA+LexrtywAcMmEGQAAAoEZAACC0pUM+MmQI/Jbx/6O15+jhdrFVlz+zHtfAB5jwgwAAIHADAAAgcAMAACBDjOrpKv8HFvoxLpiDgATZgAACARmAAAIVDK4m2uqAIAtMWEGAIBAYAYAgEAlg/LcTFCXeg4AW2DCDAAAgcAMAACBSgajOJJnS65/xtWFALbBhBkAAAKBGQAAApWMkRzJ/nXr72JoVcPfJQBQkQkzAAAEAjMAAAQCMwAABDrMzE43eRu2eMXg5X+nn3OA9TJhBgCAQGAGAIBAYAYAgEBgBgCAQGAGAICg9C0ZU37qfCuf2ocKtnhjxta5JQRYMxNmAAAIBGYAAAhKVzKm5LgQmFMLX2JSdV0A1ZkwAwBAIDADAECwmUoGsIzrGoBbMwBojQkzAAAEAjMAAAQCMwAABDrMAA9yTRvANpgwAwBAIDADAECgkgEQqF0AYMIMAACBwAwAAIFKBvBUlxWHpb/1T90CgCFMmAEAIBCYAQAgUMkAVk3tAoCxTJgBACAQmAEAIFDJABYz5Y0ZqhcAzMWEGQAAAoEZAAAClQygNFULAJZmwgwAAIHADAAAgcAMAACBDjNQgq4yAFWZMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAMCgwH4/HudcBq1N931RfH1RUfd9UXx9UNGTfDArM5/N59GJga6rvm+rrg4qq75vq64OKhuyb1yEvtN/vu9Pp1O12u+5wOIxeGKzZ8Xjszudzt9/vl15KZF/DcPY1rM89+/ql7/v+CWsCAIAm+dAfAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAE/wWokDlMH3cOFAAAAABJRU5ErkJggg==",
            "text/plain": [
              "<Figure size 900x900 with 9 Axes>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "image_fused = np.clip(imgs_sampled_dx - imgs_sampled_sx, a_max=1, a_min=0)\n",
        "y_fused = np.clip((latents_sampled_dx[:, 1] == 0) + (latents_sampled_sx[:, 1] == 2), a_max=1, a_min=0)\n",
        "c_fused =  np.clip(c_sx + c_dx, a_max=1, a_min=0)\n",
        "c_fused[:, 3] = 1\n",
        "show_images_grid(image_fused, 9)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAswAAALJCAYAAACgHHWpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAV+0lEQVR4nO3dS3IiSRYF0FCbBkgLYCz2vyg01gKUzKIHZdaFqcVNB+LzPOKcYVkm6QU4XHt+iXgZx3EcAACAX/1n7QUAAEBlAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAASvLX/o4+Nj+Pr6Gg6Hw3A6neZeE3TtfD4Pl8tlOB6Pw+fn59rLucm+hnb2NWzPPfv6peVOf+/v78OfP3+mWh/swtvb2/D9/b32Mm6yr+F+9jVsT8u+bqpkHA6HSRYEe1J931RfH1RUfd9UXx9U1LJvmgKzYx24X/V9U319UFH1fVN9fVBRy77xoz8AAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIXtdeAMBP4zje/XdeXl5mWAkAmDADAEAkMAMAQCAwAwBAoMMMlPBIb/nW39dnBmBKJswAABAIzAAAEKhkAJuT6h3qGgDcy4QZAAACgRkAAAKVDGBXXE0DgHuZMAMAQCAwAwBAoJIBlHBdj3j2Jiat/w4AtDBhBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIDAVTKATXNVDACeZcIMAACBwAwAAIHADAAAgQ4zsDl6ywBMyYQZAAACgRkAAAKVDGAT1DAAmIsJMwAABAIzAAAEKhlAOeoVAFRiwgwAAIHADAAAgUoGAACTG8fx1//eY+3OhBkAAAKBGQAAApUMAACa3apaPPL3e6lnmDADAEAgMAMAQCAwAwBAoMMMAMDT3eQtM2EGAIBAYAYAgEAlAwBgRypVL3q5xJwJMwAABAIzAAAEKhkAABtQqWqxNSbMAAAQCMwAABCoZAAAs+vlaghVqVusy4QZAAACgRkAAAKVDABgUT/rBSoaDEPt2o4JMwAABAIzAAAEAjMAAARddpifvbRKtV4MAOzZre9139f/un4uXGJueSbMAAAQCMwAABB0U8mY8vih8mVLAIB/tH73+y5nbibMAAAQCMwAABB0U8mYi3oGAMxjqas5uMrG9lS7G6QJMwAABAIzAAAEu69kXKs2/gdYSm83QvD5TIutVjXcxGR5JswAABAIzAAAEKhkAAC7kmoMvdc1tmrtq5qZMAMAQCAwAwBA0E0lY41fhK49/geA3vR+1Qbf/fzGhBkAAAKBGQAAAoEZAACCbjrM1/SZAQDc9W8pJswAABAIzAAAEHRTyah0zKCecZ9HXjvPK0A/Kn1HT8n3fU1rvC4mzAAAEAjMAAAQlK5kVD3icSzzuylfr9bH8loAwD9+fidWzVE9MmEGAIBAYAYAgKB0JWNtjvv/bu3jnlv/vtcOALZj7e91E2YAAAgEZgAACFQyrqw97geAHqxdx2M7esleJswAABAIzAAAEAjMAAAQlO4wX/da5upL9dKdAQCW1XtGWCJHter9uTRhBgCAQGAGAICgdCXj2rPHCr0fBVRV6bgHgH34+Z0+5fePvHCfvTxfJswAABAIzAAAEHRTybi2l/F/b9auZ3hfAOzTrc//lu8i3x1t9v48mTADAEAgMAMAQNBlJYP6Wo9uHJcB9GfOq1RMyffHv27VJj1HbUyYAQAgEJgBACBQyWBVjoIAaOH7Yjqey/uZMAMAQCAwAwBAIDADAECgwwwAPGWuO73q2lKFCTMAAAQCMwAABCoZADj6ZjIt76WftQ3vP6ozYQYAgEBgBgCAQCUDAFiUCga9MWEGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAoCkwn8/nudcBm1N931RfH1RUfd9UXx9U1LJvmgLz5XJ5ejGwN9X3TfX1QUXV90319UFFLfvmteWBjsfj8PX1NRwOh+F0Oj29MNiy8/k8XC6X4Xg8rr2UyL6GdvY1bM89+/plHMdxgTUBAECX/OgPAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAACC15Y/9PHxMXx9fQ2Hw2E4nU5zrwm6dj6fh8vlMhyPx+Hz83Pt5dxkX0M7+xq25559/TKO4/i3B3x/fx/+/Pkz1fpgF97e3obv7++1l3GTfQ33s69he1r2dVMl43A4TLIg2JPq+6b6+qCi6vum+vqgopZ90xSYHevA/arvm+rrg4qq75vq64OKWvaNH/0BAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEDwuvYC4DfjOM7yuC8vL7M8LgCwXSbMAAAQCMwAABCoZLC4ueoWj/zbKhoAwN+YMAMAQCAwAwBAoJLB7NasYPzN9drUMwCA35gwAwBAIDADAECwy0rGsxUBR/e/q1y9AAB4lAkzAAAEAjMAAAQCMwAABLvpME/Zr3Upsn/pLQMAW2fCDAAAgcAMAADBpisZ6gLz8LwCAHtiwgwAAIHADAAAwaYqGWtUBfZwxQwVDABgz0yYAQAgEJgBACDovpKhLjCPvTyvW63RAADTMWEGAIBAYAYAgKDLSsZe6gKwJ7f2tdoMAGszYQYAgEBgBgCAQGAGAICgyw4z89hDN1wftj+PvC+9zgBMyYQZAAACgRkAAIJuKhl7qAswD8fz+/Ps54X3DADXTJgBACAQmAEAIOimklGVo9uavC48w10HAbhmwgwAAIHADAAAgUoGm+G4nLn9rGp4zwHsgwkzAAAEAjMAAASlKxluVsLfOBLvW+97/Hr93osA22XCDAAAgcAMAABB6UpGVY5el+X5pgfqGQDbZcIMAACBwAwAAIHADAAAgQ4z/3Ordznnpb90PWGfnv1c8dkBLMmEGQAAAoEZAAAClYxGez7+2/P/O/Cv3u/MCPAoE2YAAAgEZgAACFQygEU51l/WVp9vd1YElmTCDAAAgcAMAACBSkbgmA8AABNmAAAIBGYAAAhKVzKuKxFL/NJbBQN4lM8PgO0yYQYAgEBgBgCAQGAGAICgdId5CXqHsC9T/jaih8+PpX8LArBFJswAABAIzAAAEHRTybh19Nl6xNjD0SmwLJ8LALQwYQYAgEBgBgCAoJtKxi2OVAEAmJMJMwAABAIzAAAE3VcygL64kQZT+/k+UtUDpmbCDAAAgcAMAACBSgawmiXqGY7nAXiWCTMAAAQCMwAABAIzAAAEOsxACVP2mfWWYRse+Syw/5mDCTMAAAQCMwAABCoZQDmOVOfhLotUNeX78fqxfJYwFRNmAAAIBGYAAAhUMgAANqi36lXlCo0JMwAABAIzAAAEKhkAwKJ6qwqACTMAAAQCMwAABCoZAGyKG1fUpIZBz0yYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIHBZOQBgFi4lx1aYMAMAQCAwAwBAoJIBsEM/74BX9ejcnfq4l/cMczBhBgCAQGAGAIBAJQOASTkSB7bGhBkAAAKBGQAAApUMANQo6Jr3L3MzYQYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIDAZeUAgK64jBxLM2EGAIBAYAYAgEAlAwAoTw2DNZkwAwBAIDADAECgkgEAzOK6RjGO41N/H9ZkwgwAAIHADAAAgUoGADA79Qp6ZsIMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcvKAQBskEv5TceEGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAIKmwHw+n+deB2xO9X1TfX1QUfV9U319UFHLvmkKzJfL5enFwN5U3zfV1wcVVd831dcHFbXsm9eWBzoej8PX19dwOByG0+n09MJgy87n83C5XIbj8bj2UiL7GtrZ17A99+zrl3EcxwXWBAAAXfKjPwAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACF5b/tDHx8fw9fU1HA6H4XQ6zb0m6Nr5fB4ul8twPB6Hz8/PtZdzk30N7exr2J579vXLOI7j3x7w/f19+PPnz1Trg114e3sbvr+/117GTfY13M++hu1p2ddNlYzD4TDJgmBPqu+b6uuDiqrvm+rrg4pa9k1TYHasA/ervm+qrw8qqr5vqq8PKmrZN370BwAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAwevaCwAAoJ5xHP/6Z15eXhZYyfpMmAEAIBCYAQAgUMkAAKCpgnHP39lSXcOEGQAAAoEZAAAClQwAgJ16pIbxyGP3Xs8wYQYAgEBgBgCAQCUDAGBH5qxhtPybPdYzTJgBACAQmAEAIBCYAQAg0GEGANi4NXrLW2LCDAAAgcAMAACBSgYAAIvp8RJzJswAABAIzAAAEKhkMLufv8zt5fgFAGAYTJgBACASmAEAIFDJYBbpAun3XjxdhQMA7udmJdMxYQYAgEBgBgCAQCWD8lqPlFQ3AKAvvdzExIQZAAACgRkAAAKBGQAAAh1mNqOXHhQA0BcTZgAACARmAAAIVDKYzNp3FFLD6MOU7xOvOQBLMGEGAIBAYAYAgEAlA5jEGpWcR/5NNQ4A7mXCDAAAgcAMAACBSgbwsLWvjPKIW2tW1QDgFhNmAAAIBGYAAAhUMoC79FjDaKGqAcAtJswAABAIzAAAEAjMAAAQ6DDzlLX7rPqlzO36Pe79BrBPJswAABAIzAAAEKhk8JTrI+q16xkwt5/vcRUNgH0wYQYAgEBgBgCAQCWDybQeT6tusBWuoAGwDybMAAAQCMwAABCoZLC4e4+uXZmAHqhnAGyXCTMAAAQCMwAABCoZlOd4uxY3q/k79QygAp/X0zFhBgCAQGAGAIBAJQMAgFX0UlszYQYAgEBgBgCAQGAGAIBAhxl4mEsWAbAHJswAABAIzAAAEKhkAJNQzwCoy2f0c0yYAQAgEJgBACBQyQAm5+gPgFt6ubvfNRNmAAAIBGYAAAhUMoBZpaO3rdY1ejxuBPZjjdpc75+LJswAABAIzAAAEKhkAKu5dUTXY1Wj9+NGYJ9c1aiNCTMAAAQCMwAABAIzAAAEm+owT9290UmEdbTuvbX7dj4j4P89uy/tq/U8exnQLb92JswAABAIzAAAEHRfyZjzSPb6sbd8zAC9si9hHUt9916z39e19+ffhBkAAAKBGQAAgu4rGUtRzwCA9fgeZk0mzAAAEAjMAAAQdFnJWPtmBQCwN5W+e3+uRUWDuZkwAwBAIDADAEDQZSVjbX6pC2zN0sftPjuBnpgwAwBAIDADAEAgMAMAQNBlh/m6+1bpMjcAlVX6vGxdi64zLfy2iLmZMAMAQCAwAwBAUKqSkY7oHLEA3K9SDeMRt9bvO2F5vdQh1TOYgwkzAAAEAjMAAASrVDIeOcqpfPwDUMkePi8du6/r53Ne9T3nfcJUTJgBACAQmAEAIFisklH1uAaAvv38fnH0vrwerqChnsEzTJgBACAQmAEAICh14xIAeJaj9+VVrWHAVEyYAQAgEJgBACBQyQBgs9Qz5tF7BSOt3/uE35gwAwBAIDADAEAgMAMAQKDDDLAxPdx1jf5s9b2ks0wLE2YAAAgEZgAACFQyAIBfbamGoXrBM0yYAQAgEJgBACBQyQDYMFfMYM/UMJiKCTMAAAQCMwAABItVMhwLAqwrHU/7XGYr1DCYgwkzAAAEAjMAAASrXCVDPQOglq1+Ljue3yavK0szYQYAgEBgBgCAQGAGAIBg9Tv9tfaQqnbqrtelUwVsQctnWdXP5GHwWTylSt12rytrMmEGAIBAYAYAgGD1SkYrRzEAdTzymfzIkb7P/jrWqGd4/anChBkAAAKBGQAAgm4qGQD0zfH6dqTXsqWu4b1Ab0yYAQAgEJgBACBQyQAAJqNuwRaZMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQNAUmM/n89zrgM2pvm+qrw8qqr5vqq8PKmrZN02B+XK5PL0Y2Jvq+6b6+qCi6vum+vqgopZ989ryQMfjcfj6+hoOh8NwOp2eXhhs2fl8Hi6Xy3A8HtdeSmRfQzv7Grbnnn39Mo7juMCaAACgS370BwAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAwX8B/kEMST2z1YIAAAAASUVORK5CYII=",
            "text/plain": [
              "<Figure size 900x900 with 9 Axes>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "no_sq_no_heart = image_fused[((latents_sampled_sx[:, 1] != 0).astype(float) * (latents_sampled_dx[:, 1] != 2).astype(float)).astype(bool), :, :]\n",
        "y_no_sq_no_heart = y_fused[((latents_sampled_sx[:, 1] != 0).astype(float) * (latents_sampled_dx[:, 1] != 2).astype(float)).astype(bool)]\n",
        "c_no_sq_no_heart = c_fused[((latents_sampled_sx[:, 1] != 0).astype(float) * (latents_sampled_dx[:, 1] != 2).astype(float)).astype(bool)]\n",
        "show_images_grid(no_sq_no_heart, 9)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "-FCACtAlqKTA"
      },
      "outputs": [
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAswAAALJCAYAAACgHHWpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAW7UlEQVR4nO3dQVbbwLYFUPEXDcMA3MbzH5RpMwBwT6/xVl4cPj6Ubcm6Je3dTkjFUOKsW8flp3EcxwEAAPjR/y29AAAAqExgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCA4LnlD729vQ0fHx/DbrcbDofD3GuCrh2Px+F0Og37/X54f39fejkX2dfQzr6G9blmXz+1fNLf6+vr8PX1NdX6YBNeXl6Gz8/PpZdxkX0N17OvYX1a9nVTJWO3202yINiS6vum+vqgour7pvr6oKKWfdMUmB3rwPWq75vq64OKqu+b6uuDilr2jTf9AQBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBAIDADAEDwvPQCAOjPOI5X/52np6cZVgIwPxNmAAAIBGYAAAhUMgBocksN49LfV88AemLCDAAAgcAMAACBSgYAD/e93qGiAVRmwgwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAABBl9fK3fJpU64sAqjLpwAClZkwAwBAIDADAEDQTSXjlhpGy9939Adw2b3PXoA1MGEGAIBAYAYAgKCbSsZcvDMb4LLz56J6BrBVJswAABAIzAAAEGy+knHu+3GjigYAACbMAAAQCMwAABCoZARu0AB4PM9eoBoTZgAACARmAAAIBGYAAAh0mAEoRW8ZqMaEGQAAAoEZAAAClYxGrjkCtu782ff9k1EB1syEGQAAAoEZAAACgRkAAAKBGQAAAoEZAACCbm7JqPTubDdmAFv3/dm39HMZYE4mzAAAEAjMAAAQdFPJAKCue2tz6m1AZSbMAAAQCMwAABB0WcmodGMGAP9qeUarYAA9MWEGAIBAYAYAgEBgBgCAoMsOMwB90FUG1sCEGQAAAoEZAACC7isZrpgDAGBOJswAABAIzAAAEHRfyTj3/d3YKhoAAP/yCZzXM2EGAIBAYAYAgGBVlYzv3KABANCmNSttsbphwgwAAIHADAAAwaorGeemrGds8Sji0W75Hvm+AMD8tpijTJgBACAQmAEAIBCYAQAg2EyH+VyP3Zmqpuwa39uJOv/7vscAUFOPnzRowgwAAIHADAAAwSYrGdxHdQIAmFrKF0vnBRNmAAAIBGYAAAhUMljUvfUOAOB3vf++XbrOacIMAACBwAwAAIFKBlBC63Hh0u+UBmBZS9QzTJgBACAQmAEAIFDJALpyqbqhqgHAXEyYAQAgEJgBACAQmAEAINBh5mrnXdFKnxykw9qfKX9+dJuhBr8XWCMTZgAACARmAAAIVDKAVfMJgrBdW69qVarH9M6EGQAAAoEZAAAClQzusvSNGVs5VluTqkeEWz+6BeAyE2YAAAgEZgAACFQymMylo+upj+AdkfNI6efXzyLA4y3x7DVhBgCAQGAGAIBAJYPZfT86ubai4di7b1VvxZiCmzXgv9a8z2EYTJgBACASmAEAIFDJ4OEedZsGALCclnpaL9U2E2YAAAgEZgAACARmAAAIdJgpo1pfCa7h5xfger08O02YAQAgEJgBACBQyQAm54pAgOWd1x08l+9jwgwAAIHADAAAgUoGAAB36+XGi1uYMAMAQCAwAwBAoJIBTGKL78Be8/EjsC5uzLiPCTMAAAQCMwAABCoZAMDqqVD9dem1aK1qbPG1NGEGAIBAYAYAgEBgBgCAQIcZuNnWribaYm8P2A7PuMtMmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgcEsGQOBd4/C7KffJ1m7foQ8mzAAAEAjMAAAQqGQAN7t0DOtIFbiVGhQVmTADAEAgMAMAQKCSAUyu9UhVdQOAHpgwAwBAIDADAEAgMAMAQKDDDCym6rV0rrUC4JwJMwAABAIzAAAEKhlAOakSsXRdA4DtMWEGAIBAYAYAgEAlA+hKyw0Wt9Q23IwBwCUmzAAAEAjMAAAQqGQAq+OWDQCmZMIMAACBwAwAAIFKBrApbsMA4FomzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABA0Bebj8Tj3OmB1qu+b6uuDiqrvm+rrg4pa9k1TYD6dTncvBram+r6pvj6oqPq+qb4+qKhl3zy3fKH9fj98fHwMu91uOBwOdy8M1ux4PA6n02nY7/dLLyWyr6GdfQ3rc82+fhrHcXzAmgAAoEve9AcAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAMFzyx96e3sbPj4+ht1uNxwOh7nXBF07Ho/D6XQa9vv98P7+vvRyLrKvoZ19Detzzb5+Gsdx/O0Lvr6+Dl9fX1OtDzbh5eVl+Pz8XHoZF9nXcD37GtanZV83VTJ2u90kC4Itqb5vqq8PKqq+b6qvDypq2TdNgdmxDlyv+r6pvj6oqPq+qb4+qKhl33jTHwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABM9LLwCAbRjHcZav+/T0NMvXBfjDhBkAAAKBGQAAApUMAO42V93iln9bRQOYmgkzAAAEAjMAAAQqGQBcbckKxm/O16aeAUzBhBkAAAKBGQAAApUMYBXurQg4uv9Z5eoFwKOYMAMAQCAwAwBAIDADAECgwwx0a8p+ravI/tJbBviXCTMAAAQCMwAABCoZQFfUBebhdQW4zIQZAAACgRkAAAKVDKC0JaoCW7gxQwUDoJ0JMwAABAIzAAAEKhlAOeoC89jK67rWGg2wHBNmAAAIBGYAAAhUMoAStlIXAKA/JswAABAIzAAAEAjMAAAQ6DADrNgWuuGukQPmZsIMAACBwAwAAEGXlYxLR4yO5aAvW6gLMA/Pe+CRTJgBACAQmAEAIOiyknFJOt51fAe08ryoyfcFWIoJMwAABAIzAAAE3VQy7n03vZs1APrjGQ1UYMIMAACBwAwAAEE3lYy5uFkDHsuHlfAbz16gGhNmAAAIBGYAAAhKVzKWPrpt+fcdHcI62MuP5fUGemLCDAAAgcAMAACBwAwAAEHpDnMPfIIgUNmlZ9Gc7xHx/APWxoQZAAACgRkAAIJSlYylr5Gbkk8QhPq2vBe3/H+Hrbg3V3lO/GXCDAAAgcAMAABBqUrGVrhZAwCYwpx1VnnlLxNmAAAIBGYAAAgWr2Ss6WaMe52/Fls87oBHsLcA7rPFqoYJMwAABAIzAAAEi1cy+GvNRxnwx/nP+SMqWfYVsDZV66zf17Wm568JMwAABAIzAAAEAjMAAASLdJirdm+AdVhTbw6gV2u6LteEGQAAAoEZAAAC18otrPcjCrjHpZ//1tqW/QPQh94/HdCEGQAAAoEZAACCh1Uy3IwBtOrliA6A+/RS1TBhBgCAQGAGAIDALRkLqHbMAAD04zxHqLw+hgkzAAAEAjMAAASzVjIcEwAAzGdN9YzKlVUTZgAACARmAAAIBGYAAAhcK/cglXs5AED/Utao2m/uJR+ZMAMAQCAwAwBAMHklo+rIn+Xc8jPRyxENANdLvxc8/6nIhBkAAAKBGQAAgskrGWv6xJl7bP1I6d7vveM6gH5M+ftejW86VXNYj98vE2YAAAgEZgAACGb94JJLI/eqRwTcx/cVYBuqPe8vrafHo/97VPu+rIkJMwAABAIzAAAEs1YyLllrVWNrRz9LOf858ZoDPEaPv6NVNero/TU3YQYAgEBgBgCAYJFKxiVpXN/jURAAUI9q32Os6bU1YQYAgEBgBgCAQGAGAICgVIc5aenBLNFzXlM/p0d6aADA3EyYAQAgEJgBACDoppLRwrV0yzp//b3eAOuy1mf89/+Let/t1vzamTADAEAgMAMAQLCqSkZy6ZjglmOlNR85/GRNR28A3O/770G/J2pY4vuylUxkwgwAAIHADAAAwWYqGZe4WeNnW/6/A3Cdnm/QWHOlQB11OibMAAAQCMwAABBsvpKRbO34obdjNADqmbIGwDy2lm+mYMIMAACBwAwAAIHADAAAgQ4zADC7pa9x1dvlHibMAAAQCMwAABCoZLAa50d6jt4A+uGZTXUmzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQuCWD/zl/l/IjLpEHAOiBCTMAAAQCMwAABCoZ/Kj3eoYPMQEApmLCDAAAgcAMAACBwAwAAIEOM7/63gGestOsXwwAVGfCDAAAgcAMAACBSgZXU6MAALbEhBkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAACCpsB8PB7nXgesTvV9U319UFH1fVN9fVBRy75pCsyn0+nuxcDWVN831dcHFVXfN9XXBxW17Jvnli+03++Hj4+PYbfbDYfD4e6FwZodj8fhdDoN+/1+6aVE9jW0s69hfa7Z10/jOI4PWBMAAHTJm/4AACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACB4bvlDb29vw8fHx7Db7YbD4TD3mqBrx+NxOJ1Ow36/H97f35dezkX2NbSzr2F9rtnXT+M4jr99wdfX1+Hr62uq9cEmvLy8DJ+fn0sv4yL7Gq5nX8P6tOzrpkrGbrebZEGwJdX3TfX1QUXV90319UFFLfumKTA71oHrVd831dcHFVXfN9XXBxW17Btv+gMAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgOB56QUADMMwjON4199/enqaaCUA8C8TZgAACARmAAAIVDKAxdxbw2j5WqoaANzLhBkAAAKBGQAAApUM4KGmrGHc8u+paABwLRNmAAAIBGYAAAgEZgAACHSYgU057zTrMwPQwoQZAAACgRkAAAKBGQAAAoEZAAACgRkAAIJSt2Q8+hPAruHd9HCbyvvajRkAtDBhBgCAQGAGAIBgkUpG5SPaSy6t2TEurIN6BgCXmDADAEAgMAMAQPCwSkaPNYwWqhoAAOtmwgwAAIHADAAAgcAMAADBrB3mtfaWW6T/u34zAEA/TJgBACAQmAEAIFjkk/62zieKQW32KFxnygqmPUdFJswAABAIzAAAEExeydjyzRi3cPQLQFVL/E6/5d/0+5O5mTADAEAgMAMAQOCWDADgf3qsVl5as6oGUzFhBgCAQGAGAIBg8krG+fFHj8c6S1rrjRnp52BN/09+9v177LkA9ax1X6pqMBUTZgAACARmAAAIBGYAAAhcK8csWvtwl3rbrX9fD425rfW9BbBl9jXXMmEGAIBAYAYAgGDWSoYr5pibYzUA7vE9n/hdwk9MmAEAIBCYAQAgeNgtGeoZv9v6MZCfi23wLAAqU/XjJybMAAAQCMwAABAs8sEljmT/ctwDADWpZ/CHCTMAAAQCMwAABItUMs6lI4611jW2cKzzqNrNFl5LgLmpSv5OPWPbTJgBACAQmAEAIFi8kpG0HHlUPTpyXDMfr+16OAYGoAcmzAAAEAjMAAAQCMwAABCU7jC30Get7/v36Nququ8xS/Lzx9Z4bwH8fybMAAAQCMwAABB0X8mgPy3HfY7Bt8cxMNRjX8J/mTADAEAgMAMAQKCSQRlqGPzhGBjqsS/ZMhNmAAAIBGYAAAhUMliUGga/ufeDb6b4N4F/pT2y1rqG58K2mTADAEAgMAMAQKCSAXRlrnfqO26FaVzaSz1WNTwX+MOEGQAAAoEZAAACgRkAAAIdZqBb+oXQj9b9unTX2XOFn5gwAwBAIDADAECgkgEAlKESQUUmzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQ+OASAIAVGsfxof/emj90xoQZAAACgRkAAAKBGQAAAh1mAIBOPbqnnLSupceuswkzAAAEAjMAAAQqGQAAHalUw7jFpfVXrmqYMAMAQCAwAwBAoJIBAFBc7zWMFuf/x2r1DBNmAAAIBGYAAAhUMgAAKOV7BWXpioYJMwAABAIzAAAEKhkAAJS29A0aJswAABAIzAAAEKhkAADQjSXqGSbMAAAQCMwAABAIzAAAEOgwAwAUd97V/f4peMzPhBkAAAKBGQAAApUMAP5x73HvEp/CBTAnE2YAAAgEZgAACFQyAJj0Xfe3fC01DmjnxozHM2EGAIBAYAYAgEAlA4DFXTpWVtWALO0RdY3pmDADAEAgMAMAQKCSAbBBjmph/dZ6m8YSVS0TZgAACARmAAAIBGYAAAh0mAEoxVVyML2WfVW557z0c8GEGQAAAoEZAAAClQyADfp+vFn5KBZ4jFtqD7c8O5auV9zChBkAAAKBGQAAApUMAFb7iWDAvHqsV9zChBkAAAKBGQAAApUMAP5x6YhVVQPYKhNmAAAIBGYAAAhUMgBY3FbeaQ/0yYQZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAp/0B8DD+WQ/oCcmzAAAEAjMAAAQqGQA0CTVKMZxvPrvAPTChBkAAAKBGQAAApUMAO6megGsmQkzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABE2B+Xg8zr0OWJ3q+6b6+qCi6vum+vqgopZ90xSYT6fT3YuBram+b6qvDyqqvm+qrw8qatk3zy1faL/fDx8fH8NutxsOh8PdC4M1Ox6Pw+l0Gvb7/dJLiexraGdfw/pcs6+fxnEcH7AmAADokjf9AQBAIDADAEAgMAMAQCAwAwBAIDADAEAgMAMAQCAwAwBA8B9wMYRfPP5gcQAAAABJRU5ErkJggg==",
            "text/plain": [
              "<Figure size 900x900 with 9 Axes>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "no_heart_no_sq = image_fused[((latents_sampled_sx[:, 1] != 2).astype(float) * (latents_sampled_dx[:, 1] != 0).astype(float)).astype(bool), :, :]\n",
        "y_no_heart_no_sq = y_fused[((latents_sampled_sx[:, 1] != 2).astype(float) * (latents_sampled_dx[:, 1] != 0).astype(float)).astype(bool)]\n",
        "c_no_heart_no_sq = c_fused[((latents_sampled_sx[:, 1] != 2).astype(float) * (latents_sampled_dx[:, 1] != 0).astype(float)).astype(bool)]\n",
        "show_images_grid(no_heart_no_sq, 9)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAswAAALJCAYAAACgHHWpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWBklEQVR4nO3dPXLbShpAUWrKAe0FMLb2vyg61gJsZpxoavRUj9fgLxrAObEldVlq8lbjI/B2Pp/POwAA4F/9Z+4FAADAyAQzAAAEwQwAAEEwAwBAEMwAABAEMwAABMEMAABBMAMAQPg25R/9/Plz9/Hxsdvv97v39/dnrwkW7Xg87k6n0+5wOOx+/fo193Iusq9hOvsa1ueaff025Ul/P3782P358+dR64NN+P79++73799zL+Mi+xquZ1/D+kzZ15NGMvb7/UMWBFsy+r4ZfX0wotH3zejrgxFN2TeTgtllHbje6Ptm9PXBiEbfN6OvD0Y0Zd/40B8AAATBDAAAQTADAEAQzAAAEAQzAAAEwQwAAEEwAwBAEMwAABAEMwAABMEMAABBMAMAQBDMAAAQBDMAAATBDAAAQTADAEAQzAAAEAQzAAAEwQwAAEEwAwBAEMwAABAEMwAABMEMAABBMAMAQBDMAAAQBDMAAATBDAAA4dvcCwAA1ul8Pl/9NW9vb09YCdzHCTMAAATBDAAAwUgGAPAwt4xhXPr6qeMZU36mUQ/u4YQZAACCYAYAgGAkAwC4y71jGK/4ebeMesD/OGEGAIAgmAEAIBjJAACG9KxRD+MZXMsJMwAABMEMAABBMAMAQBDMAAAQBDMAAATBDAAAQTADAEAQzAAAEAQzAAAET/oDFsUTugB4NSfMAAAQBDMAAAQjGcBwPo9d3PLvjGrAa33ec1P3LyyJE2YAAAiCGQAAgpEMYAiPvIzrThpA8brAtZwwAwBAEMwAABCMZACz8Wl6WJ9R75hhDIN7OGEGAIAgmAEAIAhmAAAIZpiBTbl2ptLcI9xu7nlm+5dHccIMAABBMAMAQDCSAazavZeBPTUQHuOZ4xn2Js/mhBkAAIJgBgCAYCQDYKKvl5FdBv47Iy08g78lXs0JMwAABMEMAADBSAYADzXHAypYlq8jFVP+ZoxhMCcnzAAAEAQzAAAEIxnAbJ75IANea8rvzx0zuJa/E0bhhBkAAIJgBgCAYCTjEw8lAJjm3hEar7d85vfP6JwwAwBAEMwAABAEMwAAhM3MMN8yb3fpa8xawTZtfe8/89Z/bjkHjMwJMwAABMEMAABhVSMZnhQGyzXqU/+2Ph4wx+/CeAYwGifMAAAQBDMAAITFj2SMdOkWWC6X/v/P6yrAPzlhBgCAIJgBACAsfiRj1E/WA2MwajHNqK+f7pgBjMAJMwAABMEMAABh8SMZz+TyH4zNHr3dqCMYxXgGMBcnzAAAEAQzAAAEwQwAAGFVM8xfZ9qWOKMHmE8FYCxOmAEAIAhmAAAIqxrJeDS3MAIAwAkzAAAEwQwAAGHVIxmfxyhuuWOGMQwAAJwwAwBAEMwAABBWPZLx2ZTxDCMYwFY880FP947DAYzGCTMAAATBDAAAYTMjGZ8ZvQB4LK+rsG5bf5ibE2YAAAiCGQAAgmAGAICwyRlmAP7p0kzivXOLnrgK67PFeWYnzAAAEAQzAAAEIxm83JTLslu5xAOje+RetK9hWaa8X29lPMMJMwAABMEMAADBSAZPd8sn47dyiQcARnHL+3V9/Zrev50wAwBAEMwAABCMZAAAbNS9YxhTv/fSxzOcMAMAQBDMAAAQjGQwvDV/6hYAtmDp4xlOmAEAIAhmAAAIRjJ4mGd+0hYAeIy536+XOJ7hhBkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgOC2ctxl7lvTAAB/5/36Pk6YAQAgCGYAAAhGMrjL5yf0uNwDAGN65Pv1Ft/7nTADAEAQzAAAEIxkMLzPl34AgGVb4vu6E2YAAAiCGQAAgpEMHmaJl1gAYGsuvV9/vePFlPf1rbz3O2EGAIAgmAEAIAhmAAAIZpgBANjMPPItnDADAEAQzAAAEAQzAAAEwQwAAEEwAwBAEMwAABAEMwAABMEMAABBMAMAQBDMAAAQBDMAAATBDAAAQTADAEAQzAAAEAQzAAAEwQwAAEEwAwBAEMwAABAEMwAABMEMAABBMAMAQBDMAAAQBDMAAATBDAAAQTADAEAQzAAAEAQzAAAEwQwAAEEwAwBAEMwAABAEMwAABMEMAABBMAMAQBDMAAAQBDMAAATBDAAAQTADAEAQzAAAEAQzAAAEwQwAAEEwAwBAEMwAABAEMwAABMEMAABBMAMAQBDMAAAQBDMAAATBDAAAQTADAEAQzAAAEAQzAAAEwQwAAEEwAwBAEMwAABAEMwAAhEnBfDwen70OWJ3R983o64MRjb5vRl8fjGjKvpkUzKfT6e7FwNaMvm9GXx+MaPR9M/r6YERT9s23Kd/ocDjsPj4+dvv9fvf+/n73wmDNjsfj7nQ67Q6Hw9xLSfY1TGdfw/pcs6/fzufz+QVrAgCARfKhPwAACIIZAACCYAYAgCCYAQAgCGYAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgCCYAQAgCGYAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgCCYAQAgCGYAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACN+m/KOfP3/uPj4+dvv9fvf+/v7sNcGiHY/H3el02h0Oh92vX7/mXs5F9jVMZ1/D+lyzr9/O5/P5b9/wx48fuz9//jxqfbAJ379/3/3+/XvuZVxkX8P17GtYnyn7etJIxn6/f8iCYEtG3zejrw9GNPq+GX19MKIp+2ZSMLusA9cbfd+Mvj4Y0ej7ZvT1wYim7Bsf+gMAgCCYAQAgCGYAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgCCYAQAgCGYAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgCCYAQAgCGYAAAiCGQAAwre5F8D2nM/nq/7929vbk1YCAPB3TpgBACAIZgAACEYyeLprRzD+9vVGNACAV3LCDAAAQTADAEAwksFT3DuGMfV7G88AAJ7NCTMAAATBDAAAQTADAEAQzAAAEAQzAAAEwQwAAMFt5XiYZ95KDgBgLk6YAQAgCGYAAAhGMlgcT/cDAF7JCTMAAATBDAAAwUgGi3PP3TiMcwAA13LCDAAAQTADAEAwksGmfB3nMKIBAPyNE2YAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACJ70B8BDfX2i5hSeugmMzAkzAAAEwQwAAMFIBpvisi88xy1jGABL4YQZAACCYAYAgGAkA4DZfR7pMDoFjMYJMwAABMEMAADBSAar5JIu3M4dLwD+yQkzAAAEwQwAAMFIBg9jDAIAWCMnzAAAEAQzAAAEwQwAAMEMMwBuJQcQnDADAEAQzAAAEIxkAAAw2aURrjXfXtYJMwAABMEMAADBSAYA/7iUOscdM9Z8KReW6trXgq//fk372gkzAAAEwQwAAMFIBgD/MPd4BvA69vg0TpgBACAIZgAACEYyALjo0qfcXcaFZbFn7+OEGQAAgmAGAIAgmAEAIJhhBuBqa3qCF/Acn+eml/6a4YQZAACCYAYAgGAkAwBghdxK7nGcMAMAQBDMAAAQNj+SUZcrlv6JTgBguz53zNzjGUu/Y4YTZgAACIIZAADCJkcypl6WuPTvlngpAQCA2zhhBgCAIJgBACBsZiTjkZ8OXfonPQGAbZlyx4xqmns7aum95IQZAACCYAYAgCCYAQAgDD3DfO28zNLnYwBgTW6Ze/Ve/nz+j6/nhBkAAIJgBgCAMNRIxr23LPn69S45AMBrPfK9/NL7uFGP15tyW7r6mqVzwgwAAEEwAwBAGGok49Ee+XQ/AGA+rxj1YJot/v85YQYAgCCYAQAgzD6SYWwCAPg3GoFROGEGAIAgmAEAIMw+knHLjbABAOBVnDADAEAQzAAAEGYfyViiLd6wGwBgq5wwAwBAEMwAABAEMwAABDPMwCp8vi2lzxnAfNwuljVywgwAAEEwAwBAGGok4+tl1JEu5bjEC+O59BphPAP4ymsB93DCDAAAQTADAEAYaiTjq0uXT0Ya1QBey/6H5Zj7jhnGMHgUJ8wAABAEMwAAhKFHMi551aiGSzkwv3v39devt69hHs8cz7CveTYnzAAAEAQzAACERY5kXFKXZHyyHpbjmfvVQ01gfvc+qMze5dWcMAMAQBDMAAAQBDMAAIRVzTAXTw2E8cy9/8wzw3LYo8zJCTMAAATBDAAAYTMjGZe4xAOvNfcYxiXGM2A+l54CaC8yCifMAAAQBDMAAITNj2QAAOMwhsGInDADAEAQzAAAEIxkAE816l0xik/pA/CZE2YAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACJ70ByzW56fwLfGJggAsgxNmAAAIghkAAIKRDOCpPo9N7HZGJwBYHifMAAAQBDMAAAQjGcBL3Xtni68jHo/6vlN+BgDb5IQZAACCYAYAgGAkA5jNs0YfbhnPMIYBwCVOmAEAIAhmAAAIRjKAVTNqAcC9nDADAEAQzAAAEAQzAAAEwQwAAEEwAwBAEMwAABAEMwAABMEMAABBMAMAQBDMAAAQBDMAAATBDAAAQTADAEAQzAAAEAQzAAAEwQwAAEEwAwBAEMwAABAEMwAABMEMAABBMAMAQBDMAAAQBDMAAATBDAAAYVIwH4/HZ68DVmf0fTP6+mBEo++b0dcHI5qybyYF8+l0unsxsDWj75vR1wcjGn3fjL4+GNGUffNtyjc6HA67j4+P3X6/372/v9+9MFiz4/G4O51Ou8PhMPdSkn0N09nXsD7X7Ou38/l8fsGaAABgkXzoDwAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgCCYAQAgCGYAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgCCYAQAgCGYAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgCCYAQAgCGYAAAiCGQAAwrcp/+jnz5+7j4+P3X6/372/vz97TbBox+NxdzqddofDYffr16+5l3ORfQ3T2dewPtfs67fz+Xz+2zf88ePH7s+fP49aH2zC9+/fd79//557GRfZ13A9+xrWZ8q+njSSsd/vH7Ig2JLR983o64MRjb5vRl8fjGjKvpkUzC7rwPVG3zejrw9GNPq+GX19MKIp+8aH/gAAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgCCYAQAgCGYAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgCCYAQAgCGYAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgCCYAQAgCGYAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgCCYAQAgCGYAAAiCGQAAgmAGAIDwbe4FAAD8m/P5fNfXv729PWglbJ0TZgAACIIZAACCkQwAYBj3jmFc+l7GM7iHE2YAAAiCGQAAgpEMAGBWjxzDmPIzjGdwLSfMAAAQBDMAAATBDAAAwQwzAPBSr5hZhkdywgwAAEEwAwBAMJIBDMftn2B95h7D8FrCPZwwAwBAEMwAABCMZACzmXKJtv6NS6wwnrlHLz7zGsGjOGEGAIAgmAEAIBjJAF7qkZdrb7mbxrU/3yVdWBZ7lmdwwgwAAEEwAwBAMJIBPNXcn5i/9+d7iApc59I+edVrwaWfY/9yDyfMAAAQBDMAAAQjGcAqvOJy79ef4RIvTPd5v8w9qgXXcsIMAABBMAMAQBDMAAAQBDMAAATBDAAAQTADAEBwWzkA4KW+3pLRbeae75H/x1u8paYTZgAACIIZAACCkQwe5vPlnqmXa1wiAsBTAJfllvf7pXPCDAAAQTADAEAwksFTzHFJ7dLP3MrlolH5NPy2Tf1926f8z6W/hVteO7b+d/WK19utjGc4YQYAgCCYAQAgGMngYUb9lPNWLhfxfP5+prll/xup4m/8LYzv6z5e0+/MCTMAAATBDAAAQTADAEBY/Ayz28zAsow6636J14tplvC7BF5rTZ8hcsIMAABBMAMAQFjMSMYjL/e5fRGMYdTxDK8FwBqM+hq7RE6YAQAgCGYAAAiLGckA1m3uS4fGMK7j8i4sy9yvsUu/Y4YTZgAACIIZAADC0CMZLvnxCEu89MPz+Hu4nddkWK6R9u8SxzOcMAMAQBDMAAAQhh7JeLWvlysuXSaYclljKZcYlsj/7fo98tPc/l6Wx+8MHmOkMYxLlvIwOSfMAAAQBDMAAATBDAAAwQxzuGf2p752tLmcpVni7Wi43dffsc8QvNbcTwT7zO8V2hJmlne7Ze5lJ8wAABAEMwAABCMZPMUjbwsGUyzxEh/XMY4Fy7X0PeuEGQAAgmAGAIBgJGMGLivex//Ztvn9P4fRKeDR1vR67YQZAACCYAYAgGAkg6e75cETwGtde+nUPobxzPF+u6axi+KEGQAAgmAGAIBgJIOX28rlG1gzo1awXVt8H3fCDAAAQTADAEAYeiTj85G/y30A47p0ifaRr91bvAwM97h3X9pz/+eEGQAAgmAGAIAgmAEAIAw9w7xWZoKArfB6B+OxL6/nhBkAAIJgBgCAIJgBACAIZgAACIIZAADCYu6SsfSn/vlEKgDAMjlhBgCAIJgBACAsZiRjCYxdAACsjxNmAAAIghkAAIKRjBsYvQAA2A4nzAAAEAQzAAAEwQwAAGHzM8zmkQEAKE6YAQAgCGYAAAiLHMn4PEZxPp8n/TsAALiFE2YAAAiCGQAAwiJHMj4zdgEAwDM5YQYAgCCYAQAgCGYAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgCCYAQAgCGYAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgCCYAQAgCGYAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgCCYAQAgCGYAAAiCGQAAgmAGAIAgmAEAIAhmAAAIghkAAIJgBgCAIJgBACAIZgAACIIZAACCYAYAgDApmI/H47PXAasz+r4ZfX0wotH3zejrgxFN2TeTgvl0Ot29GNia0ffN6OuDEY2+b0ZfH4xoyr75NuUbHQ6H3cfHx26/3+/e39/vXhis2fF43J1Op93hcJh7Kcm+hunsa1ifa/b12/l8Pr9gTQAAsEg+9AcAAEEwAwBAEMwAABAEMwAABMEMAABBMAMAQBDMAAAQ/gsnsD8MD8s+gQAAAABJRU5ErkJggg==",
            "text/plain": [
              "<Figure size 900x900 with 9 Axes>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "hh = image_fused[((latents_sampled_sx[:, 1] == 2).astype(float) * (latents_sampled_dx[:, 1] == 2).astype(float)).astype(bool), :, :]\n",
        "y_hh = np.ones(len(hh))\n",
        "c_hh = c_fused[((latents_sampled_sx[:, 1] == 2).astype(float) * (latents_sampled_dx[:, 1] == 2).astype(float)).astype(bool)]\n",
        "show_images_grid(hh, 9)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAswAAALJCAYAAACgHHWpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAXRklEQVR4nO3dTVbzSLYFULkWDcMA3MbzH5RpMwBwT6/3PpcLHWSsnxvS3r3KRVKRlsM+68ZBOvR933cAAMCP/rP2AgAAoDKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAgpcxP/T+/t59fn52x+OxO5/Pc68Jmna5XLrr9dqdTqfu4+Nj7eUMsq9hPPsatueRfX0Y86S/t7e37vv7e6r1wS68vr52X19fay9jkH0Nj7OvYXvG7OtRlYzj8TjJgmBPqu+b6uuDiqrvm+rrg4rG7JtRgdmxDjyu+r6pvj6oqPq+qb4+qGjMvvFHfwAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAELysvYA59X3/4z8/HA4LrwQAgFaZMAMAQCAwAwBAIDADAECw6Q7zkNtusz4zAACJCTMAAAQCMwAABJuqZAzdRm7sv6OeAQDAPRNmAAAIBGYAAAiar2T8pYbxl9+lrgEAsE8mzAAAEAjMAAAQNFnJmLKG8Zf/T/UMAID9MGEGAIBAYAYAgKCZSsYaNYwh6hkAAPthwgwAAIHADAAAgcAMAABB6Q5zpd7yEH1mAIBtM2EGAIBAYAYAgKBUJaOFCkaingEAsD0mzAAAEAjMAAAQrF7JaL2GMeT+v0tFAwCgTSbMAAAQCMwAABCsUsnYag0jcQcNAIA2mTADAEAgMAMAQLD6XTL2SD0DAKAdJswAABAIzAAAEAjMAAAQLNZh3uOt5MYYel10mwEAajBhBgCAQGAGAIBg1kqGGgYAAK0zYQYAgEBgBgCAYPJKhhoGAABbYsIMAACBwAwAAMFiDy7hdx5WAgBQjwkzAAAEAjMAAASTVzJuawXumPE7NQwAgNpMmAEAIBCYAQAgmPUuGeoZP1PDAABohwkzAAAEAjMAAAQCMwAABIs96W/vfWa95X9ur7/XBQCozoQZAAACgRkAAILFKhm3ho7ht1TVUDX4J11X9QwAoDoTZgAACARmAAAIVqlkDLk/km+toqFS8Jyh6+11BQDWZMIMAACBwAwAAEGpSsa9Fh52oi7wsymv1/3v8poD6TPGZwQwNRNmAAAIBGYAAAhKVzJuVapnOO772VLXxcNOYDlrf94CVGDCDAAAgcAMAACBwAwAAEEzHeZblfrMrMvTASHzGQnwPBNmAAAIBGYAAAiarGTcWqqe4Yj/Z1WPe916DqAd6nVUZ8IMAACBwAwAAEHzlYxb90c3z9YFHAX9rGoNY0har2sMsLyx3yPqdVRhwgwAAIHADAAAwaYqGfcevYOG4x4AqEm9jjWZMAMAQCAwAwBAsOlKxi3HNX/X2l0xEu8DgHXM+V3ibhrMzYQZAAACgRkAAAKBGQAAgt10mHnMlnrL7NvY97LeI0xvje8SfWbmYMIMAACBwAwAAIFKBpvnSG4f1IiAe0OfC74XeJQJMwAABAIzAAAEKhn86Pa4qsWjbsdt2zXX+9Ff1m+Ha7muFr4zvEd4lAkzAAAEAjMAAAQqGfwqHVdVOnpzrNa2Su8l4DGt7V/fFzzKhBkAAAKBGQAAApUMntL63TRYXgvvE39BD1kL+/jeHvfy2tdpS6+5CTMAAAQCMwAABCoZTGbo6GXOI6EtHfdsydrHgPyjNgWs5ZnPnGrf7ybMAAAQCMwAABAIzAAAEOgwM7v7HtKzPcpqvaY920Mn1i3m4H95Aix7Y8IMAACBwAwAAIFKBotb4/ZzACxj7VsZqmEwBxNmAAAIBGYAAAhUMihj6BjP8Vpdax+9ArWp4LEVJswAABAIzAAAEKhkUJIaBhXdHyN7n8LfeKDVPFRd5mPCDAAAgcAMAACBSgYAsKpH76ahgsHSTJgBACAQmAEAIBCYAQAg0GEGJrHHp/55IiXMa4+fK9RkwgwAAIHADAAAgUoGAFCe2hNrMmEGAIBAYAYAgEAlA5icv2zfH8flwJaZMAMAQCAwAwBAoJIBMIEWHmJSdV0A1ZkwAwBAIDADAECgkgHM6r4G4K4ZAPykcm3MhBkAAAKBGQAAAoEZAAACHWaAP6rctwNgOibMAAAQCMwAABCoZAAEahcAmDADAEAgMAMAQKCSASzqtuKw9lP/1C0AGMOEGQAAAoEZAAAClQxg09QugL2Y+vNu7dpcJSbMAAAQCMwAABCoZACrmfKOGaoXANPyufqPCTMAAAQCMwAABCoZQGmOBAFYmwkzAAAEAjMAAAQCMwAABDrMQAm6ygBUZcIMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAMGowHy5XOZeB2xO9X1TfX1QUfV9U319UNGYfTMqMF+v16cXA3tTfd9UXx9UVH3fVF8fVDRm37yM+UWn06n7/Pzsjsdjdz6fn14YbNnlcumu12t3Op3WXkpkX8N49jVszyP7+tD3fb/AmgAAoEn+6A8AAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAIKXMT/0/v7efX5+dsfjsTufz3OvCZp2uVy66/XanU6n7uPjY+3lDLKvYTz7GrbnkX196Pu+/+0Xvr29dd/f31OtD3bh9fW1+/r6WnsZg+xreJx9DdszZl+PqmQcj8dJFgR7Un3fVF8fVFR931RfH1Q0Zt+MCsyOdeBx1fdN9fVBRdX3TfX1QUVj9o0/+gMAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCA4GXtBSyl7/sf//nhcFh4JTzq/tq5ZgDAkkyYAQAgEJgBACDYdCVjqIYx9DOO+utI1841AwCWZMIMAACBwAwAAMGmKhljKhiP/PuO++f37DUDAJibCTMAAAQCMwAABAIzAAAEzXeY5+zAun3ZPKa8Zq4RADA3E2YAAAgEZgAACJqsZKxxKzJH/49xjQCArTBhBgCAQGAGAICgmUpGpSfCOfr/WaVrBAAwFRNmAAAIBGYAAAhKVzIc8dfUwnVRmwEApmLCDAAAgcAMAABB6UrG7VF6pRrAHo74K73ez1LPAGjH0PePz+917f26mDADAEAgMAMAQFC6knFr7XrGHo4ctlTDAKAdY75/1OuW57r8Y8IMAACBwAwAAIHADAAAQTMd5ltL9Zm32sXZc1f5/r99q9cYoLJnv4d8ls/DdRlmwgwAAIHADAAAQZOVjFtp3L/n6sE9r8XP9nI7HIC1zfk95LP871yXcUyYAQAgEJgBACBovpKRPHo3jdaPCwCgkjXqgFuqAczFdXmcCTMAAAQCMwAABJuuZNwaqme0eCzwF0s97KVle3xfwFrS55D917ZK3zE+1/9xXZ5jwgwAAIHADAAAwW4qGbdaGf/PRT0DGMtnBGO08D5psQbwLNdlOibMAAAQCMwAABAIzAAAEOyywwy/aaVTBbCGFrqxyV4+41v4m6VWXn8TZgAACARmAAAIVDJ2roXjmrXt5egOqrDn6ru/Lq19f+zxfVXp+77F19+EGQAAAoEZAAAClQz+X6XjGgDa0cL3R4s1gLmscb1af/1NmAEAIBCYAQAgUMmAX7R+jASwpEr1DJ/fv5vzem3p9TdhBgCAQGAGAIBAJYMfVTpSW8qWjo4AKnA3hrY8+0CaLb/2JswAABAIzAAAEAjMAAAQ6DCzK1vuVwFU5vZl7Rlzzfby2pswAwBAIDADAECgksGvnr3NzNr2clwEc9jjLSaZn9uXtWfvr7kJMwAABAIzAAAEKhk8rOoR7d6PixhnzHvWewmW5W4MVGfCDAAAgcAMAACBSgZNc0S3P5VqQPzMvuQZ3j9UZMIMAACBwAwAAIFKBk9Z4o4Zjue2pYVKxe0avf/+8VoAe2XCDAAAgcAMAACBSgaTebae4bh3O1qoXQDAWCbMAAAQCMwAABAIzAAAEOgwsyq9ZQCgOhNmAAAIBGYAAAhUMpiFqgUAsBUmzAAAEAjMAAAQqGQAk3v2qY8AUIkJMwAABAIzAAAEKhkAwW2lxN1fAPbJhBkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAoEZAACCl7UXAGzb4XD4r//d9/1KKwGAvzFhBgCAQGAGAIBAJQNgpPs6yX3dBIBtMmEGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAwJP+gEXdPh3v/sl5VXiCHwC3TJgBACAQmAEAIFDJADZHpQKAKZkwAwBAIDADAECgkgGUo1IBQCUmzAAAEAjMAAAQqGQAq1G9AKAFJswAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQCMwAABAIzAAAEAjMAAAQjArMl8tl7nXA5lTfN9XXBxVV3zfV1wcVjdk3owLz9Xp9ejGwN9X3TfX1QUXV90319UFFY/bNy5hfdDqdus/Pz+54PHbn8/nphcGWXS6X7nq9dqfTae2lRPY1jGdfw/Y8sq8Pfd/3C6wJAACa5I/+AAAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgEJgBACAQmAEAIBCYAQAgeBnzQ+/v793n52d3PB678/k895qgaZfLpbter93pdOo+Pj7WXs4g+xrGs69hex7Z14e+7/vffuHb21v3/f091fpgF15fX7uvr6+1lzHIvobH2dewPWP29ahKxvF4nGRBsCfV90319UFF1fdN9fVBRWP2zajA7FgHHld931RfH1RUfd9UXx9UNGbf+KM/AAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAgEZgAACARmAAAIBGYAAAhe1l4AwNT6vh/1c4fDYeaVALAFJswAABAIzAAAEKhkAE0ZW7cAgKmYMAMAQCAwAwBAoJIBlKN2AUAlJswAABAIzAAAEKhkAIuqVLe4XYuHmAAwxIQZAAACgRkAAAKBGQAAAoEZAAACgRkAAAKBGQAAAreVA+jcYg6AYSbMAAAQCMwAABDsppIx9HQxR6+wrNs9V+mpfwAwxIQZAAACgRkAAIJNVzLGHPf6y3gAYK/koHFMmAEAIBCYAQAg2FQl49m/uL//9x1NwD45ogS2bCgv/SVH7eUz0oQZAAACgRkAAAKBGQAAguY7zHM+KUyPEeblqX8AbdtL79mEGQAAAoEZAACCJisZaxzdqmcAADxvKMdVzlcmzAAAEAjMAAAQNFPJqPQX9OoZML017phh/wIwhgkzAAAEAjMAAATNVDKqUs+A2uxLgFrV1lutfEabMAMAQCAwAwBA0EwlY42/oB+jlaME2CL7D4AlmDADAEAgMAMAQNBMJePW2vUMx8AwL3sMgEpMmAEAIBCYAQAgEJgBACBossN8a6k+s04lAMA+mTADAEAgMAMAQNB8JWNKahcAAPNqMW+ZMAMAQCAwAwBAsKlKxv2If8xdM1o8FgCArbr97vYd/Zw1noa8VSbMAAAQCMwAABBsqpJxb+ihJo54AKA+391UYcIMAACBwAwAAMGmKxm3HOUAQE1j7uagnvG4R18nd9UYZsIMAACBwAwAAIHADAAAwW46zABADc92Ze//fZ3mafzldRx7LVu/RibMAAAQCMwAABCoZAAAs5vzlmVuObeevbzeJswAABAIzAAAEKhkAACzWOPJceoZzMGEGQAAAoEZAAAClQwAYDJr1DCGqGcwFRNmAAAIBGYAAAhUMgCAp1SqYQxRz+AZJswAABAIzAAAEAjMAAAQ6DADAA9pobOc6DPzKBNmAAAIBGYAAAhUMgCAXVHD4FEmzAAAEAjMAAAQqGQA8LB0lwTH3dt3f41buGuG9yXPMGEGAIBAYAYAgEAlA4BRxh67D/2cI/HtGrq2a1c1vOeYigkzAAAEAjMAAAQqGQAMmvJI/fZ3OSrfh9vrvFQ9w3uLOZgwAwBAIDADAECgkgHA4jz4ZH/mrGd4zzA3E2YAAAgEZgAACARmAAAIdJgB+C9rP53NkwK37/5aPvqe815gaSbMAAAQCMwAABCoZLAqR6/AWD4vtmuNJwLCI0yYAQAgEJgBACBQyWBxY47bbn/GcSvMr+Vj8Pu1+8xo21A9w3VlTSbMAAAQCMwAABCoZDC7Z496HbcCj3CMvx2uH1WYMAMAQCAwAwBAoJLBLOb8i3vHrfC8lu+KAbA0E2YAAAgEZgAACARmAAAIdJiZzBqdSH1mGG8vvWWfBcDUTJgBACAQmAEAIFDJ4CmVjnjVMwCAOZgwAwBAIDADAECgkgGwYZVqU3NRwQLmZsIMAACBwAwAAIFKBk+5PQpd++jXsSzsh/0OLMmEGQAAAoEZAAAClQwms0Y9w7EsZJVqUwCtMmEGAIBAYAYAgEBgBgCAQIeZWczZm9Rbhr9pvc9s7wNrMWEGAIBAYAYAgEAlg9ndH6M+ehTsGBaml/ZVi3UNgDmZMAMAQCAwAwBAoJLB4sb8pb4aBqxnaP+tUdXwWQBUYMIMAACBwAwAAIFKBqty3ArtcGcNYK9MmAEAIBCYAQAgUMkA4Glj7n7zl98FUIEJMwAABAIzAAAEKhkATOovDz5RwwAqM2EGAIBAYAYAgEBgBgCAQIcZgEXoKQOtMmEGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgEBgBgCAQGAGAIBAYAYAgGBUYL5cLnOvAzan+r6pvj6oqPq+qb4+qGjMvhkVmK/X69OLgb2pvm+qrw8qqr5vqq8PKhqzb17G/KLT6dR9fn52x+OxO5/PTy8MtuxyuXTX67U7nU5rLyWyr2E8+xq255F9fej7vl9gTQAA0CR/9AcAAIHADAAAgcAMAACBwAwAAIHADAAAgcAMAACBwAwAAMH/AY5rEVF43yIBAAAAAElFTkSuQmCC",
            "text/plain": [
              "<Figure size 900x900 with 9 Axes>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "sqsq = image_fused[((latents_sampled_sx[:, 1] == 0).astype(float) * (latents_sampled_dx[:, 1] == 0).astype(float)).astype(bool), :, :]\n",
        "y_sqsq = np.ones(len(sqsq))\n",
        "c_sqsq = c_fused[((latents_sampled_sx[:, 1] == 0).astype(float) * (latents_sampled_dx[:, 1] == 0).astype(float)).astype(bool)]\n",
        "show_images_grid(sqsq, 9)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {},
      "outputs": [
        {
          "ename": "NameError",
          "evalue": "name 'imgs_sampled_dx' is not defined",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
            "Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mlen\u001b[39m(imgs_sampled_dx), \u001b[39mlen\u001b[39m(y_dx), \u001b[39mlen\u001b[39m(c_dx), \u001b[39mlen\u001b[39m(imgs_sampled_sx), \u001b[39mlen\u001b[39m(y_sx), \u001b[39mlen\u001b[39m(c_sx))\n",
            "\u001b[0;31mNameError\u001b[0m: name 'imgs_sampled_dx' is not defined"
          ]
        }
      ],
      "source": [
        "print(len(imgs_sampled_dx), len(y_dx), len(c_dx), len(imgs_sampled_sx), len(y_sx), len(c_sx))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 206,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1363 1363 1363 1317 1317 1317\n"
          ]
        }
      ],
      "source": [
        "print(len(no_sq_no_heart), len(y_no_sq_no_heart), len(c_no_sq_no_heart), len(no_heart_no_sq), len(y_no_heart_no_sq), len(c_no_heart_no_sq))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 207,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "353 353 353 316 316 316\n"
          ]
        }
      ],
      "source": [
        "print(len(hh), len(y_hh), len(c_hh), len(sqsq), len(y_sqsq), len(c_sqsq))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 208,
      "metadata": {},
      "outputs": [],
      "source": [
        "train_set_imgs = np.concatenate([imgs_sampled_dx[:int(len(imgs_sampled_dx)*0.8)], \n",
        "                                    imgs_sampled_sx[:int(len(imgs_sampled_sx)*0.8)], \n",
        "                                    no_sq_no_heart[:int(len(no_sq_no_heart)*0.8)], \n",
        "                                    no_heart_no_sq[:int(len(no_heart_no_sq)*0.8)], \n",
        "                                    hh[:int(len(no_heart_no_sq)*0.8)]], axis=0)\n",
        "y_train = np.concatenate([y_dx[:int(len(imgs_sampled_dx)*0.8)],\n",
        "                            y_sx[:int(len(imgs_sampled_sx)*0.8)],\n",
        "                            y_no_sq_no_heart[:int(len(no_sq_no_heart)*0.8)],\n",
        "                            y_no_heart_no_sq[:int(len(no_heart_no_sq)*0.8)],\n",
        "                            y_hh[:int(len(no_heart_no_sq)*0.8)]], axis=0)\n",
        "c_train = np.concatenate([c_dx[:int(len(imgs_sampled_dx)*0.8)],\n",
        "                            c_sx[:int(len(imgs_sampled_sx)*0.8)],\n",
        "                            c_no_sq_no_heart[:int(len(no_sq_no_heart)*0.8)],\n",
        "                            c_no_heart_no_sq[:int(len(no_heart_no_sq)*0.8)],\n",
        "                            c_hh[:int(len(no_heart_no_sq)*0.8)]], axis=0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 209,
      "metadata": {},
      "outputs": [],
      "source": [
        "test_set_imgs = np.concatenate([imgs_sampled_dx[int(len(imgs_sampled_dx)*0.8):],\n",
        "                                    imgs_sampled_sx[int(len(imgs_sampled_sx)*0.8):],\n",
        "                                    no_sq_no_heart[int(len(no_sq_no_heart)*0.8):],\n",
        "                                    no_heart_no_sq[int(len(no_heart_no_sq)*0.8):],\n",
        "                                    hh[int(len(no_heart_no_sq)*0.8):],\n",
        "                                    sqsq], axis=0)\n",
        "y_test = np.concatenate([y_dx[int(len(imgs_sampled_dx)*0.8):],\n",
        "                            y_sx[int(len(imgs_sampled_sx)*0.8):],\n",
        "                            y_no_sq_no_heart[int(len(no_sq_no_heart)*0.8):],\n",
        "                            y_no_heart_no_sq[int(len(no_heart_no_sq)*0.8):],\n",
        "                            y_hh[int(len(no_heart_no_sq)*0.8):],\n",
        "                            y_sqsq], axis=0)\n",
        "c_test = np.concatenate([c_dx[int(len(imgs_sampled_dx)*0.8):],\n",
        "                            c_sx[int(len(imgs_sampled_sx)*0.8):],\n",
        "                            c_no_sq_no_heart[int(len(no_sq_no_heart)*0.8):],\n",
        "                            c_no_heart_no_sq[int(len(no_heart_no_sq)*0.8):],\n",
        "                            c_hh[int(len(no_heart_no_sq)*0.8):],\n",
        "                            c_sqsq], axis=0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 210,
      "metadata": {},
      "outputs": [],
      "source": [
        "import os\n",
        "save_dir = './datasets/dsprites'\n",
        "os.makedirs(save_dir, exist_ok=True)\n",
        "\n",
        "train_images_file = os.path.join(save_dir, 'train_images.npy')\n",
        "test_images_file = os.path.join(save_dir, 'test_images.npy')\n",
        "train_labels_file = os.path.join(save_dir, 'train_labels.npy')\n",
        "test_labels_file = os.path.join(save_dir, 'test_labels.npy')\n",
        "train_concepts_file = os.path.join(save_dir, 'train_concepts.npy')\n",
        "test_concepts_file = os.path.join(save_dir, 'test_concepts.npy')\n",
        "\n",
        "np.save(train_images_file, train_set_imgs)\n",
        "np.save(test_images_file, test_set_imgs)\n",
        "np.save(train_labels_file, y_train)\n",
        "np.save(test_labels_file, y_test)\n",
        "np.save(train_concepts_file, c_train)\n",
        "np.save(test_concepts_file, c_test)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "default_view": {},
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "name": "deepmind_2d_shapes_dataset_public.ipynb",
      "provenance": [
        {
          "file_id": "/piper/depot/google3/experimental/deepmind/concepts/dataset2dshapes/public/deepmind_2d_shapes_dataset.ipynb?workspaceId=lmatthey:lmatthey-2dshapes-dataset:580:citc",
          "timestamp": 1493149332589
        },
        {
          "file_id": "0BxLiVtkN33-wbmVnbVQwcUhjY0U",
          "timestamp": 1493149291483
        }
      ],
      "version": "0.3.2",
      "views": {}
    },
    "kernelspec": {
      "display_name": ".venv",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
