{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AvwwEfWfDuea"
   },
   "source": [
    "# Generating Tabular Data via XGBoost Models with Flow-Matching\n",
    "\n",
    "This notebook is a self-contained example showing how to train the novel Forest-Flow method to generate tabular data. The idea behind Forest-Flow is to learn Flow-Matching's vector field with XGBoost models instead of neural networks. The motivation is that it is known that Forests work currently better than neural networks on Tabular data tasks. This idea comes with some difficulties, for instance how to approximate Flow Matching's loss, and this notebook shows how to do it on a minimal example. The method, its training procedure and the experiments are described in [(Jolicoeur-Martineau et al. 2023)](https://arxiv.org/abs/2309.09968). The full code can be found [here](https://github.com/SamsungSAILMontreal/ForestDiffusion). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nW9nJoK3wMWM"
   },
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zEBSd1b7HVVG"
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "from functools import partial\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import xgboost as xgb\n",
    "from joblib import Parallel, delayed\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "\n",
    "from torchcfm.conditional_flow_matching import ConditionalFlowMatcher"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set seed\n",
    "seed = 1980\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "torch.backends.cudnn.benchmark = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9tIIxR1eHTsm"
   },
   "source": [
    "As example, we use [Iris](https://en.wikipedia.org/wiki/Iris_flower_data_set), a classic tabular dataset about flowers with 150 observations, 4 input continuous variables (sepal length, sepal width, petal length, and petal width), and 1 categorical outcome variable (3 categories of flowers; setosa, versicolor, and virginica)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4WJCRVrqDVIw"
   },
   "outputs": [],
   "source": [
    "# Iris: numpy dataset with 4 variables (all numerical) and 1 outcome (categorical; 3 categories)\n",
    "my_data = load_iris()\n",
    "X, y = my_data[\"data\"], my_data[\"target\"]\n",
    "\n",
    "# shuffle the observations\n",
    "new_perm = np.random.permutation(X.shape[0])\n",
    "np.take(X, new_perm, axis=0, out=X)\n",
    "np.take(y, new_perm, axis=0, out=y)\n",
    "\n",
    "# Save data before adding missing values\n",
    "X_true, y_true = copy.deepcopy(X), copy.deepcopy(y)\n",
    "Xy_true = np.concatenate((X_true, np.expand_dims(y_true, axis=1)), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "jjIE966SK2Kr",
    "outputId": "9e766c37-6153-4da1-cfc7-dc339db4aee7"
   },
   "outputs": [],
   "source": [
    "X[0:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "uy5sn6vaLsQq",
    "outputId": "7a55e318-b7d4-4d59-ea0e-2236ca2adcd8"
   },
   "outputs": [],
   "source": [
    "y[0:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "t9PiFTtrIWgl"
   },
   "source": [
    "We set the hyperparameters here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "TDYP6yz-IHBt"
   },
   "outputs": [],
   "source": [
    "# Main hyperparameters\n",
    "n_t = 50  # number of flow steps (higher is better, 50 is enough for great performance)\n",
    "duplicate_K = 100  # number of different noise sample per real data sample (higher is better)\n",
    "\n",
    "# XGBoost hyperparameters\n",
    "max_depth = 7\n",
    "n_estimators = 100\n",
    "eta = 0.3\n",
    "tree_method = \"hist\"\n",
    "reg_lambda = 0.0\n",
    "reg_alpha = 0.0\n",
    "subsample = 1.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xlDzllCzt-AP"
   },
   "source": [
    "We do the data preprocessing, which includes min/max normalization and extracting the $x(t)$, $y$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6xamj-lqIHG5"
   },
   "outputs": [],
   "source": [
    "# Save min/max of the values\n",
    "X_min = np.nanmin(X, axis=0, keepdims=1)\n",
    "X_max = np.nanmax(X, axis=0, keepdims=1)\n",
    "\n",
    "# Min-Max scaling of the variables\n",
    "scaler = MinMaxScaler(feature_range=(-1, 1))\n",
    "X_scaled = scaler.fit_transform(X)\n",
    "\n",
    "# Save shape\n",
    "b, c = X.shape\n",
    "\n",
    "# we duplicate the data multiple times, so that X0 is k times bigger, so that we can have k random noise z associated per sample\n",
    "X1 = np.tile(X_scaled, (duplicate_K, 1))\n",
    "\n",
    "# Generate noise data\n",
    "X0 = np.random.normal(size=X1.shape)\n",
    "\n",
    "# Saving the freqency of the classes and storing label masks for later\n",
    "y_uniques, y_probs = np.unique(y, return_counts=True)\n",
    "y_probs = y_probs / np.sum(y_probs)\n",
    "mask_y = {}  # mask for which observations has a specific value of y\n",
    "for i in range(len(y_uniques)):\n",
    "    mask_y[y_uniques[i]] = np.zeros(b, dtype=bool)\n",
    "    mask_y[y_uniques[i]][y == y_uniques[i]] = True\n",
    "    mask_y[y_uniques[i]] = np.tile(mask_y[y_uniques[i]], (duplicate_K))\n",
    "n_y = len(y_uniques)  # number of classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SkAnNB_tn2Tt"
   },
   "outputs": [],
   "source": [
    "# Build [X(t), y] at multiple values of t\n",
    "\n",
    "# Define Independent Conditional Flow Matching (I-CFM)\n",
    "FM = ConditionalFlowMatcher(sigma=0.0)\n",
    "\n",
    "# Time levels\n",
    "t_levels = np.linspace(1e-3, 1, num=n_t)\n",
    "\n",
    "# Interpolation between x0 and x1 (xt)\n",
    "X_train = np.zeros((n_t, X0.shape[0], X0.shape[1]))  # [n_t, b, c]\n",
    "\n",
    "# Output to predict (ut)\n",
    "y_train = np.zeros((n_t, X0.shape[0], X0.shape[1]))  # [n_t, b, c]\n",
    "\n",
    "# Fill with xt and ut\n",
    "for i in range(n_t):\n",
    "    t = torch.ones(X0.shape[0]) * t_levels[i]  # current t\n",
    "    _, xt, ut = FM.sample_location_and_conditional_flow(\n",
    "        torch.from_numpy(X0), torch.from_numpy(X1), t=t\n",
    "    )\n",
    "    X_train[i], y_train[i] = xt.numpy(), ut.numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1f-NfQ1DuKYv"
   },
   "source": [
    "We train the $ckn_t$ XGBoost models, where $k$ is the number of classes (3), $c$ is the number of input variables (4), and $n_t$ is the number of time levels (50) ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bB0kQI8zsSpB"
   },
   "outputs": [],
   "source": [
    "# Function used for training one model\n",
    "\n",
    "\n",
    "def train_parallel(X_train, y_train):\n",
    "    model = xgb.XGBRegressor(\n",
    "        n_estimators=n_estimators,\n",
    "        objective=\"reg:squarederror\",\n",
    "        eta=eta,\n",
    "        max_depth=max_depth,\n",
    "        reg_lambda=reg_lambda,\n",
    "        reg_alpha=reg_alpha,\n",
    "        subsample=subsample,\n",
    "        seed=666,\n",
    "        tree_method=tree_method,\n",
    "        device=\"cpu\",\n",
    "    )\n",
    "\n",
    "    y_no_miss = ~np.isnan(y_train)\n",
    "    model.fit(X_train[y_no_miss, :], y_train[y_no_miss])\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "FR6zUk0asDQ5",
    "outputId": "878152dc-9b1d-432c-bc18-2b6bee5dfbc1"
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "# Train all model(s); fast if you have a decent multi-core CPU, but extremely slow on Google Colab because it uses a weak 2-core CPU\n",
    "\n",
    "\n",
    "regr = Parallel(n_jobs=-1)(  # using all cpus\n",
    "    delayed(train_parallel)(\n",
    "        X_train.reshape(n_t, b * duplicate_K, c)[i][mask_y[j], :],\n",
    "        y_train.reshape(n_t, b * duplicate_K, c)[i][mask_y[j], k],\n",
    "    )\n",
    "    for i in range(n_t)\n",
    "    for j in y_uniques\n",
    "    for k in range(c)\n",
    ")\n",
    "\n",
    "# Replace fits with doubly loops to make things easier\n",
    "regr_ = [[[None for k in range(c)] for i in range(n_t)] for j in y_uniques]\n",
    "current_i = 0\n",
    "for i in range(n_t):\n",
    "    for j in range(len(y_uniques)):\n",
    "        for k in range(c):\n",
    "            regr_[j][i][k] = regr[current_i]\n",
    "            current_i += 1\n",
    "regr = regr_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "v3arcvLGum4X"
   },
   "source": [
    "We generate data by solving the ODE."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xk85Vtbvurx3"
   },
   "outputs": [],
   "source": [
    "batch_size = 150  # number of generated samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "o77nB480vQNO"
   },
   "outputs": [],
   "source": [
    "# Return the flow at time t using the XGBoost models\n",
    "\n",
    "\n",
    "def my_model(t, xt, mask_y=None):\n",
    "    # xt is [b*c]\n",
    "    xt = xt.reshape(xt.shape[0] // c, c)  # [b, c]\n",
    "\n",
    "    # Output from the models\n",
    "    out = np.zeros(xt.shape)  # [b, c]\n",
    "    i = int(round(t * (n_t - 1)))\n",
    "    for j, label in enumerate(y_uniques):\n",
    "        for k in range(c):\n",
    "            out[mask_y[label], k] = regr[j][i][k].predict(xt[mask_y[label], :])\n",
    "\n",
    "    out = out.reshape(-1)  # [b*c]\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XCSiPfy-topJ"
   },
   "outputs": [],
   "source": [
    "# Simple Euler ODE solver (nothing fancy)\n",
    "\n",
    "\n",
    "def euler_solve(x0, my_model, N=100):\n",
    "    h = 1 / (N - 1)\n",
    "    x_fake = x0\n",
    "    t = 0\n",
    "    # from t=0 to t=1\n",
    "    for i in range(N - 1):\n",
    "        x_fake = x_fake + h * my_model(t=t, xt=x_fake)\n",
    "        t = t + h\n",
    "    return x_fake"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ap2rI0zhs7Fp"
   },
   "outputs": [],
   "source": [
    "# Generate prior noise\n",
    "x0 = np.random.normal(size=(batch_size, c))\n",
    "\n",
    "# Generate random labels for the outcome\n",
    "label_y_fake = y_uniques[np.argmax(np.random.multinomial(1, y_probs, size=x0.shape[0]), axis=1)]\n",
    "mask_y_fake = {}  # mask for which observations has a specific value of y\n",
    "for i in range(len(y_uniques)):\n",
    "    mask_y_fake[y_uniques[i]] = np.zeros(x0.shape[0], dtype=bool)\n",
    "    mask_y_fake[y_uniques[i]][label_y_fake == y_uniques[i]] = True\n",
    "\n",
    "# ODE solve\n",
    "ode_solved = euler_solve(\n",
    "    my_model=partial(my_model, mask_y=mask_y_fake), x0=x0.reshape(-1), N=n_t\n",
    ")  # [t, b*c]\n",
    "solution = ode_solved.reshape(batch_size, c)  # [b, c]\n",
    "\n",
    "# invert the min-max normalization\n",
    "solution = scaler.inverse_transform(solution)\n",
    "\n",
    "# clip to min/max values\n",
    "small = (solution < X_min).astype(float)\n",
    "solution = small * X_min + (1 - small) * solution\n",
    "big = (solution > X_max).astype(float)\n",
    "solution = big * X_max + (1 - big) * solution\n",
    "\n",
    "# Concatenate the y label\n",
    "Xy_fake = np.concatenate((solution, np.expand_dims(label_y_fake, axis=1)), axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "W0n4u7bc0uOB"
   },
   "source": [
    "We just generated fake tabular data! Lets now compare the two dataset (real vs fake)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "-Vw__-M2yKK7",
    "outputId": "d7ca83b0-1c1e-474c-b5c2-f9cf86e60570"
   },
   "outputs": [],
   "source": [
    "Xy_true[0:10]  # Real data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "tRXq9n3QvwsW",
    "outputId": "2803303c-c42e-43d8-aaee-da38458bb385"
   },
   "outputs": [],
   "source": [
    "Xy_fake[0:10]  # Flow generated data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 455
    },
    "id": "teqKVOb-xgsk",
    "outputId": "2c426c9b-f9e9-46bc-f7fc-926869552a5d"
   },
   "outputs": [],
   "source": [
    "_, (ax1, ax2) = plt.subplots(2)\n",
    "# Real data\n",
    "scatter = ax1.scatter(Xy_true[:, 0], Xy_true[:, 1], c=Xy_true[:, -1])\n",
    "ax1.set(\n",
    "    xlabel=my_data.feature_names[0], ylabel=my_data.feature_names[1], xlim=(4, 8), ylim=(2, 4.5)\n",
    ")\n",
    "_ = ax1.legend(\n",
    "    scatter.legend_elements()[0], my_data.target_names, loc=\"lower right\", title=\"Classes\"\n",
    ")\n",
    "# Fake data\n",
    "scatter = ax2.scatter(Xy_fake[:, 0], Xy_fake[:, 1], c=Xy_fake[:, -1])\n",
    "ax2.set(\n",
    "    xlabel=my_data.feature_names[0], ylabel=my_data.feature_names[1], xlim=(4, 8), ylim=(2, 4.5)\n",
    ")\n",
    "_ = ax2.legend(\n",
    "    scatter.legend_elements()[0], my_data.target_names, loc=\"lower right\", title=\"Classes\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qgMj9KxODUyP"
   },
   "source": [
    "Below we show how to do the same with the [ForestDiffusion pip package](https://github.com/SamsungSAILMontreal/ForestDiffusion)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "iwsuf2VP6HxI",
    "outputId": "6ac4e540-e272-4b1f-bae7-03ac014d2abc"
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "from ForestDiffusion import ForestDiffusionModel as ForestFlowModel\n",
    "\n",
    "forest_model = ForestFlowModel(\n",
    "    X,\n",
    "    label_y=y,\n",
    "    n_t=50,\n",
    "    duplicate_K=100,\n",
    "    bin_indexes=[],\n",
    "    cat_indexes=[],\n",
    "    int_indexes=[],\n",
    "    diffusion_type=\"flow\",\n",
    "    n_jobs=-1,\n",
    "    seed=1,\n",
    ")\n",
    "Xy_fake_ = forest_model.generate(batch_size=X.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 455
    },
    "id": "yGxkJT3g6TJ5",
    "outputId": "788ee11f-7cf8-4d9b-b300-a4b243559fad"
   },
   "outputs": [],
   "source": [
    "_, (ax1, ax2) = plt.subplots(2)\n",
    "# Real data\n",
    "scatter = ax1.scatter(Xy_true[:, 0], Xy_true[:, 1], c=Xy_true[:, -1])\n",
    "ax1.set(\n",
    "    xlabel=my_data.feature_names[0], ylabel=my_data.feature_names[1], xlim=(4, 8), ylim=(2, 4.5)\n",
    ")\n",
    "_ = ax1.legend(\n",
    "    scatter.legend_elements()[0], my_data.target_names, loc=\"lower right\", title=\"Classes\"\n",
    ")\n",
    "# Fake data\n",
    "scatter = ax2.scatter(Xy_fake_[:, 0], Xy_fake_[:, 1], c=Xy_fake_[:, -1])\n",
    "ax2.set(\n",
    "    xlabel=my_data.feature_names[0], ylabel=my_data.feature_names[1], xlim=(4, 8), ylim=(2, 4.5)\n",
    ")\n",
    "_ = ax2.legend(\n",
    "    scatter.legend_elements()[0], my_data.target_names, loc=\"lower right\", title=\"Classes\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "torchcfm",
   "language": "python",
   "name": "torchcfm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}