{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f620d4eb-4b2c-47c9-be1f-b30fa2d9a4cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#!pip install tabpfn\n",
    "#!pip install tabicl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64f26c53-d1ed-428f-a4c1-67fae28518b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Paths shown reflect the default Jupyter Docker Stacks user directory (/home/jovyan).\n",
    "code_path = '/home/jovyan/code/'\n",
    "\n",
    "# source utility functions \n",
    "file_path = os.path.join(code_path, 'utility_functions_implementing_tabpfn_generators_iclr.py')\n",
    "with open(os.path.expanduser(file_path)) as file:\n",
    "    exec(file.read())\n",
    "\n",
    "# source additional utility functions \n",
    "file_path = os.path.join(code_path, 'additional_utility_functions_for_tabpfn_generators_iclr.py')\n",
    "with open(os.path.expanduser(file_path)) as file:\n",
    "    exec(file.read())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e989899b-1412-4120-b33b-8d0a16f088fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os                              # filesystem paths and directory creation\n",
    "import pandas as pd                    # Pandas\n",
    "import openml                          # OpenML API client\n",
    "from tqdm import tqdm                  # progress bar\n",
    "from sklearn.model_selection import train_test_split  # random data splitting\n",
    "import re\n",
    "import pyarrow.feather as feather\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8f692d2-d57c-4551-9469-077fed6db079",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select datasets with mostly categorical variables\n",
    "\n",
    "# Load the OpenML study suite\n",
    "suite = openml.study.get_suite(\"OpenML-CC18\")   # alias for study id 99\n",
    "task_ids = suite.tasks  # list of task IDs\n",
    "\n",
    "# Filtering thresholds\n",
    "n_max = 10000         # max rows\n",
    "n_cols_max = 500     # max columns\n",
    "n_class_max = 10     # max number of classes (levels) among categorical columns\n",
    "\n",
    "tasks_to_keep = []\n",
    "\n",
    "for tsk in tqdm(task_ids, desc=\"Dataset\"):\n",
    "    # Get task and underlying dataset\n",
    "    task = openml.tasks.get_task(tsk)\n",
    "    dataset = task.get_dataset()\n",
    "\n",
    "    # Get a SINGLE DataFrame with ALL columns (features + target)\n",
    "    # y will be None because we don't split targets out\n",
    "    X, y, categorical_mask, attr_names = dataset.get_data(\n",
    "        target=None, dataset_format=\"dataframe\"\n",
    "    )\n",
    "\n",
    "    # Basic shape\n",
    "    n_rows, n_cols = X.shape\n",
    "\n",
    "    # Identify non-numeric columns (category, string, object, bool, etc.)\n",
    "    non_num_cols = X.select_dtypes(exclude=\"number\").columns\n",
    "    n_non_num = len(non_num_cols)\n",
    "    n_num = n_cols - n_non_num\n",
    "\n",
    "    # Count levels (unique values) per non-numeric column; ignore NaN in the count\n",
    "    if n_non_num > 0:\n",
    "        levels_per_cat = X[non_num_cols].nunique(dropna=True)\n",
    "        max_classes = int(levels_per_cat.max())\n",
    "    else:\n",
    "        max_classes = 0  # no categorical columns\n",
    "\n",
    "    # Keep datasets that satisfy:\n",
    "    # - size constraints\n",
    "    # - more non-numeric than numeric columns\n",
    "    # - maximum number of categorical levels <= n_class_max\n",
    "    if (\n",
    "        (n_rows <= n_max)\n",
    "        and (n_cols <= n_cols_max)\n",
    "        and (n_num < n_non_num)\n",
    "        and (max_classes <= n_class_max)\n",
    "    ):\n",
    "        tasks_to_keep.append(tsk)\n",
    "\n",
    "# tasks_to_keep now has the filtered task IDs\n",
    "\n",
    "len(tasks_to_keep)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a0a2b1a-b037-4577-ab27-f8b0d341fe2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = summarize_openml_tasks(tasks_to_keep)\n",
    "summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e567ed2-efe9-4212-bb65-5fcd537b38f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "split_seeds = list(range(1, 11)) # 10 splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39bcf9b2-401b-4c37-ba35-14acae22561a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate one big file with all the original and holdout data splits\n",
    "\n",
    "t0 = time.perf_counter()\n",
    "_, splits_long = build_all_splits_tasks_cat(tasks_to_keep, split_seeds)\n",
    "elapsed = time.perf_counter() - t0\n",
    "\n",
    "# Save the file \n",
    "feather.write_feather(splits_long, \"/home/jovyan/selected_OpenMLCC18/outputs/openml_cc18_orig_hold_data_splits_categorical.feather\")\n",
    "\n",
    "print(f\"elapsed running time: {elapsed/60:.2f} minutes \" f\"(~{elapsed/3600:.2f} hours)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7273a67-b8ef-43b1-89da-c0fdba1b4457",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate one big file with all MIAV synthetic datasets\n",
    "\n",
    "t0 = time.perf_counter()\n",
    "syn_long, _ = build_all_synthetics_miav_tasks_cat(\n",
    "    tasks_to_keep=tasks_to_keep,\n",
    "    split_seeds=split_seeds,\n",
    "    pfn_method = 'tabpfn'\n",
    ")\n",
    "elapsed = time.perf_counter() - t0\n",
    "\n",
    "# Save the file\n",
    "feather.write_feather(syn_long, \"/home/jovyan/selected_OpenMLCC18/outputs/openml_cc18_syn_miav_tabpfn_categorical.feather\")\n",
    "\n",
    "print(f\"elapsed running time: {elapsed/60:.2f} minutes \" f\"(~{elapsed/3600:.2f} hours)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09cb434e-d860-4fa5-a467-477216896704",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate one big file with all MIAV synthetic datasets\n",
    "\n",
    "t0 = time.perf_counter()\n",
    "syn_long, _ = build_all_synthetics_miav_tasks_cat(\n",
    "    tasks_to_keep=tasks_to_keep,\n",
    "    split_seeds=split_seeds,\n",
    "    pfn_method = 'tabicl'\n",
    ")\n",
    "elapsed = time.perf_counter() - t0\n",
    "\n",
    "# Save the file\n",
    "feather.write_feather(syn_long, \"/home/jovyan/selected_OpenMLCC18/outputs/openml_cc18_syn_miav_tabicl_categorical.feather\")\n",
    "\n",
    "print(f\"elapsed running time: {elapsed/60:.2f} minutes \" f\"(~{elapsed/3600:.2f} hours)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb75fb84-b2bc-41d3-ad96-842d29ad396c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate one big file with all JF synthetic datasets\n",
    "\n",
    "t0 = time.perf_counter()\n",
    "syn_long, _ = build_all_synthetics_jf_tasks_cat(\n",
    "    tasks_to_keep=tasks_to_keep,\n",
    "    split_seeds=split_seeds,\n",
    "    pfn_method = 'tabpfn'\n",
    ")\n",
    "elapsed = time.perf_counter() - t0\n",
    "\n",
    "# Save the file\n",
    "feather.write_feather(syn_long, \"/home/jovyan/selected_OpenMLCC18/outputs/openml_cc18_syn_jf_tabpfn_categorical.feather\")\n",
    "\n",
    "print(f\"elapsed running time: {elapsed/60:.2f} minutes \" f\"(~{elapsed/3600:.2f} hours)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c99f0f67-2144-44ec-b03e-8ebf3cb2291b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate one big file with all JF synthetic datasets\n",
    "\n",
    "t0 = time.perf_counter()\n",
    "syn_long, _ = build_all_synthetics_jf_tasks_cat(\n",
    "    tasks_to_keep=tasks_to_keep,\n",
    "    split_seeds=split_seeds,\n",
    "    pfn_method = 'tabicl'\n",
    ")\n",
    "elapsed = time.perf_counter() - t0\n",
    "\n",
    "# Save the file\n",
    "feather.write_feather(syn_long, \"/home/jovyan/selected_OpenMLCC18/outputs/openml_cc18_syn_jf_tabicl_categorical.feather\")\n",
    "\n",
    "print(f\"elapsed running time: {elapsed/60:.2f} minutes \" f\"(~{elapsed/3600:.2f} hours)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b0e3d0f-dc8c-458d-8407-894367d00ddb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate one big file with all FC synthetic datasets\n",
    "\n",
    "t0 = time.perf_counter()\n",
    "syn_long, _ = build_all_synthetics_fc_tasks_cat(\n",
    "    tasks_to_keep=tasks_to_keep,\n",
    "    split_seeds=split_seeds,\n",
    "    pfn_method = 'tabpfn'\n",
    ")\n",
    "elapsed = time.perf_counter() - t0\n",
    "\n",
    "# Save the file\n",
    "feather.write_feather(syn_long, \"/home/jovyan/selected_OpenMLCC18/outputs/openml_cc18_syn_fc_tabpfn_categorical.feather\")\n",
    "\n",
    "print(f\"elapsed running time: {elapsed/60:.2f} minutes \" f\"(~{elapsed/3600:.2f} hours)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a02a698a-10d5-47b6-8c03-4f594bb80cd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate one big file with all FC synthetic datasets\n",
    "\n",
    "t0 = time.perf_counter()\n",
    "syn_long, _ = build_all_synthetics_fc_tasks_cat(\n",
    "    tasks_to_keep=tasks_to_keep,\n",
    "    split_seeds=split_seeds,\n",
    "    pfn_method = 'tabicl'\n",
    ")\n",
    "elapsed = time.perf_counter() - t0\n",
    "\n",
    "# Save the file\n",
    "feather.write_feather(syn_long, \"/home/jovyan/selected_OpenMLCC18/outputs/openml_cc18_syn_fc_tabicl_categorical.feather\")\n",
    "\n",
    "print(f\"elapsed running time: {elapsed/60:.2f} minutes \" f\"(~{elapsed/3600:.2f} hours)\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
