{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "06c0d564",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fd1982ef610>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import random\n",
    "import ipywidgets as widgets\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from src.datasets import data, configs, utils\n",
    "from src.datasets.utils import PreGeneratedDataset\n",
    "\n",
    "from torchvision import transforms as transforms\n",
    "\n",
    "seed = 43\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1230396-0575-4a4b-aa80-7b7a826b3430",
   "metadata": {},
   "source": [
    "The dataset was created in the similar manner as you could see below. \n",
    "The exact version can be downloaded at [will be added later]."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9554c463",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filtering code, to remove overlapping objects OOD\n",
    "def filter_objects(latents, max_samples=5000, threshold=0.2, sort=False):\n",
    "    \"\"\"\n",
    "    Filter objects based on their Euclidean distance.\n",
    "    Args:\n",
    "        latents: Tensor of shape (batch_size, n_slots, n_latents)\n",
    "        max_objects: Number of objects to keep at most\n",
    "        threshold: Distance threshold\n",
    "        sort: Whether to sort the objects by distance\n",
    "    \"\"\"\n",
    "    N, slots, _ = latents.size()\n",
    "    mask = torch.zeros(N, dtype=bool)\n",
    "\n",
    "    # Compute Euclidean distance for each pair of slots in each item\n",
    "    for n in range(N):\n",
    "        slots_distances = torch.cdist(latents[n, :, :2], latents[n, :, :2], p=2)\n",
    "        slots_distances.fill_diagonal_(float(\"inf\"))  # Ignore distance to self\n",
    "\n",
    "        # Only keep samples in which no two objects are closer than the threshold\n",
    "        min_distance = slots_distances.min().item()\n",
    "        if min_distance >= threshold:\n",
    "            mask[n] = True\n",
    "\n",
    "    # If all objects are \"close\", print a message and return\n",
    "    if not torch.any(mask):\n",
    "        print(\"No objects were found that meet the distance threshold.\")\n",
    "        return None, []\n",
    "\n",
    "    # Apply the mask to the latents\n",
    "    filtered_samples = latents[mask]\n",
    "    filtered_indices = torch.arange(N)[mask]\n",
    "\n",
    "    # If the number of filtered samples exceeds the maximum, truncate them\n",
    "    if filtered_samples.size(0) > max_samples:\n",
    "        filtered_samples = filtered_samples[:max_samples]\n",
    "        filtered_indices = filtered_indices[:max_samples]\n",
    "\n",
    "    # FIXME this part could be made more efficient and readable by saving the\n",
    "    #   min distances in the step above and reusing them here\n",
    "    # FIXME setting sort=True is throwing some errors for me\n",
    "    if sort:\n",
    "        # Sort the filtered objects by minimum distance to any other object\n",
    "        min_distances = torch.zeros(mask.sum().item())\n",
    "        for i, n in enumerate(torch.where(mask)[0]):\n",
    "            slots_distances = torch.cdist(latents[n], latents[n], p=2)\n",
    "            slots_distances.fill_diagonal_(float(\"inf\"))\n",
    "            min_distances[i] = slots_distances.min().item()\n",
    "\n",
    "        indices = torch.argsort(min_distances)\n",
    "        filtered_samples = filtered_samples[indices]\n",
    "        filtered_indices = filtered_indices[indices]\n",
    "\n",
    "    return filtered_samples, filtered_indices.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "aba3c949",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating images (sampling: off_diagonal): 100%|██████████| 10000/10000 [01:31<00:00, 109.61it/s]\n"
     ]
    }
   ],
   "source": [
    "# Create a OOD dataset\n",
    "n_samples = 10000\n",
    "n_slots = 2\n",
    "default_cfg = configs.SpriteWorldConfig()\n",
    "sample_mode = \"off_diagonal\"\n",
    "no_overlap = True\n",
    "delta = 0.125\n",
    "\n",
    "off_diagonal_dataset = data.SpriteWorldDataset(\n",
    "    n_samples,\n",
    "    n_slots,\n",
    "    default_cfg,\n",
    "    sample_mode=sample_mode,\n",
    "    no_overlap=no_overlap,\n",
    "    delta=delta,\n",
    "    transform=transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()]),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "938e1eb6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1e7096e346324af2a986b6a9efd1724f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(IntSlider(value=0, description='index', max=9999), Output()), _dom_classes=('widget-inte…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<function __main__.<lambda>(index)>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Checking the generated dataset\n",
    "def display_data(index, dataset):\n",
    "    plt.imshow(dataset[index][0][-1].permute(1, 2, 0))\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "dispaly_off_diagonal = lambda index: display_data(index, dataset=off_diagonal_dataset)\n",
    "\n",
    "num_samples = len(off_diagonal_dataset)\n",
    "\n",
    "# slider\n",
    "widgets.interact(\n",
    "    dispaly_off_diagonal,\n",
    "    index=widgets.IntSlider(min=0, max=num_samples - 1, step=1, value=0),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "2da8741b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filter the dataset\n",
    "n_objects = 5000\n",
    "_, indicies = filter_objects(\n",
    "    off_diagonal_dataset.z, max_samples=n_objects, threshold=0.2\n",
    ")\n",
    "\n",
    "# save the filtered dataset\n",
    "ood_data_path = \"YOUR PATH\"\n",
    "\n",
    "os.makedirs(os.path.join(path, \"images\"), exists=True)\n",
    "os.makedirs(os.path.join(path, \"latents\"), exists=True)\n",
    "torch.save(off_diagonal_dataset.x[indicies], os.path.join(ood_data_path, \"images\", \"images.pt\"))\n",
    "torch.save(\n",
    "    torch.cat(\n",
    "        [\n",
    "            off_diagonal_dataset.z[indicies, :, :4],\n",
    "            off_diagonal_dataset.z[indicies, :, 5:-2],\n",
    "        ],\n",
    "        dim=-1,\n",
    "    ),\n",
    "    os.path.join(path, \"latents\", \"latents.pt\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "6c0d07dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6090f8f9aafc4d05a6eb56109ad80e94",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(IntSlider(value=0, description='index', max=4999), Output()), _dom_classes=('widget-inte…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<function __main__.<lambda>(index)>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Checking the generated dataset\n",
    "no_overlaps_ood = PreGeneratedDataset(data_path)\n",
    "\n",
    "dispaly_no_overlaps_ood = lambda index: display_data(index, dataset=no_overlaps_ood)\n",
    "\n",
    "# slider\n",
    "widgets.interact(\n",
    "    dispaly_no_overlaps_ood,\n",
    "    index=widgets.IntSlider(min=0, max=n_objects - 1, step=1, value=0),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6fd22c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generating the ID dataset (test)\n",
    "delta = 0.125\n",
    "sample_mode = \"diagonal\"\n",
    "n_slots = 2\n",
    "n_samples = 5000\n",
    "no_overlap = True\n",
    "test_diagonal_dataset = data.SpriteWorldDataset(n_samples, n_slots, default_cfg, sample_mode=sample_mode, \n",
    "                                            no_overlap=no_overlap,\n",
    "                                            delta=delta)\n",
    "\n",
    "utils.dump_generated_dataset(test_diagonal_dataset, \"your path to test diagonal\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22476f69-2611-4451-ba20-7d5c32097130",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generating the ID dataset (train)\n",
    "delta = 0.125\n",
    "sample_mode = \"diagonal\"\n",
    "n_slots = 2\n",
    "n_samples = 100000\n",
    "no_overlap = True\n",
    "train_diagonal_dataset = data.SpriteWorldDataset(n_samples, n_slots, default_cfg, sample_mode=sample_mode, \n",
    "                                            no_overlap=no_overlap,\n",
    "                                            delta=delta)\n",
    "\n",
    "utils.dump_generated_dataset(train_diagonal_dataset, \"your path to train diagonal\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
