{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e373d41-1d9c-4a81-b206-047462bf061d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here I am \"monkey patching\" torch.nn to include my own RMSNorm before anything uses it.\n",
    "# This patching fixes an error comming from a synthcity dependency (opacus/grad_sample/rms_norm.py)\n",
    "# when we try to load a new plugin.\n",
    "# The problem is that the opacus library assumes that my environment has torch.nn.RMSNorm, but my\n",
    "# PyTorch version is older than 2.3, so RMSNorm doesn't exist.\n",
    "\n",
    "import torch.nn as nn\n",
    "\n",
    "if not hasattr(nn, \"RMSNorm\"):\n",
    "    class RMSNorm(nn.Module):\n",
    "        def __init__(self, dim, eps=1e-8):\n",
    "            super().__init__()\n",
    "            self.eps = eps\n",
    "            self.scale = nn.Parameter(torch.ones(dim))\n",
    "\n",
    "        def forward(self, x):\n",
    "            norm = x.norm(2, dim=-1, keepdim=True)\n",
    "            rms = norm / (x.shape[-1] ** 0.5)\n",
    "            return self.scale * x / (rms + self.eps)\n",
    "    \n",
    "    nn.RMSNorm = RMSNorm  # patch torch.nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2180112a-eb1a-4bf5-a1e2-7972b2913f38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from tqdm.auto import tqdm\n",
    "import pyarrow.feather as feather\n",
    "\n",
    "from synthcity.plugins import Plugins\n",
    "from synthcity.plugins.core.dataloader import GenericDataLoader\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8431e96-2115-4571-ac79-4825f5f46a56",
   "metadata": {},
   "outputs": [],
   "source": [
    "# source the smote method synthcity plugin\n",
    "code_path = ''\n",
    "file_path = os.path.join(code_path, 'synthcity_plugin_for_smotenc_generator.py')\n",
    "with open(os.path.expanduser(file_path)) as file:\n",
    "    exec(file.read())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0203f145-4542-41a3-a271-e6b009e09855",
   "metadata": {},
   "outputs": [],
   "source": [
    "from synthcity.plugins import Plugins\n",
    "\n",
    "generators = Plugins()\n",
    "\n",
    "generators.add(\"smotenc\", SmoteNCPlugin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37dd2893-9870-4b9e-b0cf-36c1f2692e30",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def build_all_synthetics_synthcity(\n",
    "    X: pd.DataFrame,\n",
    "    split_seeds,\n",
    "    plugin_name: str,                    # e.g. \"tvae\", \"ctgan\", \"arf\", \"ddpm\", ...\n",
    "    plugin_kwargs: dict | None = None,\n",
    "    *,\n",
    "    task_id: int = 0,\n",
    "    ds_name: str = \"\",\n",
    "    na_strategy: str = \"drop\",           # \"drop\" | \"ignore\"\n",
    "    set_split_seed: bool = True,         # inject random_state=seed if not provided\n",
    "):\n",
    "    \"\"\"\n",
    "    For each seed:\n",
    "      - 50/50 split (use the 'orig' half)\n",
    "      - handle NAs per `na_strategy`\n",
    "      - fit chosen synthcity plugin: Plugins().get(plugin_name, **plugin_kwargs)\n",
    "      - generate synthetic rows and add metadata\n",
    "    Returns (syn_long, failures)\n",
    "    \"\"\"\n",
    "    all_chunks: list[pd.DataFrame] = []\n",
    "    failures: list[dict] = []\n",
    "\n",
    "    for j, seed in enumerate(split_seeds, start=1):\n",
    "        try:\n",
    "            # Split\n",
    "            X_orig, _ = train_test_split(X, test_size=0.5, random_state=seed)\n",
    "\n",
    "            # NA handling\n",
    "            if na_strategy == \"drop\":\n",
    "                X_proc = X_orig.dropna().reset_index(drop=True)\n",
    "            elif na_strategy == \"ignore\":\n",
    "                X_proc = X_orig.reset_index(drop=True)\n",
    "            else:\n",
    "                raise ValueError(f\"Unknown na_strategy: {na_strategy}\")\n",
    "\n",
    "            if X_proc.empty:\n",
    "                failures.append({\n",
    "                    \"task_id\": task_id, \"split\": j, \"role\": \"syn\",\n",
    "                    \"error\": \"empty_after_na_handling\",\n",
    "                })\n",
    "                continue\n",
    "\n",
    "            # Data loader (unconditional generation on all columns)\n",
    "            loader = GenericDataLoader(X_proc)\n",
    "\n",
    "            # Build the plugin with user-specified kwargs\n",
    "            params = dict(plugin_kwargs or {})\n",
    "            if set_split_seed and \"random_state\" not in params:\n",
    "                params[\"random_state\"] = seed  # per-split reproducibility\n",
    "\n",
    "            model = Plugins().get(plugin_name, **params)\n",
    "\n",
    "            # Fit & generate\n",
    "            model.fit(loader)\n",
    "            n = len(loader)\n",
    "            X_syn = model.generate(count=n).dataframe()\n",
    "\n",
    "            # Add metadata\n",
    "            chunk = X_syn.copy()\n",
    "            chunk.insert(0, \"__role__\", \"syn\")\n",
    "            chunk.insert(0, \"__split__\", np.int32(j))\n",
    "            chunk.insert(0, \"__task_id__\", np.int32(task_id))\n",
    "            chunk.insert(0, \"__dataset__\", str(ds_name))\n",
    "\n",
    "            all_chunks.append(chunk)\n",
    "\n",
    "        except Exception as e:\n",
    "            failures.append({\n",
    "                \"task_id\": task_id, \"split\": j, \"role\": \"syn\", \"error\": repr(e),\n",
    "            })\n",
    "\n",
    "    syn_long = (\n",
    "        pd.concat(all_chunks, axis=0, ignore_index=True)\n",
    "        if all_chunks else\n",
    "        pd.DataFrame(columns=[\"__dataset__\", \"__task_id__\", \"__split__\", \"__role__\"])\n",
    "    )\n",
    "    return syn_long, failures\n",
    "\n",
    "\n",
    "\n",
    "def enforce_dtypes(dat, \n",
    "                   num_variables, \n",
    "                   cat_variables):\n",
    "    \"\"\"\n",
    "    Enforce \"float64\" type for numeric variables and \"object\" type for the\n",
    "    categorical variables\n",
    "    Parameters:\n",
    "        dat (pd.DataFrame): Input data matrix (numeric, categorical, or mixed).\n",
    "        num_variables (list): Indices of numeric variables.\n",
    "        cat_variables (list): Indices of categorical variables.\n",
    "\n",
    "    Returns:\n",
    "    pd.DataFrame: with transformed data types\n",
    "    \"\"\"\n",
    "    if num_variables is not None and cat_variables is None:\n",
    "        dat_N = pd.DataFrame(dat.iloc[:, num_variables], dtype = \"float64\")\n",
    "        dat = dat_N\n",
    "\n",
    "    elif num_variables is None and cat_variables is not None:\n",
    "        dat_C = pd.DataFrame(dat.iloc[:, cat_variables], dtype = \"str\")\n",
    "        dat = dat_C\n",
    "\n",
    "    elif num_variables is not None and cat_variables is not None:\n",
    "        dat_N = pd.DataFrame(dat.iloc[:, num_variables], dtype = \"float64\")\n",
    "        dat_C = pd.DataFrame(dat.iloc[:, cat_variables], dtype = \"str\")\n",
    "        dat = pd.concat([dat_N, dat_C], axis=1)\n",
    "        # Reorder columns to match the order in the original data\n",
    "        reordered_indices = num_variables + cat_variables\n",
    "        dat = dat.iloc[:, np.argsort(reordered_indices)]\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"At least one of num_variables or cat_variables must be specified.\")\n",
    "    \n",
    "    return dat \n",
    "\n",
    "\n",
    "def synth_smotenc(\n",
    "    dat: pd.DataFrame,\n",
    "    k: int,\n",
    "    round_int_vars: bool = True\n",
    ") -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    Python translation of SynthSMOTENC.\n",
    "\n",
    "    Implements a simple SMOTE-like synthesizer:\n",
    "      - k-NN is computed in the space of numeric columns only.\n",
    "      - Numeric columns are synthesized by interpolation with one randomly\n",
    "        chosen neighbor among the k neighbors.\n",
    "      - Categorical / boolean / string columns are synthesized by majority vote\n",
    "        across {sample} U {its k numeric-nearest neighbors}.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    dat : pd.DataFrame\n",
    "        Input data.\n",
    "    k : int\n",
    "        Number of nearest neighbors.\n",
    "    round_int_vars : bool, default True\n",
    "        If True, round back columns that are integer-valued in `dat`.\n",
    "    atol : float, default 1e-8\n",
    "        Tolerance to decide if an original numeric column is integer-like.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    pd.DataFrame\n",
    "        Synthetic data with the same shape/columns/index as `dat`.\n",
    "    \"\"\"\n",
    "\n",
    "    # ---- Identify numeric vs. categorical-like columns (like GetVariableTypes) ----\n",
    "    num_cols = dat.select_dtypes(include=\"number\").columns.tolist()\n",
    "    cat_cols = dat.select_dtypes(include=[\"category\", \"object\", \"string\", \"bool\", \"boolean\"]).columns.tolist()\n",
    "\n",
    "    n = len(dat)\n",
    "    if len(num_cols) == 0 and len(cat_cols) == 0:\n",
    "        # Nothing to do\n",
    "        return dat.copy()\n",
    "\n",
    "    # ---- kNN on numeric columns (if any) ----\n",
    "    if len(num_cols) > 0:\n",
    "        X_num = dat[num_cols].to_numpy(dtype=float)\n",
    "        # sklearn kneighbors on the same data includes self as the first neighbor;\n",
    "        # we ask for k+1 and then drop the self neighbor to match R's FNN::get.knn behavior.\n",
    "        nbrs = NearestNeighbors(n_neighbors=min(k + 1, max(1, n)), algorithm=\"auto\")\n",
    "        nbrs.fit(X_num)\n",
    "        distances, indices = nbrs.kneighbors(X_num, return_distance=True)\n",
    "        # Drop self (first column) if present\n",
    "        if indices.shape[1] > 1:\n",
    "            nn_index = indices[:, 1:]  # shape: (n, k) when n > 1 and k >= 1\n",
    "        else:\n",
    "            # Degenerate case: only one neighbor (itself)\n",
    "            nn_index = np.zeros((n, 0), dtype=int)\n",
    "    else:\n",
    "        nn_index = np.zeros((n, 0), dtype=int)\n",
    "\n",
    "    # ---- Helpers ----\n",
    "    def generate_synth_num_data(dat_df: pd.DataFrame, nn_idx: np.ndarray, num_variables: list[str]) -> pd.DataFrame:\n",
    "        \"\"\"Interpolate numeric columns toward a random neighbor among k.\"\"\"\n",
    "        if len(num_variables) == 0:\n",
    "            return dat_df.copy()\n",
    "\n",
    "        out = dat_df.copy()\n",
    "        Xn = out[num_variables].to_numpy(dtype=float)\n",
    "        for i in range(n):\n",
    "            if nn_idx.shape[1] == 0:  # no neighbors\n",
    "                continue\n",
    "            # pick one neighbor index uniformly at random\n",
    "            j_idx = np.random.randint(nn_idx.shape[1])\n",
    "            nb = nn_idx[i, j_idx]\n",
    "            lam = np.random.rand()\n",
    "            # x + lambda * (x_nn - x)\n",
    "            Xn[i, :] = Xn[i, :] + lam * (Xn[nb, :] - Xn[i, :])\n",
    "        out[num_variables] = Xn\n",
    "        return out\n",
    "\n",
    "    def majority_vote(values: list[str]) -> str:\n",
    "        \"\"\"Majority vote with a deterministic tie-break (lexicographic).\"\"\"\n",
    "        s = pd.Series(values, dtype=\"string\")\n",
    "        counts = s.value_counts(dropna=False)\n",
    "        # tie-break by label lexicographic to mimic R's sort(table(...), decreasing=TRUE)\n",
    "        top_count = counts.max()\n",
    "        candidates = sorted([idx for idx, c in counts.items() if c == top_count], key=lambda x: str(x))\n",
    "        return candidates[0]\n",
    "\n",
    "    def generate_synth_cat_data(dat_df: pd.DataFrame, nn_idx: np.ndarray, cat_variables: list[str]) -> pd.DataFrame:\n",
    "        \"\"\"Majority vote across sample + its k neighbors for each categorical column.\"\"\"\n",
    "        if len(cat_variables) == 0:\n",
    "            return dat_df.copy()\n",
    "\n",
    "        out = dat_df.copy()\n",
    "        for i in range(n):\n",
    "            # neighbor indices (could be empty if no numeric columns)\n",
    "            neigh = nn_idx[i, :].tolist() if nn_idx.shape[1] else []\n",
    "            # for each categorical column, vote among sample + neighbors\n",
    "            for col in cat_variables:\n",
    "                vals = [dat_df.iloc[i][col]]\n",
    "                if neigh:\n",
    "                    vals.extend(dat_df.iloc[neigh][col].astype(\"string\").tolist())\n",
    "                out.at[out.index[i], col] = majority_vote(vals)\n",
    "        return out\n",
    "\n",
    "    # ---- Generate numeric-only and categorical-only synthetic datasets ----\n",
    "    synth_num_dat = generate_synth_num_data(dat, nn_index, num_cols)\n",
    "    synth_cat_dat = generate_synth_cat_data(dat, nn_index, cat_cols)\n",
    "\n",
    "    # ---- Combine: numeric from synth_num_dat + categorical from synth_cat_dat ----\n",
    "    synthetic_dat = synth_num_dat.copy()\n",
    "    if len(cat_cols) > 0:\n",
    "        synthetic_dat[cat_cols] = synth_cat_dat[cat_cols]\n",
    "\n",
    "    # ---- Optionally round back integer-valued numeric columns (like RoundIntegerVariables) ----\n",
    "    if round_int_vars and len(num_cols) > 0:\n",
    "        synthetic_dat = round_integer_variables(df_ori=dat, df_syn=synthetic_dat)\n",
    "\n",
    "    # Preserve original column order/dtypes as much as feasible\n",
    "    synthetic_dat = synthetic_dat.astype(\n",
    "        {\n",
    "            c: dat[c].dtype\n",
    "            for c in synthetic_dat.columns\n",
    "            if isinstance(dat[c].dtype, pd.CategoricalDtype)\n",
    "        },\n",
    "        copy=False,\n",
    "    )\n",
    "    \n",
    "    return synthetic_dat\n",
    "\n",
    "\n",
    "def round_integer_variables(df_ori: pd.DataFrame, df_syn: pd.DataFrame, atol: float = 1e-8) -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    This function uses the original data to determine which variables\n",
    "    are integer-valued and then rounds the values of the corresponding\n",
    "    variables in the synthetic data to the nearest integer.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    df_ori : pd.DataFrame\n",
    "        DataFrame containing the original data.\n",
    "    df_syn : pd.DataFrame\n",
    "        DataFrame containing the synthetic data.\n",
    "        Assumes same schema/columns as `df_ori`.\n",
    "    atol : float, optional\n",
    "        Absolute tolerance when testing whether original data is\n",
    "        integer-valued (default = 1e-8). This mimics R's all.equal().\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    pd.DataFrame\n",
    "        Synthetic data with rounded values for integer-valued variables.\n",
    "    \"\"\"\n",
    "    # Make sure df_ori and df_syn are pandas dataframes\n",
    "    df_ori = pd.DataFrame(df_ori)\n",
    "    df_syn = pd.DataFrame(df_syn)\n",
    "    \n",
    "    # Make a copy so original synthetic data is not modified\n",
    "    syn = df_syn.copy()\n",
    "\n",
    "    # Iterate only over numeric columns in the original data\n",
    "    for col in df_ori.select_dtypes(include=\"number\").columns:\n",
    "        # Drop missing values to avoid issues when checking integer-ness\n",
    "        s = df_ori[col].dropna()\n",
    "\n",
    "        # If column is not empty, check if all values are \"integer-like\"\n",
    "        # i.e., equal to their rounded version (within tolerance)\n",
    "        if len(s) and np.isclose(s.to_numpy(), np.round(s.to_numpy()), atol=atol).all():\n",
    "            # Round the synthetic data for this column to the nearest integer\n",
    "            #syn[col] = np.round(syn[col])\n",
    "            syn[col] = np.rint(pd.to_numeric(syn[col], errors=\"coerce\"))\n",
    "\n",
    "    return syn\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "503db7af-afc3-4330-a66d-965498a6fa36",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_path = ''\n",
    "split_seeds = list(range(1, 11))  # 10 splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98903b4f-21f3-4b3d-a60d-aec46abac95b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------- Abalone data --------------------------------------------------------------\n",
    "\n",
    "from sklearn.datasets import fetch_openml\n",
    "\n",
    "# Fetch the Abalone dataset\n",
    "abalone = fetch_openml(name=\"abalone\", version=1, as_frame=True)\n",
    "\n",
    "# Access the data and target\n",
    "X = abalone.data\n",
    "y = abalone.target\n",
    "\n",
    "X['target'] =  y # Rings\n",
    "\n",
    "\n",
    "num_idx = [1, 2, 3, 4, 5, 6, 7, 8]\n",
    "cat_idx = [0]\n",
    "\n",
    "X = enforce_dtypes(dat = X, \n",
    "                   num_variables = num_idx, \n",
    "                   cat_variables = cat_idx)\n",
    "\n",
    "ds_name = 'abalone'\n",
    "\n",
    "# ------------------ Generate synthetic datasets using bayesian_network -----------------------------------------\n",
    "print('running bayesian_network')\n",
    "syn_long_bayesnet, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'bayesian_network',   \n",
    "    plugin_kwargs = {'struct_learning_search_method': 'hillclimb',\n",
    "                        'struct_learning_score': 'bic'},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_bayesnet.feather\")\n",
    "feather.write_feather(syn_long_bayesnet, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using arf -----------------------------------------\n",
    "print('running arf')\n",
    "syn_long_arf, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'arf',   \n",
    "    plugin_kwargs = {'num_trees': 80,\n",
    "                        'delta': 0,\n",
    "                        'max_iters': 2,\n",
    "                        'early_stop': False,\n",
    "                        'min_node_size': 2},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_arf.feather\")\n",
    "feather.write_feather(syn_long_arf, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using ctgan -----------------------------------------\n",
    "print('running ctgan')\n",
    "syn_long_ctgan, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'ctgan',   \n",
    "    plugin_kwargs = {'generator_n_layers_hidden': 1,\n",
    "                            'generator_n_units_hidden': 100,\n",
    "                            'generator_nonlin': 'elu',\n",
    "                            'n_iter': 700,\n",
    "                            'generator_dropout': 0.13836424598477665,\n",
    "                            'discriminator_n_layers_hidden': 2,\n",
    "                            'discriminator_n_units_hidden': 100,\n",
    "                            'discriminator_nonlin': 'tanh',\n",
    "                            'discriminator_n_iter': 5,\n",
    "                            'discriminator_dropout': 0.023861565936528797,\n",
    "                            'lr': 0.001,\n",
    "                            'weight_decay': 0.0001,\n",
    "                            'batch_size': 200,\n",
    "                            'encoder_max_clusters': 8},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_ctgan.feather\")\n",
    "feather.write_feather(syn_long_ctgan, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using tvae -----------------------------------------\n",
    "print('running tvae')\n",
    "syn_long_tvae, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'tvae',   \n",
    "    plugin_kwargs = {'n_iter': 400,\n",
    "                          'lr': 0.001,\n",
    "                          'decoder_n_layers_hidden': 5,\n",
    "                          'weight_decay': 0.0001,\n",
    "                          'batch_size': 128,\n",
    "                          'n_units_embedding': 200,\n",
    "                          'decoder_n_units_hidden': 150,\n",
    "                          'decoder_nonlin': 'tanh',\n",
    "                          'decoder_dropout': 0.19964446358158816,\n",
    "                          'encoder_n_layers_hidden': 4,\n",
    "                          'encoder_n_units_hidden': 100,\n",
    "                          'encoder_nonlin': 'relu',\n",
    "                          'encoder_dropout': 0.0820245231222064},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_tvae.feather\")\n",
    "feather.write_feather(syn_long_tvae, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using ddpm -----------------------------------------\n",
    "print('running ddpm')\n",
    "syn_long_ddpm, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'ddpm',   \n",
    "    plugin_kwargs = {'lr': 0.002991978123076162,\n",
    "                          'batch_size': 970,\n",
    "                          'num_timesteps': 407,\n",
    "                          'n_iter': 7605,\n",
    "                          'is_classification': False},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_ddpm.feather\")\n",
    "feather.write_feather(syn_long_ddpm, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using smote -----------------------------------------\n",
    "print('running smote')\n",
    "syn_long_smote, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'smotenc',   \n",
    "    plugin_kwargs = {'k': 5},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_smote.feather\")\n",
    "feather.write_feather(syn_long_smote, fname)\n",
    "print(len(failures), failures[:1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a74ba33-ff31-4510-b164-01f850707d87",
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------------------- Bank marketing data ------------------------------------------------------\n",
    "\n",
    "import openml\n",
    "\n",
    "# bank marketing\n",
    "dataset = openml.datasets.get_dataset(44126) \n",
    "\n",
    "X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
    "\n",
    "X['target'] = y\n",
    "\n",
    "num_idx = list(range(0, 7))\n",
    "cat_idx = [7]\n",
    "\n",
    "X = enforce_dtypes(dat = X, \n",
    "                   num_variables = num_idx, \n",
    "                   cat_variables = cat_idx)\n",
    "\n",
    "\n",
    "ds_name = 'bank'\n",
    "\n",
    "# ------------------ Generate synthetic datasets using bayesian_network -----------------------------------------\n",
    "print('running bayesian_network')\n",
    "syn_long_bayesnet, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'bayesian_network',   \n",
    "    plugin_kwargs = {'struct_learning_search_method': 'hillclimb',\n",
    "                         'struct_learning_score': 'bic'},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_bayesnet.feather\")\n",
    "feather.write_feather(syn_long_bayesnet, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using arf -----------------------------------------\n",
    "print('running arf')\n",
    "syn_long_arf, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'arf',   \n",
    "    plugin_kwargs = {},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_arf.feather\")\n",
    "feather.write_feather(syn_long_arf, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using ctgan -----------------------------------------\n",
    "print('running ctgan')\n",
    "syn_long_ctgan, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'ctgan',   \n",
    "    plugin_kwargs = {'generator_n_layers_hidden': 2,\n",
    "                            'generator_n_units_hidden': 50,\n",
    "                            'generator_nonlin': 'tanh',\n",
    "                            'n_iter': 1000,\n",
    "                            'generator_dropout': 0.0575,\n",
    "                            'discriminator_n_layers_hidden': 4,\n",
    "                            'discriminator_n_units_hidden': 150,\n",
    "                            'discriminator_nonlin': 'relu'},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_ctgan.feather\")\n",
    "feather.write_feather(syn_long_ctgan, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using tvae -----------------------------------------\n",
    "print('running tvae')\n",
    "syn_long_tvae, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'tvae',   \n",
    "    plugin_kwargs = {'n_iter': 300,\n",
    "                          'lr': 0.0002,\n",
    "                          'decoder_n_layers_hidden': 4,\n",
    "                          'weight_decay': 0.001,\n",
    "                          'batch_size': 256,\n",
    "                          'n_units_embedding': 200,\n",
    "                          'decoder_n_units_hidden': 300,\n",
    "                          'decoder_nonlin': 'elu',\n",
    "                          'decoder_dropout': 0.194325119117226,\n",
    "                          'encoder_n_layers_hidden': 1,\n",
    "                          'encoder_n_units_hidden': 450,\n",
    "                          'encoder_nonlin': 'leaky_relu',\n",
    "                          'encoder_dropout': 0.04288563703094718},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_tvae.feather\")\n",
    "feather.write_feather(syn_long_tvae, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using ddpm -----------------------------------------\n",
    "print('running ddpm')\n",
    "syn_long_ddpm, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'ddpm',   \n",
    "    plugin_kwargs = {'lr': 0.0009375080542687667,\n",
    "                          'batch_size': 2929,\n",
    "                          'num_timesteps': 998,\n",
    "                          'n_iter': 1051,\n",
    "                          'is_classification': True},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_ddpm.feather\")\n",
    "feather.write_feather(syn_long_ddpm, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using smote -----------------------------------------\n",
    "print('running smote')\n",
    "syn_long_smote, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'smotenc',   \n",
    "    plugin_kwargs = {'k': 5},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_smote.feather\")\n",
    "feather.write_feather(syn_long_smote, fname)\n",
    "print(len(failures), failures[:1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5b347be-2206-4e53-afa4-03654acd1e7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------- Credit dataset -------------------------------------------------------------\n",
    "\n",
    "dataset = openml.datasets.get_dataset(44089) \n",
    "\n",
    "X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
    "\n",
    "X['target'] = y\n",
    "\n",
    "num_idx = list(range(0, 10))\n",
    "cat_idx = [10]\n",
    "\n",
    "X = enforce_dtypes(dat = X, \n",
    "                   num_variables = num_idx, \n",
    "                   cat_variables = cat_idx)\n",
    "\n",
    "\n",
    "ds_name = 'credit'\n",
    "\n",
    "# ------------------ Generate synthetic datasets using bayesian_network -----------------------------------------\n",
    "print('running bayesian_network')\n",
    "syn_long_bayesnet, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'bayesian_network',   \n",
    "    plugin_kwargs = {'struct_learning_search_method': 'hillclimb',\n",
    "                         'struct_learning_score': 'bic'},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_bayesnet.feather\")\n",
    "feather.write_feather(syn_long_bayesnet, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using arf -----------------------------------------\n",
    "print('running arf')\n",
    "syn_long_arf, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'arf',   \n",
    "    plugin_kwargs = {},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_arf.feather\")\n",
    "feather.write_feather(syn_long_arf, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using ctgan -----------------------------------------\n",
    "print('running ctgan')\n",
    "syn_long_ctgan, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'ctgan',   \n",
    "    plugin_kwargs = {'generator_n_layers_hidden': 2,\n",
    "                            'generator_n_units_hidden': 50,\n",
    "                            'generator_nonlin': 'tanh',\n",
    "                            'n_iter': 1000,\n",
    "                            'generator_dropout': 0.0575,\n",
    "                            'discriminator_n_layers_hidden': 4,\n",
    "                            'discriminator_n_units_hidden': 150,\n",
    "                            'discriminator_nonlin': 'relu'},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_ctgan.feather\")\n",
    "feather.write_feather(syn_long_ctgan, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using tvae -----------------------------------------\n",
    "print('running tvae')\n",
    "syn_long_tvae, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'tvae',   \n",
    "    plugin_kwargs = {'n_iter': 300,\n",
    "                          'lr': 0.0002,\n",
    "                          'decoder_n_layers_hidden': 4,\n",
    "                          'weight_decay': 0.001,\n",
    "                          'batch_size': 256,\n",
    "                          'n_units_embedding': 200,\n",
    "                          'decoder_n_units_hidden': 300,\n",
    "                          'decoder_nonlin': 'elu',\n",
    "                          'decoder_dropout': 0.194325119117226,\n",
    "                          'encoder_n_layers_hidden': 1,\n",
    "                          'encoder_n_units_hidden': 450,\n",
    "                          'encoder_nonlin': 'leaky_relu',\n",
    "                          'encoder_dropout': 0.04288563703094718},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_tvae.feather\")\n",
    "feather.write_feather(syn_long_tvae, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using ddpm -----------------------------------------\n",
    "print('running ddpm')\n",
    "syn_long_ddpm, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'ddpm',   \n",
    "    plugin_kwargs = {'lr': 0.0009375080542687667,\n",
    "                          'batch_size': 2929,\n",
    "                          'num_timesteps': 998,\n",
    "                          'n_iter': 1051,\n",
    "                          'is_classification': True},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_ddpm.feather\")\n",
    "feather.write_feather(syn_long_ddpm, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using smote -----------------------------------------\n",
    "print('running smote')\n",
    "syn_long_smote, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'smotenc',   \n",
    "    plugin_kwargs = {'k': 5},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_smote.feather\")\n",
    "feather.write_feather(syn_long_smote, fname)\n",
    "print(len(failures), failures[:1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56d91b05-5a2d-4aad-beef-49d29fbb9e70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------- Eye movements dataset ----------------------------------------------\n",
    "\n",
    "dataset = openml.datasets.get_dataset(44130) \n",
    "\n",
    "X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
    "\n",
    "X['target'] = y\n",
    "\n",
    "num_idx = list(range(0, 20))\n",
    "cat_idx = [20]\n",
    "\n",
    "X = enforce_dtypes(dat = X, \n",
    "                   num_variables = num_idx, \n",
    "                   cat_variables = cat_idx)\n",
    "\n",
    "\n",
    "ds_name = 'eye'\n",
    "\n",
    "# ------------------ Generate synthetic datasets using bayesian_network -----------------------------------------\n",
    "print('running bayesian_network')\n",
    "syn_long_bayesnet, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'bayesian_network',   \n",
    "    plugin_kwargs = {'struct_learning_search_method': 'hillclimb',\n",
    "                         'struct_learning_score': 'bic'},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_bayesnet.feather\")\n",
    "feather.write_feather(syn_long_bayesnet, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using arf -----------------------------------------\n",
    "print('running arf')\n",
    "syn_long_arf, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'arf',   \n",
    "    plugin_kwargs = {},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_arf.feather\")\n",
    "feather.write_feather(syn_long_arf, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using ctgan -----------------------------------------\n",
    "print('running ctgan')\n",
    "syn_long_ctgan, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'ctgan',   \n",
    "    plugin_kwargs = {'generator_n_layers_hidden': 2,\n",
    "                            'generator_n_units_hidden': 50,\n",
    "                            'generator_nonlin': 'tanh',\n",
    "                            'n_iter': 1000,\n",
    "                            'generator_dropout': 0.0575,\n",
    "                            'discriminator_n_layers_hidden': 4,\n",
    "                            'discriminator_n_units_hidden': 150,\n",
    "                            'discriminator_nonlin': 'relu'},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_ctgan.feather\")\n",
    "feather.write_feather(syn_long_ctgan, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using tvae -----------------------------------------\n",
    "print('running tvae')\n",
    "syn_long_tvae, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'tvae',   \n",
    "    plugin_kwargs = {'n_iter': 300,\n",
    "                          'lr': 0.0002,\n",
    "                          'decoder_n_layers_hidden': 4,\n",
    "                          'weight_decay': 0.001,\n",
    "                          'batch_size': 256,\n",
    "                          'n_units_embedding': 200,\n",
    "                          'decoder_n_units_hidden': 300,\n",
    "                          'decoder_nonlin': 'elu',\n",
    "                          'decoder_dropout': 0.194325119117226,\n",
    "                          'encoder_n_layers_hidden': 1,\n",
    "                          'encoder_n_units_hidden': 450,\n",
    "                          'encoder_nonlin': 'leaky_relu',\n",
    "                          'encoder_dropout': 0.04288563703094718},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_tvae.feather\")\n",
    "feather.write_feather(syn_long_tvae, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using ddpm -----------------------------------------\n",
    "print('running ddpm')\n",
    "syn_long_ddpm, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'ddpm',   \n",
    "    plugin_kwargs = {'lr': 0.0009375080542687667,\n",
    "                          'batch_size': 2929,\n",
    "                          'num_timesteps': 998,\n",
    "                          'n_iter': 1051,\n",
    "                          'is_classification': True},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_ddpm.feather\")\n",
    "feather.write_feather(syn_long_ddpm, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using smote -----------------------------------------\n",
    "print('running smote')\n",
    "syn_long_smote, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'smotenc',   \n",
    "    plugin_kwargs = {'k': 5},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_smote.feather\")\n",
    "feather.write_feather(syn_long_smote, fname)\n",
    "print(len(failures), failures[:1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55040534-683a-4284-9b8b-a0f8b1cf6f6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------- House_16H dataset --------------------------------------------------------\n",
    "\n",
    "dataset = openml.datasets.get_dataset(44123) \n",
    "\n",
    "X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
    "\n",
    "X['target'] = y\n",
    "\n",
    "num_idx = list(range(0, 16))\n",
    "cat_idx = [16]\n",
    "\n",
    "X = enforce_dtypes(dat = X, \n",
    "                   num_variables = num_idx, \n",
    "                   cat_variables = cat_idx)\n",
    "\n",
    "\n",
    "ds_name = 'house16h'\n",
    "\n",
    "# ------------------ Generate synthetic datasets using bayesian_network -----------------------------------------\n",
    "print('running bayesian_network')\n",
    "syn_long_bayesnet, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'bayesian_network',   \n",
    "    plugin_kwargs = {'struct_learning_search_method': 'hillclimb',\n",
    "                         'struct_learning_score': 'bic'},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_bayesnet.feather\")\n",
    "feather.write_feather(syn_long_bayesnet, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using arf -----------------------------------------\n",
    "print('running arf')\n",
    "syn_long_arf, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'arf',   \n",
    "    plugin_kwargs = {},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_arf.feather\")\n",
    "feather.write_feather(syn_long_arf, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using ctgan -----------------------------------------\n",
    "print('running ctgan')\n",
    "syn_long_ctgan, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'ctgan',   \n",
    "    plugin_kwargs = {'generator_n_layers_hidden': 2,\n",
    "                            'generator_n_units_hidden': 50,\n",
    "                            'generator_nonlin': 'tanh',\n",
    "                            'n_iter': 1000,\n",
    "                            'generator_dropout': 0.0575,\n",
    "                            'discriminator_n_layers_hidden': 4,\n",
    "                            'discriminator_n_units_hidden': 150,\n",
    "                            'discriminator_nonlin': 'relu'},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_ctgan.feather\")\n",
    "feather.write_feather(syn_long_ctgan, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using tvae -----------------------------------------\n",
    "print('running tvae')\n",
    "syn_long_tvae, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'tvae',   \n",
    "    plugin_kwargs = {'n_iter': 300,\n",
    "                          'lr': 0.0002,\n",
    "                          'decoder_n_layers_hidden': 4,\n",
    "                          'weight_decay': 0.001,\n",
    "                          'batch_size': 256,\n",
    "                          'n_units_embedding': 200,\n",
    "                          'decoder_n_units_hidden': 300,\n",
    "                          'decoder_nonlin': 'elu',\n",
    "                          'decoder_dropout': 0.194325119117226,\n",
    "                          'encoder_n_layers_hidden': 1,\n",
    "                          'encoder_n_units_hidden': 450,\n",
    "                          'encoder_nonlin': 'leaky_relu',\n",
    "                          'encoder_dropout': 0.04288563703094718},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_tvae.feather\")\n",
    "feather.write_feather(syn_long_tvae, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using ddpm -----------------------------------------\n",
    "print('running ddpm')\n",
    "syn_long_ddpm, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'ddpm',   \n",
    "    plugin_kwargs = {'lr': 0.0009375080542687667,\n",
    "                          'batch_size': 2929,\n",
    "                          'num_timesteps': 998,\n",
    "                          'n_iter': 1051,\n",
    "                          'is_classification': True},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_ddpm.feather\")\n",
    "feather.write_feather(syn_long_ddpm, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using smote -----------------------------------------\n",
    "print('running smote')\n",
    "syn_long_smote, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'smotenc',   \n",
    "    plugin_kwargs = {'k': 5},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_smote.feather\")\n",
    "feather.write_feather(syn_long_smote, fname)\n",
    "print(len(failures), failures[:1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7e08eca-a57e-4533-8f75-4085534b97c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------------- MagicTelescope data ---------------------------------------------------\n",
    "\n",
    "dataset = openml.datasets.get_dataset(44125) \n",
    "\n",
    "X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
    "\n",
    "X['target'] = y\n",
    "\n",
    "num_idx = list(range(0, 10))\n",
    "cat_idx = [10]\n",
    "\n",
    "X = enforce_dtypes(dat = X, \n",
    "                   num_variables = num_idx, \n",
    "                   cat_variables = cat_idx)\n",
    "\n",
    "\n",
    "ds_name = 'magic'\n",
    "\n",
    "# ------------------ Generate synthetic datasets using bayesian_network -----------------------------------------\n",
    "print('running bayesian_network')\n",
    "syn_long_bayesnet, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'bayesian_network',   \n",
    "    plugin_kwargs = {'struct_learning_search_method': 'hillclimb',\n",
    "                         'struct_learning_score': 'bic'},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_bayesnet.feather\")\n",
    "feather.write_feather(syn_long_bayesnet, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using arf -----------------------------------------\n",
    "print('running arf')\n",
    "syn_long_arf, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'arf',   \n",
    "    plugin_kwargs = {},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_arf.feather\")\n",
    "feather.write_feather(syn_long_arf, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using ctgan -----------------------------------------\n",
    "print('running ctgan')\n",
    "syn_long_ctgan, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'ctgan',   \n",
    "    plugin_kwargs = {'generator_n_layers_hidden': 2,\n",
    "                            'generator_n_units_hidden': 50,\n",
    "                            'generator_nonlin': 'tanh',\n",
    "                            'n_iter': 1000,\n",
    "                            'generator_dropout': 0.0575,\n",
    "                            'discriminator_n_layers_hidden': 4,\n",
    "                            'discriminator_n_units_hidden': 150,\n",
    "                            'discriminator_nonlin': 'relu'},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_ctgan.feather\")\n",
    "feather.write_feather(syn_long_ctgan, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using tvae -----------------------------------------\n",
    "print('running tvae')\n",
    "syn_long_tvae, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'tvae',   \n",
    "    plugin_kwargs = {'n_iter': 300,\n",
    "                          'lr': 0.0002,\n",
    "                          'decoder_n_layers_hidden': 4,\n",
    "                          'weight_decay': 0.001,\n",
    "                          'batch_size': 256,\n",
    "                          'n_units_embedding': 200,\n",
    "                          'decoder_n_units_hidden': 300,\n",
    "                          'decoder_nonlin': 'elu',\n",
    "                          'decoder_dropout': 0.194325119117226,\n",
    "                          'encoder_n_layers_hidden': 1,\n",
    "                          'encoder_n_units_hidden': 450,\n",
    "                          'encoder_nonlin': 'leaky_relu',\n",
    "                          'encoder_dropout': 0.04288563703094718},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_tvae.feather\")\n",
    "feather.write_feather(syn_long_tvae, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using ddpm -----------------------------------------\n",
    "print('running ddpm')\n",
    "syn_long_ddpm, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'ddpm',   \n",
    "    plugin_kwargs = {'lr': 0.0009375080542687667,\n",
    "                          'batch_size': 2929,\n",
    "                          'num_timesteps': 998,\n",
    "                          'n_iter': 1051,\n",
    "                          'is_classification': True},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_ddpm.feather\")\n",
    "feather.write_feather(syn_long_ddpm, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using smote -----------------------------------------\n",
    "print('running smote')\n",
    "syn_long_smote, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'smotenc',   \n",
    "    plugin_kwargs = {'k': 5},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_smote.feather\")\n",
    "feather.write_feather(syn_long_smote, fname)\n",
    "print(len(failures), failures[:1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63587d3e-346c-4f09-9a86-f00a032227ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------- Pol data ----------------------------------------------------------------\n",
    "\n",
    "dataset = openml.datasets.get_dataset(44122) \n",
    "\n",
    "X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
    "\n",
    "X['target'] = y\n",
    "\n",
    "num_idx = list(range(0, 26))\n",
    "cat_idx = [26]\n",
    "\n",
    "X = enforce_dtypes(dat = X, \n",
    "                   num_variables = num_idx, \n",
    "                   cat_variables = cat_idx)\n",
    "\n",
    "\n",
    "ds_name = 'pol'\n",
    "\n",
    "# ------------------ Generate synthetic datasets using bayesian_network -----------------------------------------\n",
    "print('running bayesian_network')\n",
    "syn_long_bayesnet, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'bayesian_network',   \n",
    "    plugin_kwargs = {'struct_learning_search_method': 'hillclimb',\n",
    "                         'struct_learning_score': 'bic'},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_bayesnet.feather\")\n",
    "feather.write_feather(syn_long_bayesnet, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using arf -----------------------------------------\n",
    "print('running arf')\n",
    "syn_long_arf, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'arf',   \n",
    "    plugin_kwargs = {},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_arf.feather\")\n",
    "feather.write_feather(syn_long_arf, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using ctgan -----------------------------------------\n",
    "print('running ctgan')\n",
    "syn_long_ctgan, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'ctgan',   \n",
    "    plugin_kwargs = {'generator_n_layers_hidden': 2,\n",
    "                            'generator_n_units_hidden': 50,\n",
    "                            'generator_nonlin': 'tanh',\n",
    "                            'n_iter': 1000,\n",
    "                            'generator_dropout': 0.0575,\n",
    "                            'discriminator_n_layers_hidden': 4,\n",
    "                            'discriminator_n_units_hidden': 150,\n",
    "                            'discriminator_nonlin': 'relu'},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_ctgan.feather\")\n",
    "feather.write_feather(syn_long_ctgan, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using tvae -----------------------------------------\n",
    "print('running tvae')\n",
    "syn_long_tvae, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'tvae',   \n",
    "    plugin_kwargs = {'n_iter': 300,\n",
    "                          'lr': 0.0002,\n",
    "                          'decoder_n_layers_hidden': 4,\n",
    "                          'weight_decay': 0.001,\n",
    "                          'batch_size': 256,\n",
    "                          'n_units_embedding': 200,\n",
    "                          'decoder_n_units_hidden': 300,\n",
    "                          'decoder_nonlin': 'elu',\n",
    "                          'decoder_dropout': 0.194325119117226,\n",
    "                          'encoder_n_layers_hidden': 1,\n",
    "                          'encoder_n_units_hidden': 450,\n",
    "                          'encoder_nonlin': 'leaky_relu',\n",
    "                          'encoder_dropout': 0.04288563703094718},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_tvae.feather\")\n",
    "feather.write_feather(syn_long_tvae, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using ddpm -----------------------------------------\n",
    "print('running ddpm')\n",
    "syn_long_ddpm, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'ddpm',   \n",
    "    plugin_kwargs = {'lr': 0.0009375080542687667,\n",
    "                          'batch_size': 2929,\n",
    "                          'num_timesteps': 998,\n",
    "                          'n_iter': 1051,\n",
    "                          'is_classification': True},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_ddpm.feather\")\n",
    "feather.write_feather(syn_long_ddpm, fname)\n",
    "print(len(failures), failures[:1])\n",
    "\n",
    "\n",
    "# ------------------ Generate synthetic datasets using smote -----------------------------------------\n",
    "print('running smote')\n",
    "syn_long_smote, failures = build_all_synthetics_synthcity(\n",
    "    X=X,\n",
    "    split_seeds=split_seeds,\n",
    "    plugin_name = 'smotenc',   \n",
    "    plugin_kwargs = {'k': 5},\n",
    "    task_id=0,\n",
    "    ds_name=ds_name\n",
    ")\n",
    "# Save the combined table once:\n",
    "fname = os.path.join(output_path, f\"{ds_name}_syn_smote.feather\")\n",
    "feather.write_feather(syn_long_smote, fname)\n",
    "print(len(failures), failures[:1])\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
