{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices_path = Path('tabularbench/data/train_val_test_indices.npy')\n",
    "indices = np.load(indices_path, allow_pickle=True).item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "keys are the openml_ids:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys([44132, 44133, 44134, 44136, 44137, 44138, 44139, 44140, 44141, 44142, 44143, 44144, 44145, 44146, 44147, 44148, 45032, 45033, 45034, 44089, 44120, 44121, 44122, 44123, 44125, 44126, 44128, 44129, 44130, 45022, 45021, 45020, 45019, 45028, 45026, 44055, 44056, 44059, 44061, 44062, 44063, 44065, 44066, 44068, 44069, 45041, 45042, 45043, 45045, 45046, 45047, 45048, 44156, 44157, 44159, 45035, 45036, 45038, 45039])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "indices.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Every openml dataset has different dataset sizes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{10000, 50000}\n",
      "{10000, 50000}\n"
     ]
    }
   ],
   "source": [
    "sizes_all = []\n",
    "\n",
    "for openml_id, indices_by_size in indices.items():\n",
    "    sizes = indices_by_size.keys()\n",
    "    sizes_all.append(sizes)\n",
    "\n",
    "print(set.intersection(*map(set, sizes_all)))\n",
    "print(set.union(*map(set, sizes_all)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Under every key, there is an array with every element belonging to a split:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{44132: (3, 3),\n",
       " 44133: (2, 2),\n",
       " 44134: (2, 2),\n",
       " 44136: (3, 3),\n",
       " 44137: (3, 3),\n",
       " 44138: (1, 2),\n",
       " 44139: (1, 2),\n",
       " 44140: (1, 1),\n",
       " 44141: (3, 3),\n",
       " 44142: (2, 2),\n",
       " 44143: (1, 1),\n",
       " 44144: (1, 2),\n",
       " 44145: (3, 3),\n",
       " 44146: (1, 1),\n",
       " 44147: (3, 3),\n",
       " 44148: (1, 2),\n",
       " 45032: (3, 3),\n",
       " 45033: (5, 5),\n",
       " 45034: (1, 1),\n",
       " 44089: (2, 2),\n",
       " 44120: (1, 1),\n",
       " 44121: (1, 1),\n",
       " 44122: (3, 3),\n",
       " 44123: (3, 3),\n",
       " 44125: (3, 3),\n",
       " 44126: (3, 3),\n",
       " 44128: (1, 1),\n",
       " 44129: (1, 1),\n",
       " 44130: (3, 3),\n",
       " 45022: (1, 1),\n",
       " 45021: (1, 1),\n",
       " 45020: (3, 3),\n",
       " 45019: (5, 5),\n",
       " 45028: (1, 2),\n",
       " 45026: (3, 3),\n",
       " 44055: (5, 5),\n",
       " 44056: (3, 3),\n",
       " 44059: (1, 1),\n",
       " 44061: (5, 5),\n",
       " 44062: (3, 3),\n",
       " 44063: (2, 2),\n",
       " 44065: (1, 1),\n",
       " 44066: (1, 2),\n",
       " 44068: (1, 1),\n",
       " 44069: (1, 1),\n",
       " 45041: (3, 3),\n",
       " 45042: (5, 5),\n",
       " 45043: (1, 1),\n",
       " 45045: (1, 1),\n",
       " 45046: (1, 1),\n",
       " 45047: (1, 1),\n",
       " 45048: (1, 1),\n",
       " 44156: (1, 1),\n",
       " 44157: (3, 3),\n",
       " 44159: (1, 1),\n",
       " 45035: (1, 1),\n",
       " 45036: (3, 3),\n",
       " 45038: (1, 1),\n",
       " 45039: (3, 3)}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "splits_per_openmlid = { k: (len(v[10000]), len(v[50000])) for k, v in indices.items() }\n",
    "splits_per_openmlid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For every split, we have a tuple with the indices for the training, validation and test data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2923 376 878\n",
      "[0, 3, 6, 7, 8, 9, 11, 12, 13, 16] [2, 4, 5, 10, 14, 30, 33, 39, 70, 72] [1, 15, 17, 22, 23, 31, 34, 36, 42, 45]\n"
     ]
    }
   ],
   "source": [
    "idcs = indices[45033][10000][0]\n",
    "print(len(idcs['train']), len(idcs['val']), len(idcs['test']))\n",
    "print(idcs['train'][:10], idcs['val'][:10], idcs['test'][:10])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There should be no overlapping indices:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "for openml_id, dataset in indices.items():\n",
    "    for size, splits in dataset.items():\n",
    "        for split in splits:\n",
    "            split_train = set(split['train'])\n",
    "            split_val = set(split['val'])\n",
    "            split_test = set(split['test'])\n",
    "\n",
    "            assert len(split_train & split_val) == 0\n",
    "            assert len(split_train & split_test) == 0\n",
    "            assert len(split_val & split_test) == 0        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Per dataset and size, we consider the values that are not in any of the splits:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{44132: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 44133: {10000: [0, 0], 50000: [0, 0]},\n",
       " 44134: {10000: [0, 0], 50000: [0, 0]},\n",
       " 44136: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 44137: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 44138: {10000: [0], 50000: [0, 0]},\n",
       " 44139: {10000: [0], 50000: [0, 0]},\n",
       " 44140: {10000: [0], 50000: [0]},\n",
       " 44141: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 44142: {10000: [0, 0], 50000: [0, 0]},\n",
       " 44143: {10000: [471834], 50000: [431835]},\n",
       " 44144: {10000: [0], 50000: [0, 0]},\n",
       " 44145: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 44146: {10000: [57145], 50000: [29146]},\n",
       " 44147: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 44148: {10000: [0], 50000: [0, 0]},\n",
       " 45032: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 45033: {10000: [0, 0, 0, 0, 0], 50000: [0, 0, 0, 0, 0]},\n",
       " 45034: {10000: [5355531], 50000: [5315569]},\n",
       " 44089: {10000: [0, 0], 50000: [0, 0]},\n",
       " 44120: {10000: [0], 50000: [0]},\n",
       " 44121: {10000: [456601], 50000: [416602]},\n",
       " 44122: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 44123: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 44125: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 44126: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 44128: {10000: [0], 50000: [0]},\n",
       " 44129: {10000: [830152], 50000: [790151]},\n",
       " 44130: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 45022: {10000: [0], 50000: [0]},\n",
       " 45021: {10000: [0], 50000: [0]},\n",
       " 45020: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 45019: {10000: [0, 0, 0, 0, 0], 50000: [0, 0, 0, 0, 0]},\n",
       " 45028: {10000: [0], 50000: [0, 0]},\n",
       " 45026: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 44055: {10000: [0, 0, 0, 0, 0], 50000: [0, 0, 0, 0, 0]},\n",
       " 44056: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 44059: {10000: [0], 50000: [0]},\n",
       " 44061: {10000: [0, 0, 0, 0, 0], 50000: [0, 0, 0, 0, 0]},\n",
       " 44062: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 44063: {10000: [0, 0], 50000: [0, 0]},\n",
       " 44065: {10000: [471834], 50000: [431835]},\n",
       " 44066: {10000: [0], 50000: [0, 0]},\n",
       " 44068: {10000: [284297], 50000: [244297]},\n",
       " 44069: {10000: [131596], 50000: [91600]},\n",
       " 45041: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 45042: {10000: [0, 0, 0, 0, 0], 50000: [0, 0, 0, 0, 0]},\n",
       " 45043: {10000: [0], 50000: [0]},\n",
       " 45045: {10000: [5355531], 50000: [5315569]},\n",
       " 45046: {10000: [78317], 50000: [46823]},\n",
       " 45047: {10000: [889963], 50000: [849989]},\n",
       " 45048: {10000: [57145], 50000: [29146]},\n",
       " 44156: {10000: [0], 50000: [0]},\n",
       " 44157: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 44159: {10000: [313672], 50000: [273673]},\n",
       " 45035: {10000: [0], 50000: [0]},\n",
       " 45036: {10000: [0, 0, 0], 50000: [0, 0, 0]},\n",
       " 45038: {10000: [21233], 50000: [0]},\n",
       " 45039: {10000: [0, 0, 0], 50000: [0, 0, 0]}}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_no_split_values = {}\n",
    "for openml_id, dataset in indices.items():\n",
    "    n_no_split_values[openml_id] = {}\n",
    "    for size, splits in dataset.items():\n",
    "        n_no_split_values[openml_id][size] = []\n",
    "        for split in splits:\n",
    "            split_combined = split['train'] + split['val'] + split['test']\n",
    "            max_index = max(split_combined) + 1\n",
    "            n_no_split_values[openml_id][size].append(max_index - len(split_combined))\n",
    "\n",
    "\n",
    "n_no_split_values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These are the datasets for which we don't use all the data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(44065, 10000),\n",
       " (44065, 50000),\n",
       " (44068, 10000),\n",
       " (44068, 50000),\n",
       " (44069, 10000),\n",
       " (44069, 50000),\n",
       " (44121, 10000),\n",
       " (44121, 50000),\n",
       " (44129, 10000),\n",
       " (44129, 50000),\n",
       " (44143, 10000),\n",
       " (44143, 50000),\n",
       " (44146, 10000),\n",
       " (44146, 50000),\n",
       " (44159, 10000),\n",
       " (44159, 50000),\n",
       " (45034, 10000),\n",
       " (45034, 50000),\n",
       " (45038, 10000),\n",
       " (45045, 10000),\n",
       " (45045, 50000),\n",
       " (45046, 10000),\n",
       " (45046, 50000),\n",
       " (45047, 10000),\n",
       " (45047, 50000),\n",
       " (45048, 10000),\n",
       " (45048, 50000)]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_too_large = []\n",
    "for openml_id, dataset in n_no_split_values.items():\n",
    "    for size, splits in dataset.items():\n",
    "        if np.mean(splits) > 0:\n",
    "            dataset_too_large.append((openml_id, size))\n",
    "\n",
    "sorted(dataset_too_large)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tabularbench",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
