{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "35af3558",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dsprites_dataset import DspritesDataset\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "62a00e5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = DspritesDataset(\n",
    "    h5_path=\"/datasets/dSprites/sprites.h5\",\n",
    "    csv_path=\"/datasets/dSprites/train.csv\"\n",
    "    )\n",
    "\n",
    "loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "73c6ff85",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch shape: torch.Size([64, 1, 64, 64]), Labels shape: torch.Size([64])\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA9AAAAH6CAYAAADvBqSRAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAAHHlJREFUeJzt3XuQV3X9+PHXZ1l1uXkBRaVQqLwgjoWDok5KlI2ppUPmbTCU0bxUpswE6mjhrSxz1HIyrRRM1iuFZY4mJJmDeElMccRRSiATvAzgunJT9/37o2F/rgu7L5bLXr6Px8z+4Tnncz7v/aicfX7O8nlVSiklAAAAgBZVtfcCAAAAoDMQ0AAAAJAgoAEAACBBQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKAhafny5XHmmWfGTjvtFD179oyRI0fGnDlzNvg877//fuyzzz5RqVTimmuuabb/Rz/6URxzzDGx8847R6VSiUsvvXQTrH7TWr16dVxwwQXRv3//6N69ewwfPjymT5+eeuy0adPiiCOOiP79+8c222wTn/zkJ+Mb3/hGvPDCC82OHTduXOy///7Rp0+f6NGjRwwePDguvfTSqK+v39TfEgBdxJa6Xn9UbW1tVCqV6NWrV1uXvVm09bVoaGiIyZMnxzHHHBMDBgyInj17xr777htXXnllrFq1qsmxkydPjkqlst6v2trazfXtQbuobu8FQGfQ0NAQRx99dDz33HMxfvz42HHHHePGG2+ML3zhC/HMM8/EHnvskT7XDTfcEIsWLVrv/ksuuSR22WWXGDp0aPzlL3/ZFMvf5E477bSYOnVqnH/++bHHHnvE5MmT46ijjoqZM2fG5z//+RYfO3fu3Nhhhx3ivPPOix133DGWLFkSt956axx44IExe/bs+OxnP9t47NNPPx2HHnpojB07NmpqauLZZ5+Nn/zkJzFjxoz4+9//HlVV3gME4P/bktfrterr62PChAnRs2fPjVn6Jrcxr8WKFSti7NixcdBBB8XZZ58d/fr1i9mzZ8fEiRPjr3/9azzyyCNRqVQiIuKwww6L22+/vdk5rrvuunjuuefiS1/60mb7HqFdFKBVd999d4mIcu+99zZue/PNN8v2229fTj755PR53njjjbLddtuVyy+/vERE+dnPftbsmFdffbWUUspbb71VIqJMnDhxY5e/ST355JPN1r5y5cry6U9/uhx88MFtOueSJUtKdXV1Oeuss1o99pprrikRUWbPnt2m5wKg69qS1+u1LrjggrLXXnuV0aNHl549e27U+jeljXktVq9eXWbNmtVs+2WXXVYiokyfPr3Fx69YsaL07t27fPnLX27b4qEDc/uGLmXlypWx9957x9577x0rV65s3L506dLYdddd45BDDokPP/wwIv73q1kvvfRSLF68uNXzTp06NXbeeef4+te/3rhtp512ihNOOCH++Mc/xurVq1Pru/DCC2OvvfaKU045Zb3HDBw4MHWujFJKDBw4MI499thm+1atWhXbbbddnHXWWRER8dJLL6XeaZ86dWp069YtzjzzzMZtNTU1cfrpp8fs2bPjP//5zwavs1+/ftGjR49Yvnx5q8eufX0yxwLQMXWF63VExCuvvBLXXXddXHvttVFd3fZf7Kyvr4+ePXvGeeed12zfa6+9Ft26dYurrrpqi70WW2+9dRxyyCHNto8aNSoiIubNm9fic99///3x7rvvxujRo1tdJ3Q2ApoupXv37nHbbbfF/Pnz4+KLL27c/p3vfCfeeeedmDx5cnTr1i0iIv773//G4MGD46KLLmr1vM8++2zsv//+zX5l+MADD4wVK1bEyy+/3Oo5nnrqqbjtttvi+uuvb/y1p82tUqnEKaecEg8++GAsXbq0yb77778/6urqGn84GDx4cIwZM6bVcz777LOx5557xrbbbttk+4EHHhgREf/85z9Ta1u+fHm89dZbMXfu3DjjjDOirq5unb/m9cEHH8Tbb78dr7/+ejz88MNxySWXRO/evRufD4DOp6tcr88///wYOXJkHHXUUa2etyW9evWKUaNGxd133934xsFad955Z5RSYvTo0Vv8tfi4JUuWRETEjjvu2OJxtbW10b179ybxDl2FgKbLGT58eEyYMCF+/vOfx2OPPRZTp06Nu+66K6666qrYc88923TOxYsXx6677tps+9ptr7/+eouPL6XEueeeGyeeeGIcfPDBbVpDW40ZMybef//9uOeee5psnzJlSgwcOLDVv7P8cRv7Wqx10EEHRb9+/WK//faLe+65Jy655JI4/fTTmx33j3/8I3baaaf4xCc+EUcccUSUUuJPf/pT9OnTZ4PWDUDH0tmv1w888EA8/PDDce2117ZprR83ZsyYeOONN5p9KOeUKVPisMMOi912222DzreprtcfdfXVV8e2224bRx555HqPWbp0aTz00EPxta99LXr37r3BzwEdnQ8Ro0u69NJL489//nOceuqpUV9fHyNGjIjvfe97TY4ZOHBglFJS51u5cmVss802zbbX1NQ07m/J5MmTY+7cuTF16tTkd7Dp7LnnnjF8+PCora2Ns88+OyL+d3F78MEHY8KECY3vrm+p12KtSZMmRV1dXfz73/+OSZMmxcqVK+PDDz9s9k75PvvsE9OnT4/33nsvHn/88ZgxY4ZP4QboIjrr9XrNmjUxbty4OPvss2OfffZJra01hx9+ePTv3z9qa2vjK1/5SkREvPDCC/H888/Hb37zm4jYsq/Fx/34xz+OGTNmxI033hjbb7/9eo+bOnVqrFmzxq9v02UJaLqkrbfeOm699dY44IADoqamJiZNmrRRvzbdvXv3df5dobWjHLp3777ex9bV1cVFF10U48ePjwEDBrR5DRtjzJgx8d3vfjcWLlwYu+++e9x7773x/vvvxze/+c0NPtfGvBYf9dF39k866aQYPHhwRESzUSHbbrttHH744RERceyxx8Ydd9wRxx57bMyZM6fJJ3YD0Pl01uv1ddddF2+//XZcdtllbV7rx1VVVcXo0aPjV7/6VaxYsSJ69OgRtbW1UVNTE8cff/wGn29TXa8jIu6+++7G3xQ755xzWjy2trY2+vTp0+JdaujM/Ao3XdbaEVCrVq2KV155ZaPOteuuu67zAzvWbuvfv/96H3vNNdfEmjVr4sQTT4wFCxbEggUL4rXXXouIiGXLlsWCBQtizZo1G7W+1px00kmx1VZbNc5inDJlSgwbNiz22muvDT7XxrwW67PDDjvEF7/4xdSsyLV/n+quu+7a4OcBoOPpbNfrd955J6688sr41re+FXV1dY3H1tfXRyklFixYEG+++Wab1j9mzJior6+P++67L0opcccdd8RXv/rV2G677Tb4XJvqej19+vQYM2ZMHH300XHTTTe1eOyiRYvisccei+OPPz622mqrDV4zdAYCmi7p+eefj8svvzzGjh0bQ4cOjTPOOCPeeeedNp/vc5/7XMyZMycaGhqabH/yySejR48eLf5drUWLFsWyZctiyJAhMWjQoBg0aFAceuihEfG/X4caNGhQvPjii21eW0afPn3i6KOPjtra2li4cGHMmjWrTXefI/73Wrz88stRV1fXZPuTTz7ZuL8tVq5cmfp3tHr16mhoaNiof58AdAyd8Xq9bNmyqK+vj6uvvrrxuEGDBsXvf//7WLFiRQwaNKjJpIoNse+++8bQoUOjtrY2HnvssVi0aNFGXa/b+lp89NhRo0bFsGHD4p577mn1k8Y/+oFn0GW1z/Qs2HzWrFlThg4dWgYOHFjq6urKc889V7beeusyduzYZsfNmzevvP76662e86677mo2S/Gtt94q22+/fTnxxBObHDt//vwyf/78xn9+5plnyrRp05p83XzzzSUiymmnnVamTZtWli9f3uw5N/Uc6D/84Q8lIsrxxx9fqquryxtvvNFk/7x588rChQtbPc8TTzzRbCbmqlWrymc+85kyfPjwJscuXLiwzJs3r8m2jz9vKf+bfd27d+9y6KGHNm5btmxZWbNmTbNj186BvuWWW1pdKwAdV2e9Xr/33nvNjps2bVoZOXJkqampKdOmTStPPPFEm1+Xa6+9tlRXV5dRo0aVvn37NrkWbqnXopRSXnzxxdK3b98yZMiQsnTp0tTa99tvv7LbbruVhoaG1PHQGQloupwf/vCHpVKplEceeaRx25VXXlkiojzwwAON21599dUSEeXUU09t9ZwffPBBOeigg0qvXr3KZZddVn75y1+WIUOGlN69e5eXXnqpybG777572X333Vs839rn/miErvW73/2uXHHFFeWiiy4qEVFGjhxZrrjiinLFFVeUBQsWNB43c+bMDQrs1atXl759+5aIKEceeWSz/RFRRowYkTrX2ggfP358ufnmm8shhxxSqqury6OPPtrkuBEjRpSPv0/Xr1+/cvLJJ5ef/vSn5de//nUZP3586dOnT6mpqSmzZs1qPG7atGllwIABZdy4ceXGG28s119/fTnuuONKpVIpw4YNK6tXr06tFYCOqbNfrz/u1FNPLT179my2fdKkSSUiyqRJk1o9RymlLFmypFRXV5eIKOecc84617O5X4u6uroyYMCAUlVVVX7yk5+U22+/vcnX448/3uz55s6dWyKiXHjhhanvEzorAU2X8swzz5Tq6upy7rnnNtn+wQcflAMOOKD079+/LFu2rJSyYRehUkpZunRpOf3000vfvn1Ljx49yogRI8rTTz/d7LiNvSCvjc51fc2cObPxuPvvv79ERLnppptS6y+llG9/+9slIsodd9zRbN+GBPTKlSvL97///bLLLruUbbbZphxwwAHloYceWu/38lETJ04sw4YNKzvssEOprq4u/fv3LyeddFJ5/vnnmxw3f/78MmbMmPKpT32qdO/evdTU1JQhQ4aUiRMnlvr6+vT3DEDH0xWu1x+3voC+4YYbSkSs8zq5PkcddVSJiGahuqVei7XPs76vdT3/hRdeWCKi2fUcuppKKcnPwgc6lAkTJsSdd94Z8+fPX+eYinUZN25c3HLLLbFkyZLo0aPHZl4hAHDCCSfEggUL4qmnnko/ZtSoUTF37tyYP3/+ZlwZ0BbGWEEnNXPmzPjBD36QjudVq1bFlClT4rjjjhPPALAFlFLib3/7W0yZMiX9mMWLF8cDDzwQF1988WZcGdBW7kBDF/fmm2/GjBkzYurUqXHffffFnDlz2vxJ2QDA5vHqq6/GrFmz4re//W08/fTT8a9//St22WWX9l4W8DHuQEMX9+KLL8bo0aOjX79+8Ytf/EI8A0AH9Oijj8bYsWNjt912i9tuu008QwflDjQAAAAkVLX3AgAAAKAzENAAAACQIKABAAAgIf0hYpVKZXOuAwC6rC39cSOu2ZuPj46B9uXPNza31v6cdwcaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABIENAAAACQIaAAAAEgQ0AAAAJAgoAEAACBBQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKABAAAgQUADAABAgoAGAACABAENAAAACQIaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABIENAAAACQIaAAAAEgQ0AAAAJAgoAEAACBBQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkVLf3AqAzqlQq691XVdXy+1ItPbat+9rjOTfmse+++26L5wUAgI7IHWgAAABIENAAAACQIKABAAAgQUADAABAgoAGAACABAENAAAACcZYdXK9evVa777Fixe36ZwbMxKppcdurpFIm2u8Exunvr5+vft69+69BVcCAACbhjvQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABKMserkSinr3dfSiCsAAAA2jDvQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABKMserkWhpjBe2pqsr7cwAAdC1+wgUAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKABAAAgwRzoTs4caDqqSqXS3ksAAIBNyh1oAAAASBDQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAnGWHVyxljRUVVVeX8OAICuxU+4AAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABKMserkGhoa2nsJsE6VSqW9lwAAAJuUO9AAAACQIKABAAAgQUADAABAgoAGAACABAENAAAACQIaAAAAEoyx6uRKKe29BFgnY6wAAOhq3IEGAACABAENAAAACQIaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIMEc6E7OHGg6KnOgAQDoatyBBgAAgAQBDQAAAAkCGgAAABIENAAAACQIaAAAAEgQ0AAAAJBgjFUnZ4wVHZUxVgAAdDXuQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIMMaqk2toaGjvJcA6VVV5fw4AgK7FT7gAAACQIKABAAAgQUADAABAgoAGAACABAENAAAACQIaAAAAEgQ0AAAAJJgD3cmVUtp7CbBOlUqlvZcAAACblDvQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABKMsQIASGppfKTxfbD5tfb/mRGvbG7uQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIMMaqCzNqg46qpf/+jJ8AAKCjcgcaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIEFAAwAAQIIxVl1YQ0PDevd169ZtC64EmqqqWv97dx9++OEWXAkAAOS5Aw0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKABAAAgQUADAABAgjnQXVgppb2XAOtUqVTaewkAALDB3IEGAACABAENAAAACQIaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkGCMVRdmjBUdlTFWANA5+HkSmnIHGgAAABIENAAAACQIaAAAAEgQ0AAAAJAgoAEAACBBQAMAAECCMVZdmLEDdFTGWAGdVbdu3bb4c7qeA3Qc7kADAABAgoAGAACABAENAAAACQIaAAAAEgQ0AAAAJAhoAAAASDDGqgsz9oKOqqrKe3dA5+TaCvB/m59iAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABIENAAAACQIaAAAAEgwB7oLa2hoaO8lwDpVKpX2XgIAAGwwd6ABAAAgQUADAABAgoAGAACABAENAAAACQIaAAAAEgQ0AAAAJBhj1YWVUtp7CbBOxlgBANAZuQMNAAAACQIaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIMEYqy6soaGhvZcA62SMFQAAnZE70AAAAJAgoAEAACBBQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASjLECtjhjrAAA6IzcgQYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKABAAAgwRzoLqyhoaG9l9BllVLWu6+1172lx26Ofa2taWO+l7Y+pznQAAB0Ru5AAwAAQIKABgAAgAQBDQAAAAkCGgAAABIENAAAACQIaAAAAEiolNbm36w90NiZTqdXr17r3bcxI5Ha+tj2GO+0Mc8JsKls6T9rXLMBoG1au2a7Aw0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKABAAAgwRgrANjMjLECgM7BGCsAAADYBAQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABIENAAAACQIaAAAAEgQ0AAAAJAgoAEAACBBQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKABAAAgQUADAABAgoAGAACABAENAAAACQIaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABIENAAAACQIaAAAAEgQ0AAAAJAgoAEAACBBQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKABAAAgQUADAABAgoAGAACABAENAAAACQIaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABIENAAAACQIaAAAAEgQ0AAAAJAgoAEAACBBQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKABAAAgQUADAABAgoAGAACABAENAAAACQIaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABIENAAAACQIaAAAAEgQ0AAAAJAgoAEAACBBQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKABAAAgQUADAABAgoAGAACABAENAAAACQIaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABIENAAAACQIaAAAAEgQ0AAAAJAgoAEAACBBQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKABAAAgQUADAABAgoAGAACABAENAAAACQIaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABIENAAAACQIaAAAAEgQ0AAAAJAgoAEAACBBQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKABAAAgQUADAABAgoAGAACABAENAAAACQIaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABIENAAAACQIaAAAAEgQ0AAAAJAgoAEAACBBQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKABAAAgQUADAABAgoAGAACABAENAAAACQIaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABIENAAAACQIaAAAAEgQ0AAAAJAgoAEAACBBQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKABAAAgQUADAABAgoAGAACABAENAAAACQIaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABIENAAAACQIaAAAAEgQ0AAAAJAgoAEAACBBQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKABAAAgQUADAABAgoAGAACABAENAAAACQIaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABIENAAAACQIaAAAAEgQ0AAAAJAgoAEAACBBQAMAAECCgAYAAIAEAQ0AAAAJAhoAAAASBDQAAAAkCGgAAABIENAAAACQIKABAAAgQUADAABAgoAGAACAhEoppbT3IgAAAKCjcwcaAAAAEgQ0AAAAJAhoAAAASBDQAAAAkCCgAQAAIEFAAwAAQIKABgAAgAQBDQAAAAkCGgAAABL+HyFggebkbohCAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 1000x500 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# load a batch of images and visualize 2 of them\n",
    "batch = next(iter(loader))\n",
    "images = batch[\"img\"]\n",
    "x = batch[\"x\"]\n",
    "y = batch[\"y\"]\n",
    "print(f\"Batch shape: {images.shape}, Labels shape: {y.shape}\")\n",
    "plt.figure(figsize=(10, 5))\n",
    "for i in range(2):\n",
    "    plt.subplot(1, 2, i + 1)\n",
    "    plt.imshow(images[i].squeeze(), cmap='gray')\n",
    "    plt.title(f\"x: {x[i].item():.2f}, y: {y[i].item():.2f}\")\n",
    "    plt.axis('off')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "disco",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
