{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import json\n",
    "import random\n",
    "from pathlib import Path\n",
    "from typing import Callable, List, Iterable, Tuple\n",
    "\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import Dataset as TorchDataset\n",
    "from collections import Counter, defaultdict\n",
    "from torch.utils.data import Subset\n",
    "import os\n",
    "import torch\n",
    "\n",
    "_DATA_ROOT = './'\n",
    "\n",
    "\n",
    "class ShuffleIterator:\n",
    "    def __init__(self, iterable):\n",
    "        self.iterable = iterable\n",
    "        self.buffer = []\n",
    "        self.index = 0\n",
    "\n",
    "    def __iter__(self):\n",
    "        return self\n",
    "\n",
    "    def __next__(self):\n",
    "        if self.index >= len(self.buffer):\n",
    "            self.buffer = list(self.iterable)\n",
    "            random.shuffle(self.buffer)\n",
    "            self.index = 0\n",
    "\n",
    "        if not self.buffer:\n",
    "            self.buffer = list(self.iterable)\n",
    "            random.shuffle(self.buffer)\n",
    "\n",
    "        item = self.buffer[self.index]\n",
    "        self.index += 1\n",
    "        return item\n",
    "\n",
    "def MNIST_datasets():\n",
    "    transform = transforms.Compose(\n",
    "        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
    "    )\n",
    "\n",
    "    datasets = {\n",
    "        \"train\": torchvision.datasets.MNIST(\n",
    "            root=str(_DATA_ROOT), train=True, download=True, transform=transform\n",
    "        ),\n",
    "        \"test\": torchvision.datasets.MNIST(\n",
    "            root=str(_DATA_ROOT), train=False, download=True, transform=transform\n",
    "        ),\n",
    "    }\n",
    "    return datasets\n",
    "\n",
    "def EMNIST_datasets():\n",
    "    transform = transforms.Compose(\n",
    "        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
    "    )\n",
    "\n",
    "    target_labels = list(range(16))\n",
    "\n",
    "    datasets_full = {\n",
    "        \"train\": torchvision.datasets.EMNIST(\n",
    "            root=_DATA_ROOT, split='balanced', train=True, download=True, transform=transform\n",
    "        ),\n",
    "        \"test\": torchvision.datasets.EMNIST(\n",
    "            root=_DATA_ROOT, split='balanced', train=False, download=True, transform=transform\n",
    "        ),\n",
    "    }\n",
    "\n",
    "    datasets = {\n",
    "        \"train\": Subset(datasets_full[\"train\"], [\n",
    "            i for i, (_, label) in enumerate(datasets_full[\"train\"]) if label in target_labels\n",
    "        ]),\n",
    "        \"test\": Subset(datasets_full[\"test\"], [\n",
    "            i for i, (_, label) in enumerate(datasets_full[\"test\"]) if label in target_labels\n",
    "        ])\n",
    "    }\n",
    "    return datasets\n",
    "\n",
    "\n",
    "def get_datasets(dataset_name: str = \"MNIST\"):\n",
    "    if dataset_name == \"MNIST\":\n",
    "        return EMNIST_datasets()\n",
    "    \n",
    "def digits_to_number(digits: Iterable[int]) -> int:\n",
    "    number = 0\n",
    "    for d in digits:\n",
    "        number *= 10\n",
    "        number += d\n",
    "    return number\n",
    "\n",
    "def addition(n: int, dataset: str, seed=None, train: bool = True, z_list=None, sequence_num=30000):\n",
    "    \"\"\"Returns a dataset for binary addition\"\"\"\n",
    "    return DigitsOperator(\n",
    "        dataset_name=dataset,\n",
    "        function_name=\"addition\" if n == 1 else \"multi_addition\",\n",
    "        operator=sum,\n",
    "        size=n,\n",
    "        arity=2,\n",
    "        seed=seed,\n",
    "        train=train,\n",
    "        sequence_num=sequence_num,\n",
    "        z_list=z_list,\n",
    "    )\n",
    "\n",
    "class DigitsOperator(TorchDataset):\n",
    "    def __getitem__(self, index: int) -> Tuple[list, list, int]:\n",
    "        l1, l2 = self.data[index]\n",
    "        label = self._get_label(index)\n",
    "        l1 = [self.dataset[x][0] for x in l1]\n",
    "        l2 = [self.dataset[x][0] for x in l2]\n",
    "        return l1, l2, label\n",
    "\n",
    "    def balance_indices(self):\n",
    "        balance_size = sorted(Counter(self.dataset.labels).items())[0][1]\n",
    "        labels_dist = defaultdict(int)\n",
    "        sampler_iter = ShuffleIterator(list(range(len(self.dataset))))\n",
    "        balanced_indices = []\n",
    "        while len(balanced_indices) < balance_size * 10:\n",
    "            sample = next(sampler_iter)\n",
    "            sampled_class = self.dataset.labels[sample]\n",
    "            if labels_dist[sampled_class] >= balance_size:\n",
    "                continue\n",
    "            balanced_indices.append(sample)\n",
    "            labels_dist[sampled_class] += 1\n",
    "        return balanced_indices\n",
    "\n",
    "    def indices(self):\n",
    "        return list(range(len(self.dataset)))\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        dataset_name: str,\n",
    "        function_name: str,\n",
    "        operator: Callable[[List[int]], int],\n",
    "        size=1,\n",
    "        arity=2,\n",
    "        seed=None,\n",
    "        train: bool = True,\n",
    "        sequence_num: int = 30000,\n",
    "        z_list: List[int] = None,\n",
    "    ):\n",
    "        \"\"\"Generic dataset for operator(img, img) style datasets.\n",
    "\n",
    "        :param dataset_name: Dataset to use (train, val, test)\n",
    "        :param function_name: Name of Problog function to query.\n",
    "        :param operator: Operator to generate correct examples\n",
    "        :param size: Size of numbers (number of digits)\n",
    "        :param arity: Number of arguments for the operator\n",
    "        :param seed: Seed for RNG\n",
    "        :param z_list: List of digits to include in the dataset\n",
    "        \"\"\"\n",
    "        super(DigitsOperator, self).__init__()\n",
    "        assert size >= 1\n",
    "        assert arity >= 1\n",
    "        self.dataset_name = dataset_name\n",
    "        self.datasets = get_datasets(dataset_name)\n",
    "        self.dataset = self.datasets[\"train\" if train else \"test\"]\n",
    "        self.function_name = function_name\n",
    "        self.operator = operator\n",
    "        self.size = size\n",
    "        self.arity = arity\n",
    "        self.seed = seed\n",
    "        self.z_list = z_list\n",
    "        mnist_indices = self.indices()\n",
    "        \n",
    "        if seed is not None:\n",
    "            rng = random.Random(seed)\n",
    "            rng.shuffle(mnist_indices)\n",
    "        \n",
    "        # Filter indices based on z_list\n",
    "        if self.z_list is not None:\n",
    "            mnist_indices = [\n",
    "                idx for idx in mnist_indices if self.dataset[idx][1] in self.z_list\n",
    "            ]\n",
    "        \n",
    "        dataset_iter = ShuffleIterator(mnist_indices)\n",
    "        # Build list of examples (mnist indices)\n",
    "        self.data = []\n",
    "        try:\n",
    "            while len(self.data) < sequence_num:\n",
    "                example = [\n",
    "                    [next(dataset_iter) for _ in range(self.size)]\n",
    "                    for _ in range(self.arity)\n",
    "                ]\n",
    "                self.data.append(example)\n",
    "        except StopIteration:\n",
    "            pass\n",
    "\n",
    "    def to_file_repr(self, i):\n",
    "        \"\"\"Old file represenation dump. Not a very clear format as multi-digit arguments are not separated\"\"\"\n",
    "        return f\"{tuple(itertools.chain(*self.data[i]))}\\t{self._get_label(i)}\"\n",
    "\n",
    "    def to_json(self):\n",
    "        \"\"\"\n",
    "        Convert to JSON, for easy comparisons with other systems.\n",
    "\n",
    "        Format is [EXAMPLE, ...]\n",
    "        EXAMPLE :- [ARGS, expected_result]\n",
    "        ARGS :- [MULTI_DIGIT_NUMBER, ...]\n",
    "        MULTI_DIGIT_NUMBER :- [mnist_img_id, ...]\n",
    "        \"\"\"\n",
    "        data = [(self.data[i], self._get_label(i)) for i in range(len(self))]\n",
    "        return json.dumps(data)\n",
    "\n",
    "    def _get_label(self, i: int):\n",
    "        mnist_indices = self.data[i]\n",
    "        # Figure out what the ground truth is, first map each parameter to the value:\n",
    "        ground_truth = [\n",
    "            digits_to_number(self.dataset[j][1] for j in i) for i in mnist_indices\n",
    "        ]\n",
    "        # Then compute the expected value:\n",
    "        expected_result = self.operator(ground_truth)\n",
    "        return expected_result\n",
    "\n",
    "    def _get_symbol_label(self, i: int):\n",
    "        mnist_indices = self.data[i]\n",
    "        # Figure out what the ground truth is, first map each parameter to the value:\n",
    "        ground_truth = [self.dataset[j][1] for i in mnist_indices for j in i]\n",
    "        return ground_truth\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "def get_mnist_add(\n",
    "        train=True, \n",
    "        get_pseudo_label=False, \n",
    "        n=2, \n",
    "        sequence_num=30000, \n",
    "        alpha=0.2,\n",
    "        min_sequence_num=32,\n",
    "    ):\n",
    "    mnistDataset = addition(n, \"MNIST\", train=train, sequence_num=sequence_num)\n",
    "    all_X, all_Y, all_Z = [], [], []\n",
    "    for idx in range(len(mnistDataset)):\n",
    "        x1, x2, y = mnistDataset[idx]\n",
    "        z = mnistDataset._get_symbol_label(idx)\n",
    "        all_X.extend([x1 + x2]), all_Y.append(y), all_Z.extend([z])\n",
    "\n",
    "    z_list = list(range(10))  \n",
    "    sub_z_list = []\n",
    "    res = []\n",
    "    \n",
    "    while len(sub_z_list) < len(z_list):\n",
    "        X, Y, Z = [], [], []\n",
    "        new_elements = random.sample([elem for elem in z_list if elem not in sub_z_list], int(alpha * len(z_list)))\n",
    "        sub_z_list.extend(new_elements)\n",
    "        for x, y, z in zip(all_X, all_Y, all_Z):\n",
    "            if all(elem in sub_z_list for elem in z):\n",
    "                X.append(x)\n",
    "                Y.append(y)\n",
    "                Z.append(z)\n",
    "        if len(X) < min_sequence_num:\n",
    "            mnistDataset = addition(n, \"MNIST\", train=train, sequence_num=min_sequence_num - len(X), z_list=sub_z_list)\n",
    "            for idx in range(len(mnistDataset)):\n",
    "                x1, x2, y = mnistDataset[idx]\n",
    "                z = mnistDataset._get_symbol_label(idx)\n",
    "                X.extend([x1 + x2]), Y.append(y), Z.extend([z])\n",
    "        if get_pseudo_label:\n",
    "            res.append((X, Z, Y))\n",
    "        else:\n",
    "            res.append(X, None, Y)\n",
    "        print(sub_z_list, len(X))\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9\n",
      "14\n",
      "11\n",
      "3\n",
      "12\n",
      "7\n",
      "2\n",
      "6\n",
      "8\n",
      "4\n",
      "12\n",
      "12\n",
      "12\n",
      "1\n",
      "3\n",
      "5\n",
      "1\n",
      "14\n",
      "6\n",
      "0\n",
      "2\n",
      "3\n",
      "3\n",
      "12\n",
      "11\n",
      "13\n",
      "13\n",
      "1\n",
      "12\n",
      "5\n",
      "9\n",
      "11\n",
      "11\n",
      "11\n",
      "10\n",
      "10\n",
      "4\n",
      "4\n",
      "8\n",
      "8\n",
      "0\n",
      "3\n",
      "10\n",
      "11\n",
      "5\n",
      "15\n",
      "13\n",
      "10\n",
      "12\n",
      "5\n",
      "6\n",
      "12\n",
      "9\n",
      "12\n",
      "10\n",
      "9\n",
      "14\n",
      "5\n",
      "4\n",
      "2\n",
      "6\n",
      "10\n",
      "1\n",
      "14\n",
      "9\n",
      "13\n",
      "4\n",
      "3\n",
      "15\n",
      "13\n",
      "3\n",
      "0\n",
      "5\n",
      "6\n",
      "10\n",
      "11\n",
      "10\n",
      "0\n",
      "6\n",
      "5\n",
      "4\n",
      "12\n",
      "5\n",
      "10\n",
      "6\n",
      "10\n",
      "12\n",
      "0\n",
      "1\n",
      "9\n",
      "5\n",
      "13\n",
      "9\n",
      "2\n",
      "8\n",
      "10\n",
      "2\n",
      "15\n",
      "15\n",
      "3\n",
      "10\n",
      "8\n",
      "2\n",
      "15\n",
      "11\n",
      "5\n",
      "11\n",
      "2\n",
      "8\n",
      "15\n",
      "11\n",
      "7\n",
      "4\n",
      "8\n",
      "5\n",
      "2\n",
      "1\n",
      "13\n",
      "7\n",
      "2\n",
      "14\n",
      "2\n",
      "15\n",
      "12\n",
      "1\n",
      "11\n",
      "1\n",
      "5\n",
      "1\n",
      "7\n",
      "15\n",
      "5\n",
      "9\n",
      "6\n",
      "1\n",
      "9\n",
      "11\n",
      "5\n",
      "1\n",
      "7\n",
      "0\n",
      "0\n",
      "5\n",
      "14\n",
      "7\n",
      "15\n",
      "1\n",
      "6\n",
      "14\n",
      "4\n",
      "2\n",
      "0\n",
      "1\n",
      "10\n",
      "5\n",
      "1\n",
      "6\n",
      "5\n",
      "2\n",
      "7\n",
      "15\n",
      "0\n",
      "8\n",
      "6\n",
      "3\n",
      "9\n",
      "13\n",
      "6\n",
      "10\n",
      "11\n",
      "10\n",
      "4\n",
      "12\n",
      "1\n",
      "13\n",
      "4\n",
      "2\n",
      "8\n",
      "4\n",
      "2\n",
      "3\n",
      "9\n",
      "9\n",
      "15\n",
      "7\n",
      "11\n",
      "8\n",
      "10\n",
      "10\n",
      "1\n",
      "14\n",
      "9\n",
      "14\n",
      "15\n",
      "6\n",
      "2\n",
      "11\n",
      "4\n",
      "9\n",
      "4\n",
      "15\n",
      "4\n",
      "1\n",
      "8\n",
      "2\n",
      "11\n",
      "9\n",
      "2\n",
      "12\n",
      "6\n",
      "8\n",
      "11\n",
      "7\n",
      "1\n",
      "8\n",
      "4\n",
      "10\n",
      "11\n",
      "7\n",
      "6\n",
      "11\n",
      "11\n",
      "9\n",
      "12\n",
      "12\n",
      "2\n",
      "2\n",
      "7\n",
      "9\n",
      "14\n",
      "4\n",
      "6\n",
      "7\n",
      "3\n",
      "8\n",
      "10\n",
      "5\n",
      "4\n",
      "1\n",
      "1\n",
      "15\n",
      "4\n",
      "3\n",
      "15\n",
      "13\n",
      "14\n",
      "2\n",
      "6\n",
      "3\n",
      "11\n",
      "3\n",
      "9\n",
      "13\n",
      "3\n",
      "3\n",
      "9\n",
      "11\n",
      "11\n",
      "8\n",
      "1\n",
      "2\n",
      "1\n",
      "13\n",
      "3\n",
      "2\n",
      "4\n",
      "6\n",
      "4\n",
      "15\n",
      "7\n",
      "12\n",
      "11\n",
      "9\n",
      "12\n",
      "15\n",
      "3\n",
      "10\n",
      "3\n",
      "9\n",
      "10\n",
      "5\n",
      "14\n",
      "11\n",
      "8\n",
      "0\n",
      "14\n",
      "12\n",
      "12\n",
      "14\n",
      "9\n",
      "13\n",
      "9\n",
      "9\n",
      "15\n",
      "7\n",
      "14\n",
      "2\n",
      "2\n",
      "11\n",
      "7\n",
      "7\n",
      "3\n",
      "15\n",
      "6\n",
      "14\n",
      "0\n",
      "12\n",
      "1\n",
      "13\n",
      "12\n",
      "2\n",
      "6\n",
      "8\n",
      "2\n",
      "9\n",
      "10\n",
      "10\n",
      "9\n",
      "13\n",
      "13\n",
      "6\n",
      "11\n",
      "14\n",
      "4\n",
      "1\n",
      "7\n",
      "7\n",
      "9\n",
      "4\n",
      "13\n",
      "0\n",
      "13\n",
      "13\n",
      "3\n",
      "5\n",
      "2\n",
      "12\n",
      "2\n",
      "2\n",
      "12\n",
      "8\n",
      "15\n",
      "5\n",
      "12\n",
      "12\n",
      "4\n",
      "4\n",
      "3\n",
      "7\n",
      "6\n",
      "6\n",
      "0\n",
      "7\n",
      "5\n",
      "12\n",
      "15\n",
      "12\n",
      "13\n",
      "14\n",
      "14\n",
      "3\n",
      "4\n",
      "7\n",
      "9\n",
      "8\n",
      "7\n",
      "0\n",
      "15\n",
      "8\n",
      "1\n",
      "5\n",
      "10\n",
      "8\n",
      "7\n",
      "11\n",
      "7\n",
      "14\n",
      "13\n",
      "6\n",
      "6\n",
      "15\n",
      "11\n",
      "3\n",
      "10\n",
      "10\n",
      "10\n",
      "5\n",
      "13\n",
      "9\n",
      "5\n",
      "4\n",
      "8\n",
      "1\n",
      "5\n",
      "14\n",
      "1\n",
      "2\n",
      "2\n",
      "10\n",
      "3\n",
      "15\n",
      "11\n",
      "2\n",
      "11\n",
      "7\n",
      "0\n",
      "1\n",
      "3\n",
      "11\n",
      "7\n",
      "9\n",
      "11\n",
      "13\n",
      "0\n",
      "10\n",
      "2\n",
      "15\n",
      "1\n",
      "11\n",
      "10\n",
      "7\n",
      "1\n",
      "8\n",
      "3\n",
      "1\n",
      "1\n",
      "1\n",
      "8\n",
      "1\n",
      "11\n",
      "0\n",
      "1\n",
      "5\n",
      "10\n",
      "11\n",
      "14\n",
      "12\n",
      "4\n",
      "5\n",
      "8\n",
      "3\n",
      "0\n",
      "3\n",
      "4\n",
      "4\n",
      "15\n",
      "0\n",
      "9\n",
      "1\n",
      "9\n",
      "9\n",
      "9\n",
      "13\n",
      "0\n",
      "1\n",
      "8\n",
      "6\n",
      "11\n",
      "14\n",
      "2\n",
      "14\n",
      "10\n",
      "11\n",
      "6\n",
      "2\n",
      "15\n",
      "10\n",
      "14\n",
      "13\n",
      "0\n",
      "7\n",
      "5\n",
      "7\n",
      "3\n",
      "0\n",
      "2\n",
      "10\n",
      "3\n",
      "0\n",
      "2\n",
      "10\n",
      "3\n",
      "3\n",
      "9\n",
      "4\n",
      "15\n",
      "0\n",
      "3\n",
      "15\n",
      "15\n",
      "11\n",
      "15\n",
      "15\n",
      "6\n",
      "7\n",
      "11\n",
      "12\n",
      "15\n",
      "2\n",
      "12\n",
      "4\n",
      "2\n",
      "13\n",
      "11\n",
      "14\n",
      "6\n",
      "7\n",
      "11\n",
      "0\n",
      "14\n",
      "3\n",
      "14\n",
      "4\n",
      "1\n",
      "3\n",
      "2\n",
      "8\n",
      "6\n",
      "3\n",
      "8\n",
      "3\n",
      "4\n",
      "14\n",
      "15\n",
      "6\n",
      "10\n",
      "1\n",
      "11\n",
      "2\n",
      "9\n",
      "6\n",
      "14\n",
      "0\n",
      "4\n",
      "1\n",
      "0\n",
      "3\n",
      "6\n",
      "5\n",
      "0\n",
      "14\n",
      "15\n",
      "10\n",
      "8\n",
      "1\n",
      "9\n",
      "13\n",
      "15\n",
      "15\n",
      "0\n",
      "12\n",
      "1\n",
      "6\n",
      "13\n",
      "4\n",
      "3\n",
      "8\n",
      "2\n",
      "6\n",
      "11\n",
      "10\n",
      "5\n",
      "9\n",
      "9\n",
      "14\n",
      "13\n",
      "1\n",
      "8\n",
      "6\n",
      "6\n",
      "12\n",
      "9\n",
      "12\n",
      "4\n",
      "6\n",
      "3\n",
      "4\n",
      "2\n",
      "2\n",
      "6\n",
      "6\n",
      "1\n",
      "14\n",
      "7\n",
      "14\n",
      "13\n",
      "14\n",
      "3\n",
      "7\n",
      "5\n",
      "5\n",
      "13\n",
      "9\n",
      "5\n",
      "8\n",
      "10\n",
      "10\n",
      "14\n",
      "5\n",
      "15\n",
      "15\n",
      "0\n",
      "0\n",
      "15\n",
      "6\n",
      "14\n",
      "3\n",
      "4\n",
      "5\n",
      "10\n",
      "1\n",
      "7\n",
      "4\n",
      "2\n",
      "8\n",
      "2\n",
      "10\n",
      "8\n",
      "3\n",
      "4\n",
      "0\n",
      "1\n",
      "9\n",
      "6\n",
      "2\n",
      "2\n",
      "15\n",
      "2\n",
      "10\n",
      "3\n",
      "15\n",
      "11\n",
      "9\n",
      "1\n",
      "11\n",
      "14\n",
      "0\n",
      "6\n",
      "12\n",
      "12\n",
      "0\n",
      "11\n",
      "7\n",
      "13\n",
      "3\n",
      "5\n",
      "14\n",
      "14\n",
      "10\n",
      "5\n",
      "2\n",
      "14\n",
      "12\n",
      "2\n",
      "0\n",
      "9\n",
      "0\n",
      "5\n",
      "13\n",
      "6\n",
      "11\n",
      "2\n",
      "9\n",
      "1\n",
      "7\n",
      "8\n",
      "3\n",
      "10\n",
      "13\n",
      "9\n",
      "2\n",
      "14\n",
      "6\n",
      "11\n",
      "14\n",
      "2\n",
      "2\n",
      "0\n",
      "7\n",
      "6\n",
      "3\n",
      "10\n",
      "12\n",
      "12\n",
      "9\n",
      "8\n",
      "2\n",
      "8\n",
      "2\n",
      "15\n",
      "15\n",
      "13\n",
      "14\n",
      "1\n",
      "9\n",
      "11\n",
      "3\n",
      "10\n",
      "6\n",
      "15\n",
      "3\n",
      "1\n",
      "4\n",
      "13\n",
      "0\n",
      "2\n",
      "8\n",
      "1\n",
      "10\n",
      "1\n",
      "4\n",
      "1\n",
      "14\n",
      "9\n",
      "1\n",
      "6\n",
      "5\n",
      "0\n",
      "1\n",
      "11\n",
      "6\n",
      "8\n",
      "3\n",
      "8\n",
      "14\n",
      "7\n",
      "13\n",
      "0\n",
      "10\n",
      "10\n",
      "9\n",
      "7\n",
      "6\n",
      "4\n",
      "7\n",
      "15\n",
      "0\n",
      "5\n",
      "14\n",
      "8\n",
      "8\n",
      "14\n",
      "8\n",
      "0\n",
      "3\n",
      "13\n",
      "4\n",
      "15\n",
      "12\n",
      "7\n",
      "0\n",
      "5\n",
      "13\n",
      "6\n",
      "14\n",
      "11\n",
      "7\n",
      "1\n",
      "14\n",
      "14\n",
      "1\n",
      "0\n",
      "9\n",
      "8\n",
      "2\n",
      "5\n",
      "7\n",
      "9\n",
      "9\n",
      "11\n",
      "6\n",
      "5\n",
      "0\n",
      "6\n",
      "3\n",
      "12\n",
      "8\n",
      "11\n",
      "13\n",
      "14\n",
      "6\n",
      "8\n",
      "1\n",
      "0\n",
      "13\n",
      "9\n",
      "7\n",
      "5\n",
      "4\n",
      "15\n",
      "13\n",
      "9\n",
      "12\n",
      "15\n",
      "13\n",
      "13\n",
      "4\n",
      "12\n",
      "2\n",
      "4\n",
      "15\n",
      "11\n",
      "0\n",
      "0\n",
      "8\n",
      "14\n",
      "8\n",
      "12\n",
      "11\n",
      "0\n",
      "2\n",
      "15\n",
      "0\n",
      "3\n",
      "6\n",
      "13\n",
      "3\n",
      "7\n",
      "8\n",
      "13\n",
      "13\n",
      "11\n",
      "2\n",
      "7\n",
      "4\n",
      "10\n",
      "15\n",
      "6\n",
      "5\n",
      "12\n",
      "1\n",
      "9\n",
      "1\n",
      "11\n",
      "11\n",
      "3\n",
      "7\n",
      "15\n",
      "6\n",
      "7\n",
      "7\n",
      "2\n",
      "12\n",
      "1\n",
      "0\n",
      "3\n",
      "13\n",
      "8\n",
      "2\n",
      "11\n",
      "7\n",
      "10\n",
      "1\n",
      "13\n",
      "3\n",
      "2\n",
      "13\n",
      "15\n",
      "12\n",
      "15\n",
      "14\n",
      "0\n",
      "4\n",
      "8\n",
      "4\n",
      "11\n",
      "7\n",
      "2\n",
      "9\n",
      "7\n",
      "15\n",
      "11\n",
      "8\n",
      "2\n",
      "13\n",
      "5\n",
      "0\n",
      "12\n",
      "8\n",
      "6\n",
      "0\n",
      "6\n",
      "1\n",
      "14\n",
      "14\n",
      "0\n",
      "15\n",
      "9\n",
      "9\n",
      "9\n",
      "6\n",
      "15\n",
      "4\n",
      "9\n",
      "4\n",
      "1\n",
      "0\n",
      "2\n",
      "15\n",
      "10\n",
      "3\n",
      "8\n",
      "4\n",
      "7\n",
      "6\n",
      "5\n",
      "2\n",
      "3\n",
      "5\n",
      "12\n",
      "3\n",
      "4\n",
      "4\n",
      "15\n",
      "10\n",
      "0\n",
      "3\n",
      "13\n",
      "11\n",
      "3\n",
      "15\n",
      "1\n",
      "15\n",
      "7\n",
      "8\n",
      "4\n",
      "8\n",
      "12\n",
      "12\n",
      "13\n",
      "13\n",
      "10\n",
      "13\n",
      "5\n",
      "1\n",
      "4\n",
      "9\n",
      "0\n",
      "10\n",
      "3\n",
      "4\n",
      "3\n",
      "11\n",
      "0\n",
      "0\n",
      "4\n",
      "9\n",
      "1\n",
      "13\n",
      "10\n",
      "2\n",
      "12\n",
      "1\n",
      "1\n",
      "4\n",
      "6\n",
      "5\n",
      "3\n",
      "11\n",
      "11\n",
      "5\n",
      "6\n",
      "7\n",
      "14\n",
      "3\n",
      "4\n",
      "13\n",
      "8\n",
      "14\n",
      "9\n",
      "10\n",
      "7\n",
      "15\n",
      "9\n",
      "7\n",
      "3\n",
      "1\n",
      "11\n",
      "12\n",
      "11\n",
      "0\n",
      "3\n",
      "9\n",
      "9\n",
      "8\n",
      "2\n",
      "7\n",
      "0\n",
      "13\n",
      "6\n",
      "2\n",
      "0\n",
      "7\n",
      "15\n",
      "5\n",
      "13\n",
      "3\n",
      "15\n",
      "7\n",
      "11\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Counter({9: 67,\n",
       "         14: 58,\n",
       "         11: 69,\n",
       "         3: 70,\n",
       "         12: 53,\n",
       "         7: 60,\n",
       "         2: 72,\n",
       "         6: 64,\n",
       "         8: 58,\n",
       "         4: 61,\n",
       "         1: 70,\n",
       "         5: 53,\n",
       "         0: 63,\n",
       "         13: 59,\n",
       "         10: 58,\n",
       "         15: 65})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import Counter\n",
    "data_counts = Counter()\n",
    "datasets = get_datasets()['test']\n",
    "for i in range(1000):\n",
    "    tmp = datasets[i][1]\n",
    "    print(tmp)\n",
    "    data_counts[tmp] += 1\n",
    "data_counts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "38400"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(datasets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "abl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
