{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "535ee8e5-91fc-4a09-95a6-49e788ba6bf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Paths shown reflect the default Jupyter Docker Stacks user directory (/home/jovyan).\n",
    "code_path = '/home/jovyan/code/'\n",
    "\n",
    "# source utility functions \n",
    "file_path = os.path.join(code_path, 'utility_functions_implementing_tabpfn_generators_iclr.py')\n",
    "with open(os.path.expanduser(file_path)) as file:\n",
    "    exec(file.read())\n",
    "\n",
    "# source additional utility functions \n",
    "file_path = os.path.join(code_path, 'additional_utility_functions_for_tabpfn_generators_iclr.py')\n",
    "with open(os.path.expanduser(file_path)) as file:\n",
    "    exec(file.read())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "492461c9-f150-4e9e-995a-723370773973",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "from sklearn.model_selection import train_test_split\n",
    "import openml\n",
    "import pyarrow.feather as feather"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f3eae90-71d3-42f3-901e-25af57b3efc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_all_splits(X, split_seeds, ds_name, task_id):\n",
    "    \"\"\"\n",
    "    Create all data-splits for each dataframe X and return:\n",
    "      - splits_by_task: {task_id: {\"dataset_name\": str,\n",
    "                                   \"splits\": [{\"split\": j, \"orig\": df, \"hold\": df}, ...]}}\n",
    "      - splits_long: one big DataFrame where *dataset columns are prefixed*\n",
    "                     so no cross-dataset name collisions occur.\n",
    "        Metadata columns: __dataset__, __task_id__, __split__, __role__\n",
    "    \"\"\"\n",
    "    long_chunks = []\n",
    "    \n",
    "    # --- generate splits for this dataset ---\n",
    "    ds_splits = []\n",
    "    for j, seed in enumerate(split_seeds, start=1):\n",
    "        X_orig, X_hold = train_test_split(X, test_size=0.5, random_state=seed)\n",
    "\n",
    "        # remove rows with NA values \n",
    "        X_orig = X_orig.dropna().reset_index(drop=True)\n",
    "        X_hold = X_hold.dropna().reset_index(drop=True)\n",
    "\n",
    "        # store per-dataset copies in the Python structure\n",
    "        ds_splits.append({\"split\": j, \"orig\": X_orig, \"hold\": X_hold})\n",
    "\n",
    "        # add copies to the long-form table\n",
    "        for role, df in ((\"orig\", X_orig), (\"hold\", X_hold)):\n",
    "            chunk = df.copy()\n",
    "            # metadata columns (placed in front)\n",
    "            chunk.insert(0, \"__role__\", role)\n",
    "            chunk.insert(0, \"__split__\", j)\n",
    "            chunk.insert(0, \"__task_id__\", task_id)\n",
    "            chunk.insert(0, \"__dataset__\", ds_name)\n",
    "            long_chunks.append(chunk)\n",
    "\n",
    "    splits_long = pd.concat(long_chunks, axis=0, ignore_index=True)\n",
    "    return splits_long\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b15d07fb-ec81-406a-89b1-939c7f070559",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_all_synthetics_miav(\n",
    "    X: pd.DataFrame,\n",
    "    split_seeds,\n",
    "    generator_kwargs=None,\n",
    "    *,\n",
    "    task_id: int = 0,\n",
    "    ds_name: str = \"\",\n",
    "):\n",
    "    \"\"\"\n",
    "    For each split seed:\n",
    "      - take a 50/50 split (use the 'orig' half),\n",
    "      - drop rows with NA,\n",
    "      - synthesize with miav_tabpfn_generator,\n",
    "      - add metadata columns (no column prefixing),\n",
    "    then stack everything into one long DataFrame.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    syn_long : pd.DataFrame\n",
    "        Rows from all splits, with metadata columns first:\n",
    "        [__dataset__, __task_id__, __split__, __role__] + original data columns\n",
    "    failures : list[dict]\n",
    "        Any split-level errors: {'task_id', 'split', 'role', 'error'}\n",
    "    \"\"\"\n",
    "    if generator_kwargs is None:\n",
    "        generator_kwargs = {}\n",
    "\n",
    "    all_chunks = []\n",
    "    failures = []\n",
    "\n",
    "    for j, seed in tqdm(enumerate(split_seeds, start=1), desc='Data split', total=len(split_seeds)):\n",
    "        try:\n",
    "            X_orig, _ = train_test_split(X, test_size=0.5, random_state=seed)\n",
    "\n",
    "            # remove rows with NA values\n",
    "            X_orig = X_orig.dropna().reset_index(drop=True)\n",
    "            if X_orig.empty:\n",
    "                continue\n",
    "\n",
    "            # generate synthetic copy of the original half\n",
    "            X_syn = miav_tabpfn_generator(X_orig, **generator_kwargs)\n",
    "\n",
    "            # add metadata (insert in reverse so final left-to-right order is desired)\n",
    "            chunk = X_syn.copy()\n",
    "            chunk.insert(0, \"__role__\", \"syn\")\n",
    "            chunk.insert(0, \"__split__\", np.int32(j))\n",
    "            chunk.insert(0, \"__task_id__\", np.int32(task_id))\n",
    "            chunk.insert(0, \"__dataset__\", str(ds_name))\n",
    "\n",
    "            all_chunks.append(chunk)\n",
    "\n",
    "        except Exception as e:\n",
    "            failures.append({\n",
    "                \"task_id\": task_id,\n",
    "                \"split\": j,\n",
    "                \"role\": \"syn\",\n",
    "                \"error\": repr(e),\n",
    "            })\n",
    "\n",
    "    if all_chunks:\n",
    "        syn_long = pd.concat(all_chunks, axis=0, ignore_index=True)\n",
    "    else:\n",
    "        syn_long = pd.DataFrame(columns=[\"__dataset__\", \"__task_id__\", \"__split__\", \"__role__\"])\n",
    "\n",
    "    return syn_long, failures\n",
    "\n",
    "\n",
    "\n",
    "def build_all_synthetics_jf(\n",
    "    X: pd.DataFrame,\n",
    "    split_seeds,\n",
    "    generator_kwargs=None,\n",
    "    *,\n",
    "    task_id: int = 0,\n",
    "    ds_name: str = \"\",\n",
    "):\n",
    "    \"\"\"\n",
    "    For each split seed:\n",
    "      - take a 50/50 split (use the 'orig' half),\n",
    "      - drop rows with NA,\n",
    "      - synthesize with miav_tabpfn_generator,\n",
    "      - add metadata columns (no column prefixing),\n",
    "    then stack everything into one long DataFrame.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    syn_long : pd.DataFrame\n",
    "        Rows from all splits, with metadata columns first:\n",
    "        [__dataset__, __task_id__, __split__, __role__] + original data columns\n",
    "    failures : list[dict]\n",
    "        Any split-level errors: {'task_id', 'split', 'role', 'error'}\n",
    "    \"\"\"\n",
    "    if generator_kwargs is None:\n",
    "        generator_kwargs = {}\n",
    "\n",
    "    all_chunks = []\n",
    "    failures = []\n",
    "\n",
    "    for j, seed in tqdm(enumerate(split_seeds, start=1), desc='Data split', total=len(split_seeds)):\n",
    "        try:\n",
    "            X_orig, _ = train_test_split(X, test_size=0.5, random_state=seed)\n",
    "\n",
    "            # remove rows with NA values\n",
    "            X_orig = X_orig.dropna().reset_index(drop=True)\n",
    "            if X_orig.empty:\n",
    "                continue\n",
    "\n",
    "            # generate synthetic copy of the original half\n",
    "            X_syn = joint_factorization_tabpfn_generator(X_orig, **generator_kwargs)\n",
    "\n",
    "            # add metadata (insert in reverse so final left-to-right order is desired)\n",
    "            chunk = X_syn.copy()\n",
    "            chunk.insert(0, \"__role__\", \"syn\")\n",
    "            chunk.insert(0, \"__split__\", np.int32(j))\n",
    "            chunk.insert(0, \"__task_id__\", np.int32(task_id))\n",
    "            chunk.insert(0, \"__dataset__\", str(ds_name))\n",
    "\n",
    "            all_chunks.append(chunk)\n",
    "\n",
    "        except Exception as e:\n",
    "            failures.append({\n",
    "                \"task_id\": task_id,\n",
    "                \"split\": j,\n",
    "                \"role\": \"syn\",\n",
    "                \"error\": repr(e),\n",
    "            })\n",
    "\n",
    "    if all_chunks:\n",
    "        syn_long = pd.concat(all_chunks, axis=0, ignore_index=True)\n",
    "    else:\n",
    "        syn_long = pd.DataFrame(columns=[\"__dataset__\", \"__task_id__\", \"__split__\", \"__role__\"])\n",
    "\n",
    "    return syn_long, failures\n",
    "\n",
    "\n",
    "\n",
    "def build_all_synthetics_fc(\n",
    "    X: pd.DataFrame,\n",
    "    split_seeds,\n",
    "    generator_kwargs=None,\n",
    "    *,\n",
    "    task_id: int = 0,\n",
    "    ds_name: str = \"\",\n",
    "):\n",
    "    \"\"\"\n",
    "    For each split seed:\n",
    "      - take a 50/50 split (use the 'orig' half),\n",
    "      - drop rows with NA,\n",
    "      - synthesize with miav_tabpfn_generator,\n",
    "      - add metadata columns (no column prefixing),\n",
    "    then stack everything into one long DataFrame.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    syn_long : pd.DataFrame\n",
    "        Rows from all splits, with metadata columns first:\n",
    "        [__dataset__, __task_id__, __split__, __role__] + original data columns\n",
    "    failures : list[dict]\n",
    "        Any split-level errors: {'task_id', 'split', 'role', 'error'}\n",
    "    \"\"\"\n",
    "    if generator_kwargs is None:\n",
    "        generator_kwargs = {}\n",
    "\n",
    "    all_chunks = []\n",
    "    failures = []\n",
    "\n",
    "    for j, seed in tqdm(enumerate(split_seeds, start=1), desc='Data split', total=len(split_seeds)):\n",
    "        try:\n",
    "            X_orig, _ = train_test_split(X, test_size=0.5, random_state=seed)\n",
    "\n",
    "            # remove rows with NA values\n",
    "            X_orig = X_orig.dropna().reset_index(drop=True)\n",
    "            if X_orig.empty:\n",
    "                continue\n",
    "\n",
    "            # generate synthetic copy of the original half\n",
    "            X_syn = full_conditionals_tabpfn_generator(X_orig, **generator_kwargs)\n",
    "\n",
    "            # add metadata (insert in reverse so final left-to-right order is desired)\n",
    "            chunk = X_syn.copy()\n",
    "            chunk.insert(0, \"__role__\", \"syn\")\n",
    "            chunk.insert(0, \"__split__\", np.int32(j))\n",
    "            chunk.insert(0, \"__task_id__\", np.int32(task_id))\n",
    "            chunk.insert(0, \"__dataset__\", str(ds_name))\n",
    "\n",
    "            all_chunks.append(chunk)\n",
    "\n",
    "        except Exception as e:\n",
    "            failures.append({\n",
    "                \"task_id\": task_id,\n",
    "                \"split\": j,\n",
    "                \"role\": \"syn\",\n",
    "                \"error\": repr(e),\n",
    "            })\n",
    "\n",
    "    if all_chunks:\n",
    "        syn_long = pd.concat(all_chunks, axis=0, ignore_index=True)\n",
    "    else:\n",
    "        syn_long = pd.DataFrame(columns=[\"__dataset__\", \"__task_id__\", \"__split__\", \"__role__\"])\n",
    "\n",
    "    return syn_long, failures\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c07a528-9936-45b5-b3a0-84a4a41adcd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_path = '/home/jovyan/baseline_comparisons/outputs/'\n",
    "split_seeds = list(range(1, 11))  # 10 splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89c1921e-9367-448d-b5c5-13a0c41e50cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------- Abalone data --------------------------------------------------------------\n",
    "\n",
    "from sklearn.datasets import fetch_openml\n",
    "\n",
    "# Fetch the Abalone dataset\n",
    "abalone = fetch_openml(name=\"abalone\", version=1, as_frame=True)\n",
    "\n",
    "# Access the data and target\n",
    "X = abalone.data\n",
    "y = abalone.target\n",
    "\n",
    "X['target'] =  y # Rings\n",
    "\n",
    "\n",
    "num_idx = [1, 2, 3, 4, 5, 6, 7, 8]\n",
    "cat_idx = [0]\n",
    "\n",
    "X = enforce_dtypes(dat = X, \n",
    "                   num_variables = num_idx, \n",
    "                   cat_variables = cat_idx)\n",
    "\n",
    "\n",
    "# ------------------ Generate 10 random splits between original and holdout sets --------------\n",
    "print('generating data splits')\n",
    "splits_long = build_all_splits(X, split_seeds, ds_name = 'abalone', task_id = 0)\n",
    "fname1 = os.path.join(output_path, \"abalone_orig_hold_splits.feather\")\n",
    "feather.write_feather(splits_long, fname1)\n",
    "\n",
    "\n",
    "# ------------------ Generate MIAV synthetic datasets -----------------------------------------\n",
    "print('running MIAV')\n",
    "syn_long_miav, failures_miav = build_all_synthetics_miav(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='abalone'\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname2 = os.path.join(output_path, \"abalone_syn_miav.feather\")\n",
    "feather.write_feather(syn_long_miav, fname2)\n",
    "\n",
    "\n",
    "# ------------------- Generate JF synthetic datasets -------------------------------------------\n",
    "print('running JF')\n",
    "syn_long_jf, failures_jf = build_all_synthetics_jf(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='abalone'\n",
    ")\n",
    "\n",
    "# Save the combined table once:\n",
    "fname3 = os.path.join(output_path, \"abalone_syn_jf.feather\")\n",
    "feather.write_feather(syn_long_jf, fname3)\n",
    "\n",
    "\n",
    "# ------------------- Generate FC synthetic datasets -------------------------------------------\n",
    "print('running FC')\n",
    "syn_long_fc, failures_fc = build_all_synthetics_fc(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='abalone'\n",
    ")\n",
    "\n",
    "# Save the combined table once:\n",
    "fname4 = os.path.join(output_path, \"abalone_syn_fc.feather\")\n",
    "feather.write_feather(syn_long_fc, fname4)\n",
    "\n",
    "\n",
    "print(len(failures_miav), failures_miav[:1])\n",
    "print(len(failures_jf), failures_jf[:1])\n",
    "print(len(failures_fc), failures_fc[:1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff8154e4-d52a-42e0-bb79-22ad5023a6de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------------------- Bank marketing data ------------------------------------------------------\n",
    "\n",
    "import openml\n",
    "\n",
    "# bank marketing\n",
    "dataset = openml.datasets.get_dataset(44126) \n",
    "\n",
    "X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
    "\n",
    "X['target'] = y\n",
    "\n",
    "num_idx = list(range(0, 7))\n",
    "cat_idx = [7]\n",
    "\n",
    "X = enforce_dtypes(dat = X, \n",
    "                   num_variables = num_idx, \n",
    "                   cat_variables = cat_idx)\n",
    "\n",
    "\n",
    "# ------------------ Generate 10 random splits between original and holdout sets --------------\n",
    "print('generating data splits')\n",
    "splits_long = build_all_splits(X, split_seeds, ds_name = 'bank', task_id = 0)\n",
    "fname1 = os.path.join(output_path, \"bank_orig_hold_splits.feather\")\n",
    "feather.write_feather(splits_long, fname1)\n",
    "\n",
    "\n",
    "# ------------------ Generate MIAV synthetic datasets -----------------------------------------\n",
    "print('running MIAV')\n",
    "syn_long_miav, failures_miav = build_all_synthetics_miav(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='bank'\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname2 = os.path.join(output_path, \"bank_syn_miav.feather\")\n",
    "feather.write_feather(syn_long_miav, fname2)\n",
    "\n",
    "\n",
    "# ------------------- Generate JF synthetic datasets -------------------------------------------\n",
    "print('running JF')\n",
    "syn_long_jf, failures_jf = build_all_synthetics_jf(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='bank'\n",
    ")\n",
    "\n",
    "# Save the combined table once:\n",
    "fname3 = os.path.join(output_path, \"bank_syn_jf.feather\")\n",
    "feather.write_feather(syn_long_jf, fname3)\n",
    "\n",
    "\n",
    "# ------------------- Generate FC synthetic datasets -------------------------------------------\n",
    "print('running FC')\n",
    "syn_long_fc, failures_fc = build_all_synthetics_fc(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='bank'\n",
    ")\n",
    "\n",
    "# Save the combined table once:\n",
    "fname4 = os.path.join(output_path, \"bank_syn_fc.feather\")\n",
    "feather.write_feather(syn_long_fc, fname4)\n",
    "\n",
    "\n",
    "print(len(failures_miav), failures_miav[:1])\n",
    "print(len(failures_jf), failures_jf[:1])\n",
    "print(len(failures_fc), failures_fc[:1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ca5357f-7ebc-43d3-bbcd-e135e95152b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------- Credit dataset -------------------------------------------------------------\n",
    "\n",
    "dataset = openml.datasets.get_dataset(44089) \n",
    "\n",
    "X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
    "\n",
    "X['target'] = y\n",
    "\n",
    "num_idx = list(range(0, 10))\n",
    "cat_idx = [10]\n",
    "\n",
    "X = enforce_dtypes(dat = X, \n",
    "                   num_variables = num_idx, \n",
    "                   cat_variables = cat_idx)\n",
    "\n",
    "\n",
    "# ------------------ Generate 10 random splits between original and holdout sets --------------\n",
    "print('generating data splits')\n",
    "splits_long = build_all_splits(X, split_seeds, ds_name = 'credit', task_id = 0)\n",
    "fname1 = os.path.join(output_path, \"credit_orig_hold_splits.feather\")\n",
    "feather.write_feather(splits_long, fname1)\n",
    "\n",
    "\n",
    "# ------------------ Generate MIAV synthetic datasets -----------------------------------------\n",
    "print('running MIAV')\n",
    "syn_long_miav, failures_miav = build_all_synthetics_miav(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='credit'\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname2 = os.path.join(output_path, \"credit_syn_miav.feather\")\n",
    "feather.write_feather(syn_long_miav, fname2)\n",
    "\n",
    "\n",
    "# ------------------- Generate JF synthetic datasets -------------------------------------------\n",
    "print('running JF')\n",
    "syn_long_jf, failures_jf = build_all_synthetics_jf(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='credit'\n",
    ")\n",
    "\n",
    "# Save the combined table once:\n",
    "fname3 = os.path.join(output_path, \"credit_syn_jf.feather\")\n",
    "feather.write_feather(syn_long_jf, fname3)\n",
    "\n",
    "\n",
    "# ------------------- Generate FC synthetic datasets -------------------------------------------\n",
    "print('running FC')\n",
    "syn_long_fc, failures_fc = build_all_synthetics_fc(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='credit'\n",
    ")\n",
    "\n",
    "# Save the combined table once:\n",
    "fname4 = os.path.join(output_path, \"credit_syn_fc.feather\")\n",
    "feather.write_feather(syn_long_fc, fname4)\n",
    "\n",
    "\n",
    "print(len(failures_miav), failures_miav[:1])\n",
    "print(len(failures_jf), failures_jf[:1])\n",
    "print(len(failures_fc), failures_fc[:1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d71aa4ff-9e37-4f3f-a987-bd751c9294b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------- Eye movements dataset ----------------------------------------------\n",
    "\n",
    "dataset = openml.datasets.get_dataset(44130) \n",
    "\n",
    "X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
    "\n",
    "X['target'] = y\n",
    "\n",
    "num_idx = list(range(0, 20))\n",
    "cat_idx = [20]\n",
    "\n",
    "X = enforce_dtypes(dat = X, \n",
    "                   num_variables = num_idx, \n",
    "                   cat_variables = cat_idx)\n",
    "\n",
    "\n",
    "# ------------------ Generate 10 random splits between original and holdout sets --------------\n",
    "print('generating data splits')\n",
    "splits_long = build_all_splits(X, split_seeds, ds_name = 'eye', task_id = 0)\n",
    "fname1 = os.path.join(output_path, \"eye_orig_hold_splits.feather\")\n",
    "feather.write_feather(splits_long, fname1)\n",
    "\n",
    "\n",
    "# ------------------ Generate MIAV synthetic datasets -----------------------------------------\n",
    "print('running MIAV')\n",
    "syn_long_miav, failures_miav = build_all_synthetics_miav(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='eye'\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname2 = os.path.join(output_path, \"eye_syn_miav.feather\")\n",
    "feather.write_feather(syn_long_miav, fname2)\n",
    "\n",
    "\n",
    "# ------------------- Generate JF synthetic datasets -------------------------------------------\n",
    "print('running JF')\n",
    "syn_long_jf, failures_jf = build_all_synthetics_jf(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='eye'\n",
    ")\n",
    "\n",
    "# Save the combined table once:\n",
    "fname3 = os.path.join(output_path, \"eye_syn_jf.feather\")\n",
    "feather.write_feather(syn_long_jf, fname3)\n",
    "\n",
    "\n",
    "# ------------------- Generate FC synthetic datasets -------------------------------------------\n",
    "print('running FC')\n",
    "syn_long_fc, failures_fc = build_all_synthetics_fc(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='eye'\n",
    ")\n",
    "\n",
    "# Save the combined table once:\n",
    "fname4 = os.path.join(output_path, \"eye_syn_fc.feather\")\n",
    "feather.write_feather(syn_long_fc, fname4)\n",
    "\n",
    "\n",
    "print(len(failures_miav), failures_miav[:1])\n",
    "print(len(failures_jf), failures_jf[:1])\n",
    "print(len(failures_fc), failures_fc[:1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1af46279-c835-43d6-a075-ec212f00d036",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------- House_16H dataset --------------------------------------------------------\n",
    "\n",
    "dataset = openml.datasets.get_dataset(44123) \n",
    "\n",
    "X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
    "\n",
    "X['target'] = y\n",
    "\n",
    "num_idx = list(range(0, 16))\n",
    "cat_idx = [16]\n",
    "\n",
    "X = enforce_dtypes(dat = X, \n",
    "                   num_variables = num_idx, \n",
    "                   cat_variables = cat_idx)\n",
    "\n",
    "\n",
    "# ------------------ Generate 10 random splits between original and holdout sets --------------\n",
    "print('generating data splits')\n",
    "splits_long = build_all_splits(X, split_seeds, ds_name = 'house16h', task_id = 0)\n",
    "fname1 = os.path.join(output_path, \"house16h_orig_hold_splits.feather\")\n",
    "feather.write_feather(splits_long, fname1)\n",
    "\n",
    "\n",
    "# ------------------ Generate MIAV synthetic datasets -----------------------------------------\n",
    "print('running MIAV')\n",
    "syn_long_miav, failures_miav = build_all_synthetics_miav(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='house16h'\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname2 = os.path.join(output_path, \"house16h_syn_miav.feather\")\n",
    "feather.write_feather(syn_long_miav, fname2)\n",
    "\n",
    "\n",
    "# ------------------- Generate JF synthetic datasets -------------------------------------------\n",
    "print('running JF')\n",
    "syn_long_jf, failures_jf = build_all_synthetics_jf(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='house16h'\n",
    ")\n",
    "\n",
    "# Save the combined table once:\n",
    "fname3 = os.path.join(output_path, \"house16h_syn_jf.feather\")\n",
    "feather.write_feather(syn_long_jf, fname3)\n",
    "\n",
    "\n",
    "# ------------------- Generate FC synthetic datasets -------------------------------------------\n",
    "print('running FC')\n",
    "syn_long_fc, failures_fc = build_all_synthetics_fc(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='house16h'\n",
    ")\n",
    "\n",
    "# Save the combined table once:\n",
    "fname4 = os.path.join(output_path, \"house16h_syn_fc.feather\")\n",
    "feather.write_feather(syn_long_fc, fname4)\n",
    "\n",
    "\n",
    "print(len(failures_miav), failures_miav[:1])\n",
    "print(len(failures_jf), failures_jf[:1])\n",
    "print(len(failures_fc), failures_fc[:1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ed89776-0da7-4048-8373-2c7d9166c121",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------------- MagicTelescope data ---------------------------------------------------\n",
    "\n",
    "dataset = openml.datasets.get_dataset(44125) \n",
    "\n",
    "X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
    "\n",
    "X['target'] = y\n",
    "\n",
    "num_idx = list(range(0, 10))\n",
    "cat_idx = [10]\n",
    "\n",
    "X = enforce_dtypes(dat = X, \n",
    "                   num_variables = num_idx, \n",
    "                   cat_variables = cat_idx)\n",
    "\n",
    "\n",
    "# ------------------ Generate 10 random splits between original and holdout sets --------------\n",
    "print('generating data splits')\n",
    "splits_long = build_all_splits(X, split_seeds, ds_name = 'magic', task_id = 0)\n",
    "fname1 = os.path.join(output_path, \"magic_orig_hold_splits.feather\")\n",
    "feather.write_feather(splits_long, fname1)\n",
    "\n",
    "\n",
    "# ------------------ Generate MIAV synthetic datasets -----------------------------------------\n",
    "print('running MIAV')\n",
    "syn_long_miav, failures_miav = build_all_synthetics_miav(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='magic'\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname2 = os.path.join(output_path, \"magic_syn_miav.feather\")\n",
    "feather.write_feather(syn_long_miav, fname2)\n",
    "\n",
    "\n",
    "# ------------------- Generate JF synthetic datasets -------------------------------------------\n",
    "print('running JF')\n",
    "syn_long_jf, failures_jf = build_all_synthetics_jf(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='magic'\n",
    ")\n",
    "\n",
    "# Save the combined table once:\n",
    "fname3 = os.path.join(output_path, \"magic_syn_jf.feather\")\n",
    "feather.write_feather(syn_long_jf, fname3)\n",
    "\n",
    "\n",
    "# ------------------- Generate FC synthetic datasets -------------------------------------------\n",
    "print('running FC')\n",
    "syn_long_fc, failures_fc = build_all_synthetics_fc(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='magic'\n",
    ")\n",
    "\n",
    "# Save the combined table once:\n",
    "fname4 = os.path.join(output_path, \"magic_syn_fc.feather\")\n",
    "feather.write_feather(syn_long_fc, fname4)\n",
    "\n",
    "\n",
    "print(len(failures_miav), failures_miav[:1])\n",
    "print(len(failures_jf), failures_jf[:1])\n",
    "print(len(failures_fc), failures_fc[:1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ea7a83f-0859-4610-b0e2-0c7bb29759bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------- Pol data ----------------------------------------------------------------\n",
    "\n",
    "dataset = openml.datasets.get_dataset(44122) \n",
    "\n",
    "X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
    "\n",
    "X['target'] = y\n",
    "\n",
    "num_idx = list(range(0, 26))\n",
    "cat_idx = [26]\n",
    "\n",
    "X = enforce_dtypes(dat = X, \n",
    "                   num_variables = num_idx, \n",
    "                   cat_variables = cat_idx)\n",
    "\n",
    "\n",
    "# ------------------ Generate 10 random splits between original and holdout sets --------------\n",
    "print('generating data splits')\n",
    "splits_long = build_all_splits(X, split_seeds, ds_name = 'pol', task_id = 0)\n",
    "fname1 = os.path.join(output_path, \"pol_orig_hold_splits.feather\")\n",
    "feather.write_feather(splits_long, fname1)\n",
    "\n",
    "\n",
    "# ------------------ Generate MIAV synthetic datasets -----------------------------------------\n",
    "print('running MIAV')\n",
    "syn_long_miav, failures_miav = build_all_synthetics_miav(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='pol'\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname2 = os.path.join(output_path, \"pol_syn_miav.feather\")\n",
    "feather.write_feather(syn_long_miav, fname2)\n",
    "\n",
    "\n",
    "# ------------------- Generate JF synthetic datasets -------------------------------------------\n",
    "print('running JF')\n",
    "syn_long_jf, failures_jf = build_all_synthetics_jf(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='pol'\n",
    ")\n",
    "\n",
    "# Save the combined table once:\n",
    "fname3 = os.path.join(output_path, \"pol_syn_jf.feather\")\n",
    "feather.write_feather(syn_long_jf, fname3)\n",
    "\n",
    "\n",
    "# ------------------- Generate FC synthetic datasets -------------------------------------------\n",
    "print('running FC')\n",
    "syn_long_fc, failures_fc = build_all_synthetics_fc(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    task_id=0,\n",
    "    ds_name='pol'\n",
    ")\n",
    "\n",
    "# Save the combined table once:\n",
    "fname4 = os.path.join(output_path, \"pol_syn_fc.feather\")\n",
    "feather.write_feather(syn_long_fc, fname4)\n",
    "\n",
    "\n",
    "print(len(failures_miav), failures_miav[:1])\n",
    "print(len(failures_jf), failures_jf[:1])\n",
    "print(len(failures_fc), failures_fc[:1])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
