{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import os\n",
                "import sys\n",
                "from copy import deepcopy\n",
                "\n",
                "import numpy as np\n",
                "import torch\n",
                "from torchvision.utils import save_image\n",
                "\n",
                "from source.dataset.dsprites import DSprites\n",
                "from source.dataset.shapes3d import Shapes3d\n",
                "from source.dataset.r2e import r2e_dsprites, r2e_shape3d\n",
                "from source.dataset.r2r import r2r_dsprites, r2r_shape3d\n",
                "\n",
                "BATCH_SIZE = 64\n",
                "GT_ROOT = \"./gt/\""
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 16,
            "metadata": {},
            "outputs": [],
            "source": [
                "dsprites_data_path = \"/data/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz\"\n",
                "\n",
                "dsprites_gt_path = os.path.join(GT_ROOT, \"dsprites\")\n",
                "\n",
                "dsprites = DSprites(dsprites_data_path)\n",
                "dsprites_gt_r2e_path = os.path.join(dsprites_gt_path, \"r2e\")\n",
                "dsprites_gt_r2r_path = os.path.join(dsprites_gt_path, \"r2r\")\n",
                "\n",
                "os.makedirs(dsprites_gt_r2e_path, exist_ok=True)\n",
                "os.makedirs(dsprites_gt_r2r_path, exist_ok=True)\n",
                "\n",
                "for case in range(3):\n",
                "    \n",
                "    os.makedirs(os.path.join(dsprites_gt_r2e_path, f\"{case}\"), exist_ok=True)\n",
                "    os.makedirs(os.path.join(dsprites_gt_r2r_path, f\"{case}\"), exist_ok=True)\n",
                "\n",
                "    r2e_dsprites_train, r2e_dsprites_test = r2e_dsprites(case=case, dataset=deepcopy(dsprites))\n",
                "    r2r_dsprites_train, r2r_dsprites_test = r2r_dsprites(case=case, dataset=deepcopy(dsprites))\n",
                "\n",
                "    for i, (x, y) in enumerate(r2e_dsprites_test):\n",
                "        save_image(x, os.path.join(dsprites_gt_r2e_path, f\"{case}/{i}.png\"))\n",
                "        \n",
                "    for i, (x, y) in enumerate(r2r_dsprites_test):\n",
                "        save_image(x, os.path.join(dsprites_gt_r2r_path, f\"{case}/{i}.png\"))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 3,
            "metadata": {},
            "outputs": [],
            "source": [
                "shapes3d_data_path = \"/data/data/3D_shapes_fast\"\n",
                "\n",
                "shapes3d_gt_path = os.path.join(GT_ROOT, \"shapes3d\")\n",
                "\n",
                "shapes3d = Shapes3d(shapes3d_data_path)\n",
                "\n",
                "shapes3d_gt_r2e_path = os.path.join(shapes3d_gt_path, \"r2e\")\n",
                "shapes3d_gt_r2r_path = os.path.join(shapes3d_gt_path, \"r2r\")\n",
                "\n",
                "os.makedirs(shapes3d_gt_r2e_path, exist_ok=True)\n",
                "os.makedirs(shapes3d_gt_r2r_path, exist_ok=True)\n",
                "\n",
                "for case in range(3):\n",
                "    os.makedirs(os.path.join(shapes3d_gt_r2e_path, f\"{case}\"), exist_ok=True)\n",
                "    os.makedirs(os.path.join(shapes3d_gt_r2r_path, f\"{case}\"), exist_ok=True)\n",
                "\n",
                "    r2e_shapes3d_train, r2e_shapes3d_test = r2e_shape3d(case=case, dataset=deepcopy(shapes3d))\n",
                "    r2r_shapes3d_train, r2r_shapes3d_test = r2r_shape3d(case=case, dataset=deepcopy(shapes3d))\n",
                "\n",
                "    imgs = []\n",
                "    for i, (x, y) in enumerate(r2e_shapes3d_test):\n",
                "        save_image(x, os.path.join(shapes3d_gt_r2e_path, f\"{case}/{i}.png\"))\n",
                "        \n",
                "    for i, (x, y) in enumerate(r2r_shapes3d_test):\n",
                "        save_image(x, os.path.join(shapes3d_gt_r2r_path, f\"{case}/{i}.png\"))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "base",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.7.3"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}