{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "26ScWNvYSgQg"
      },
      "source": [
        "Copyright 2017 Google Inc.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License.\n",
        "\n",
        "\n",
        "# dSprites - Disentanglement testing Sprites dataset\n",
        "\n",
        "## Description\n",
        "Procedurally generated 2D shapes dataset. This dataset uses 6 latents, controlling the color, shape, scale, rotation and position of a sprite (color isn't varying here, its value is fixed).\n",
        "\n",
        "All possible combinations of the latents are present.\n",
        "\n",
        "The ordering of images in the dataset (i.e. shape[0] in all ndarrays) is fixed and meaningful, see below.\n",
        "\n",
        "We chose the smallest changes in latent values that generated different pixel outputs at our 64x64 resolution after rasterization.\n",
        "\n",
        "No noise added, single image sample for a given latent setting.\n",
        "\n",
        "## Details about the ordering of the dataset\n",
        "\n",
        "The dataset was generated procedurally, and its order is deterministic.\n",
        "For example, the image at index 0 corresponds to the latents (0, 0, 0, 0, 0, 0).\n",
        "\n",
        "Then the image at index 1 increases the least significant \"bit\" of the latent:\n",
        "(0, 0, 0, 0, 0, 1)\n",
        "\n",
        "And similarly, till we reach index 32, where we get (0, 0, 0, 0, 1, 0). \n",
        "\n",
        "Hence the dataset is sequentially addressable using variable bases for every \"bit\".\n",
        "Using dataset['metadata']['latents_sizes'] makes this conversion trivial, see below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "cellView": "both",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "jJ02BsnqSa96"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import\n",
        "from __future__ import division\n",
        "from __future__ import print_function\n",
        "from matplotlib import pyplot as plt\n",
        "import numpy as np\n",
        "import seaborn as sns\n",
        "\n",
        "# Change figure aesthetics\n",
        "%matplotlib inline\n",
        "sns.set_context('talk', font_scale=1.2, rc={'lines.linewidth': 1.5})\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "output_extras": [
            {
              "item_id": 2
            }
          ]
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 10952,
          "status": "ok",
          "timestamp": 1495021223246,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "uDL3Iw0WFw1L",
        "outputId": "1a3ce845-1add-41c3-ee3d-6018d09423bc"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Keys in the dataset: KeysView(NpzFile 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz' with keys: metadata, imgs, latents_classes, latents_values)\n",
            "Metadata: \n",
            " {'date': 'April 2017', 'description': 'Disentanglement test Sprites dataset.Procedurally generated 2D shapes, from 6 disentangled latent factors.This dataset uses 6 latents, controlling the color, shape, scale, rotation and position of a sprite. All possible variations of the latents are present. Ordering along dimension 1 is fixed and can be mapped back to the exact latent values that generated that image.We made sure that the pixel outputs are different. No noise added.', 'version': 1, 'latents_names': ('color', 'shape', 'scale', 'orientation', 'posX', 'posY'), 'latents_possible_values': {'orientation': array([0.        , 0.16110732, 0.32221463, 0.48332195, 0.64442926,\n",
            "       0.80553658, 0.96664389, 1.12775121, 1.28885852, 1.44996584,\n",
            "       1.61107316, 1.77218047, 1.93328779, 2.0943951 , 2.25550242,\n",
            "       2.41660973, 2.57771705, 2.73882436, 2.89993168, 3.061039  ,\n",
            "       3.22214631, 3.38325363, 3.54436094, 3.70546826, 3.86657557,\n",
            "       4.02768289, 4.1887902 , 4.34989752, 4.51100484, 4.67211215,\n",
            "       4.83321947, 4.99432678, 5.1554341 , 5.31654141, 5.47764873,\n",
            "       5.63875604, 5.79986336, 5.96097068, 6.12207799, 6.28318531]), 'posX': array([0.        , 0.03225806, 0.06451613, 0.09677419, 0.12903226,\n",
            "       0.16129032, 0.19354839, 0.22580645, 0.25806452, 0.29032258,\n",
            "       0.32258065, 0.35483871, 0.38709677, 0.41935484, 0.4516129 ,\n",
            "       0.48387097, 0.51612903, 0.5483871 , 0.58064516, 0.61290323,\n",
            "       0.64516129, 0.67741935, 0.70967742, 0.74193548, 0.77419355,\n",
            "       0.80645161, 0.83870968, 0.87096774, 0.90322581, 0.93548387,\n",
            "       0.96774194, 1.        ]), 'posY': array([0.        , 0.03225806, 0.06451613, 0.09677419, 0.12903226,\n",
            "       0.16129032, 0.19354839, 0.22580645, 0.25806452, 0.29032258,\n",
            "       0.32258065, 0.35483871, 0.38709677, 0.41935484, 0.4516129 ,\n",
            "       0.48387097, 0.51612903, 0.5483871 , 0.58064516, 0.61290323,\n",
            "       0.64516129, 0.67741935, 0.70967742, 0.74193548, 0.77419355,\n",
            "       0.80645161, 0.83870968, 0.87096774, 0.90322581, 0.93548387,\n",
            "       0.96774194, 1.        ]), 'scale': array([0.5, 0.6, 0.7, 0.8, 0.9, 1. ]), 'shape': array([1., 2., 3.]), 'color': array([1.])}, 'latents_sizes': array([ 1,  3,  6, 40, 32, 32]), 'author': 'lmatthey@google.com', 'title': 'dSprites dataset'}\n"
          ]
        }
      ],
      "source": [
        "# Load dataset\n",
        "dataset_zip = np.load('dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', allow_pickle=True, encoding='latin1')\n",
        "\n",
        "print('Keys in the dataset:', dataset_zip.keys())\n",
        "imgs = dataset_zip['imgs']\n",
        "latents_values = dataset_zip['latents_values']\n",
        "latents_classes = dataset_zip['latents_classes']\n",
        "metadata = dataset_zip['metadata'][()]\n",
        "\n",
        "print('Metadata: \\n', metadata)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "9RWpIJtiHYUL"
      },
      "outputs": [],
      "source": [
        "# Define number of values per latents and functions to convert to indices\n",
        "latents_sizes = metadata['latents_sizes']\n",
        "latents_bases = np.concatenate((latents_sizes[::-1].cumprod()[::-1][1:],\n",
        "                                np.array([1,])))\n",
        "\n",
        "def latent_to_index(latents):\n",
        "  return np.dot(latents, latents_bases).astype(int)\n",
        "\n",
        "\n",
        "def sample_latent(size=1):\n",
        "  samples = np.zeros((size, latents_sizes.size))\n",
        "  for lat_i, lat_size in enumerate(latents_sizes):\n",
        "    samples[:, lat_i] = np.random.randint(lat_size, size=size)\n",
        "\n",
        "  return samples\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "W8LKpGjGKaiN"
      },
      "outputs": [],
      "source": [
        "# Helper function to show images\n",
        "def show_images_grid(imgs_, num_images=25):\n",
        "  ncols = int(np.ceil(num_images**0.5))\n",
        "  nrows = int(np.ceil(num_images / ncols))\n",
        "  _, axes = plt.subplots(ncols, nrows, figsize=(nrows * 3, ncols * 3))\n",
        "  axes = axes.flatten()\n",
        "\n",
        "  for ax_i, ax in enumerate(axes):\n",
        "    if ax_i < num_images:\n",
        "      ax.imshow(imgs_[ax_i], cmap='Greys_r',  interpolation='nearest')\n",
        "      ax.set_xticks([])\n",
        "      ax.set_yticks([])\n",
        "    else:\n",
        "      ax.axis('off')\n",
        "\n",
        "def show_density(imgs):\n",
        "  _, ax = plt.subplots()\n",
        "  ax.imshow(imgs.mean(axis=0), interpolation='nearest', cmap='Greys_r')\n",
        "  ax.grid('off')\n",
        "  ax.set_xticks([])\n",
        "  ax.set_yticks([])"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "('color', 'shape', 'scale', 'orientation', 'posX', 'posY')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 36,
      "metadata": {},
      "outputs": [],
      "source": [
        "## Fix posX latent to left\n",
        "latents_sampled = sample_latent(size=10000)\n",
        "indices_sampled = latent_to_index(latents_sampled)\n",
        "latents_sampled[:, 2] = (latents_sampled[:, 2] >= 3.0).astype(np.float32)\n",
        "latents_sampled[:, 5] = (latents_sampled[:, 5] >= 16.0).astype(np.float32)\n",
        "latents_sampled[:, 4] = (latents_sampled[:, 4] >= 16.0).astype(np.float32)\n",
        "imgs_sampled = imgs[indices_sampled] \n",
        "c_dx = (np.arange(3) == latents_sampled[:, 1][:,None]).astype(np.float32) # `shape\n",
        "c_dx = np.concatenate([c_dx, \n",
        "                       (np.arange(2) == latents_sampled[:, 2][:,None]).astype(np.float32), # size\n",
        "                       (np.arange(2) == latents_sampled[:, 5][:,None]).astype(np.float32), # posY\n",
        "                       (np.arange(2) == latents_sampled[:, 4][:,None]).astype(np.float32), # posX\n",
        "                       (np.arange(3) == latents_sampled[:, 0][:,None]).astype(np.float32)], axis=1) # color"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 37,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "array([0., 0., 1., 0., 1., 0., 1., 1., 0., 1., 0., 0.], dtype=float32)"
            ]
          },
          "execution_count": 37,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "c_dx[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 38,
      "metadata": {},
      "outputs": [],
      "source": [
        "filter_shape0_pos0 = ((c_dx[:, 0] == 1).astype(np.float32) *            # shape 0\n",
        "                 (c_dx[:, 3] == 1).astype(np.float32) *                 # size 0         \n",
        "                 (c_dx[:, 7] == 1).astype(np.float32)).astype(bool)     # posX 0\n",
        "latents_sampled_0 = latents_sampled[filter_shape0_pos0]\n",
        "c_dx_0 = c_dx[filter_shape0_pos0]\n",
        "imgs_sampled_0 = imgs_sampled[filter_shape0_pos0]\n",
        "\n",
        "filter_shape0_pos1 = ((c_dx[:, 0] == 1).astype(np.float32) *            # shape 0\n",
        "                    (c_dx[:, 3] == 1).astype(np.float32) *              # size 0\n",
        "                    (c_dx[:, 8] == 1).astype(np.float32)).astype(bool)  # posX 1\n",
        "latents_sampled_1 = latents_sampled[filter_shape0_pos1]\n",
        "c_dx_1 = c_dx[filter_shape0_pos1]\n",
        "imgs_sampled_1 = imgs_sampled[filter_shape0_pos1]\n",
        "\n",
        "filter_shape1_pos0 = ((c_dx[:, 1] == 1).astype(np.float32) *            # shape 1\n",
        "                    (c_dx[:, 4] == 1).astype(np.float32) *              # size 1              \n",
        "                    (c_dx[:, 8] == 1).astype(np.float32)).astype(bool)  # posX 1\n",
        "latents_sampled_2 = latents_sampled[filter_shape1_pos0]\n",
        "c_dx_2 = c_dx[filter_shape1_pos0]\n",
        "imgs_sampled_2 = imgs_sampled[filter_shape1_pos0]\n",
        "\n",
        "filter_shape1_pos1 = ((c_dx[:, 1] == 1).astype(np.float32) *            # shape 1\n",
        "                    (c_dx[:, 3] == 1).astype(np.float32) *              # size 0\n",
        "                    (c_dx[:, 7] == 1).astype(np.float32)).astype(bool)  # posX 0\n",
        "latents_sampled_3 = latents_sampled[filter_shape1_pos1]\n",
        "c_dx_3 = c_dx[filter_shape1_pos1]\n",
        "imgs_sampled_3 = imgs_sampled[filter_shape1_pos1]\n",
        "\n",
        "latent_sample = np.concatenate([latents_sampled_0, latents_sampled_1, latents_sampled_2, latents_sampled_3], axis=0)\n",
        "c_dx = np.concatenate([c_dx_0, c_dx_1, c_dx_2, c_dx_3], axis=0)\n",
        "imgs_sample = np.concatenate([imgs_sampled_0, imgs_sampled_1, imgs_sampled_2, imgs_sampled_3], axis=0)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 39,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "861"
            ]
          },
          "execution_count": 39,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "filter_shape1_pos1.sum()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 40,
      "metadata": {},
      "outputs": [],
      "source": [
        "filter_pos0_shape0 = ((c_dx[:, 0] == 1).astype(np.float32) *               # shape 0\n",
        "                      (c_dx[:, 5] == 1).astype(np.float32)).astype(bool)   # posY 0\n",
        "c_dx[filter_pos0_shape0, -3:] = np.array([1, 0, 0])                        # color 0\n",
        "\n",
        "filter_pos1_shape0 = ((c_dx[:, 0] == 1).astype(np.float32) *               # shape 0\n",
        "                      (c_dx[:, 6] == 1).astype(np.float32)).astype(bool)   # posY 1\n",
        "c_dx[filter_pos1_shape0, -3:] = np.array([1, 0, 0])                        # color 0\n",
        "\n",
        "filter_pos0_shape1 = ((c_dx[:, 1] == 1).astype(np.float32) *               # shape 1\n",
        "                      (c_dx[:, 5] == 1).astype(np.float32)).astype(bool)   # posY 0\n",
        "c_dx[filter_pos0_shape1, -3:] = np.array([1, 0, 0])                        # color 0\n",
        "\n",
        "filter_pos1_shape1 = ((c_dx[:, 1] == 1).astype(np.float32) *               # shape 1\n",
        "                      (c_dx[:, 6] == 1).astype(np.float32)).astype(bool)   # posY 1\n",
        "c_dx[filter_pos1_shape1, -3:] = np.array([0, 1, 0])                        # color 1\n",
        "\n",
        "\n",
        "y = np.zeros((c_dx.shape[0]))\n",
        "\n",
        "filter_col1_size0 = ((c_dx[:, 3] == 1).astype(np.float32) *               # size 0\n",
        "                     (c_dx[:, 9] == 1).astype(np.float32)).astype(bool)   # color 0\n",
        "y[filter_col1_size0] = 0                                                  # label 0\n",
        "\n",
        "filter_col2_size0 = ((c_dx[:, 4] == 1).astype(np.float32) *               # size 1\n",
        "                     (c_dx[:, 10] == 1).astype(np.float32)).astype(bool)  # color 1\n",
        "y[filter_col2_size0] = 1                                                  # label 1              \n",
        "\n",
        "filter_col1_size1 = ((c_dx[:, 4] == 1).astype(np.float32) *               # size 1\n",
        "                     (c_dx[:, 9] == 1).astype(np.float32)).astype(bool)   # color 0\n",
        "y[filter_col1_size1] = 0                                                  # label 0\n",
        "\n",
        "filter_col2_size1 = ((c_dx[:, 3] == 1).astype(np.float32) *               # size 0\n",
        "                     (c_dx[:, 10] == 1).astype(np.float32)).astype(bool)  # color 1\n",
        "y[filter_col2_size1] = 0                                                  # label 0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 41,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "array([1., 0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 0.], dtype=float32)"
            ]
          },
          "execution_count": 41,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "c_dx[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 42,
      "metadata": {},
      "outputs": [],
      "source": [
        "# drop column if all 0\n",
        "c_dx = c_dx[:, [0, 1, 3, 4, 5, 6, 7, 8, 9, 10]]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": 43,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[1, 3, 5, 7, 9]"
            ]
          },
          "execution_count": 43,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "list(range(1, 10, 2))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 44,
      "metadata": {},
      "outputs": [],
      "source": [
        "c_dx = c_dx[:, list(range(1, 10, 2))]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 45,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "array([[0., 0., 0., 0., 0.],\n",
              "       [0., 0., 0., 0., 0.],\n",
              "       [0., 0., 0., 0., 0.],\n",
              "       ...,\n",
              "       [1., 0., 0., 0., 0.],\n",
              "       [1., 0., 1., 0., 1.],\n",
              "       [1., 0., 0., 0., 0.]], dtype=float32)"
            ]
          },
          "execution_count": 45,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "c_dx"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 46,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "(array([[0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 1., 0., 0.],\n",
              "        [0., 0., 1., 0., 0., 0.],\n",
              "        [0., 0., 1., 1., 0., 0.],\n",
              "        [1., 0., 0., 0., 0., 0.],\n",
              "        [1., 0., 1., 0., 1., 0.],\n",
              "        [1., 1., 0., 1., 0., 0.],\n",
              "        [1., 1., 1., 1., 1., 1.]]),\n",
              " array([422, 421, 399, 402, 433, 428, 398, 428]))"
            ]
          },
          "execution_count": 46,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "c_y = np.concatenate([c_dx, np.expand_dims(y, axis=-1)], axis=1)\n",
        "np.unique(c_y, axis=0, return_counts=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {},
      "outputs": [],
      "source": [
        "# random shuffle the data\n",
        "np.random.seed(0)\n",
        "# shuffle indices\n",
        "indices = np.arange(y.shape[0])\n",
        "np.random.shuffle(indices)\n",
        "# shuffle data\n",
        "c_dx = c_dx[indices]\n",
        "latent_sample = latent_sample[indices]\n",
        "imgs_sample = imgs_sample[indices]\n",
        "y = y[indices]\n",
        "\n",
        "# split data into train and test (80%)\n",
        "c_train = c_dx[:int(0.8*c_dx.shape[0])]\n",
        "c_test = c_dx[int(0.8*c_dx.shape[0]):]\n",
        "latent_sample_train = latent_sample[:int(0.8*latent_sample.shape[0])]\n",
        "latent_sample_test = latent_sample[int(0.8*latent_sample.shape[0]):]\n",
        "train_set_imgs = imgs_sample[:int(0.8*imgs_sample.shape[0])]\n",
        "test_set_imgs = imgs_sample[int(0.8*imgs_sample.shape[0]):]\n",
        "y_train = y[:int(0.8*y.shape[0])]\n",
        "y_test = y[int(0.8*y.shape[0]):]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {},
      "outputs": [],
      "source": [
        "import os\n",
        "save_dir = './datasets/dsprites'\n",
        "os.makedirs(save_dir, exist_ok=True)\n",
        "\n",
        "train_images_file = os.path.join(save_dir, 'train_images.npy')\n",
        "test_images_file = os.path.join(save_dir, 'test_images.npy')\n",
        "train_labels_file = os.path.join(save_dir, 'train_labels.npy')\n",
        "test_labels_file = os.path.join(save_dir, 'test_labels.npy')\n",
        "train_concepts_file = os.path.join(save_dir, 'train_concepts.npy')\n",
        "test_concepts_file = os.path.join(save_dir, 'test_concepts.npy')\n",
        "\n",
        "np.save(train_images_file, train_set_imgs)\n",
        "np.save(test_images_file, test_set_imgs)\n",
        "np.save(train_labels_file, y_train)\n",
        "np.save(test_labels_file, y_test)\n",
        "np.save(train_concepts_file, c_train)\n",
        "np.save(test_concepts_file, c_test)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 159,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "array([[[0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        ...,\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0]],\n",
              "\n",
              "       [[0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        ...,\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0]],\n",
              "\n",
              "       [[0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        ...,\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0]],\n",
              "\n",
              "       ...,\n",
              "\n",
              "       [[0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        ...,\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0]],\n",
              "\n",
              "       [[0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        ...,\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0]],\n",
              "\n",
              "       [[0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        ...,\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0],\n",
              "        [0, 0, 0, ..., 0, 0, 0]]], dtype=uint8)"
            ]
          },
          "execution_count": 159,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "train_set_imgs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "default_view": {},
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "name": "deepmind_2d_shapes_dataset_public.ipynb",
      "provenance": [
        {
          "file_id": "/piper/depot/google3/experimental/deepmind/concepts/dataset2dshapes/public/deepmind_2d_shapes_dataset.ipynb?workspaceId=lmatthey:lmatthey-2dshapes-dataset:580:citc",
          "timestamp": 1493149332589
        },
        {
          "file_id": "0BxLiVtkN33-wbmVnbVQwcUhjY0U",
          "timestamp": 1493149291483
        }
      ],
      "version": "0.3.2",
      "views": {}
    },
    "kernelspec": {
      "display_name": ".venv",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
