{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3b80882f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Importting clipped adam...\n"
     ]
    }
   ],
   "source": [
    "from uci_datasets import Dataset\n",
    "from uci_datasets import all_datasets\n",
    "import numpy as np \n",
    "import torch\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "\n",
    "from bartpy.sklearnmodel import SklearnModel\n",
    "from sklearn.neural_network import MLPRegressor \n",
    "from variationalRegressionTree import variationalRegressionTree, Node, CVTree\n",
    "from tqdm import tqdm \n",
    "\n",
    "import matplotlib.pyplot as plt \n",
    "from matplotlib import style \n",
    "\n",
    "from time import sleep\n",
    "\n",
    "import pandas as pd \n",
    "\n",
    "from bartpy.sklearnmodel import SklearnModel\n",
    "from sklearn.neural_network import MLPRegressor \n",
    "\n",
    "from tqdm import tqdm \n",
    "\n",
    "style.use('ggplot')\n",
    "\n",
    "n_obs_min = 0\n",
    "n_obs_max = 5000\n",
    "datasets = [name for name, (n_observations, n_dimensions) in all_datasets.items() if n_observations < n_obs_max and n_observations > n_obs_min and name != 'gas' and name != 'challenger' and name != 'energy']\n",
    "n_obs = [n_observations for name, (n_observations, n_dimensions) in all_datasets.items() if n_observations < n_obs_max and n_observations > n_obs_min and name != 'gas' and name != 'challenger' and name != 'energy']\n",
    "n_dims = [n_dimensions for name, (n_observations, n_dimensions) in all_datasets.items() if n_observations < n_obs_max and n_observations > n_obs_min and name != 'gas' and name != 'challenger' and name != 'energy']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7179a882",
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Seed = 0\n",
      "autompg dataset, N=392, d=7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 65.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VaRT_median    0.077372\n",
      "VaRT_low       0.071097\n",
      "VaRT_high      0.092258\n",
      "BART           0.068509\n",
      "MLP            0.082336\n",
      "Name: autompg, dtype: float64\n",
      "concreteslump dataset, N=103, d=7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 97.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VaRT_median    0.050613\n",
      "VaRT_low       0.033914\n",
      "VaRT_high      0.073346\n",
      "BART           0.141219\n",
      "MLP            0.246171\n",
      "Name: concreteslump, dtype: float64\n",
      "forest dataset, N=517, d=12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[1;32mIn [2]\u001b[0m, in \u001b[0;36m<cell line: 3>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     54\u001b[0m λ \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1e-3\u001b[39m\n\u001b[0;32m     55\u001b[0m VaRT \u001b[38;5;241m=\u001b[39m variationalRegressionTree(depth, x_train, y_train, device\u001b[38;5;241m=\u001b[39mDEVICE)\n\u001b[1;32m---> 56\u001b[0m \u001b[43mVaRT\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43mlr0\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mlrf\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mh1\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mλ\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mh2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43mλ\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     58\u001b[0m rmse_list \u001b[38;5;241m=\u001b[39m [] \n\u001b[0;32m     59\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m100\u001b[39m)):\n",
      "File \u001b[1;32m~\\Documents\\VaRT_Supplementary\\variationalRegressionTree.py:475\u001b[0m, in \u001b[0;36mvariationalRegressionTree.train\u001b[1;34m(self, epochs, lr0, lrf, clip_norm, optimizer, h1, h2, tree_num)\u001b[0m\n\u001b[0;32m    473\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m    474\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msample_splits()\n\u001b[1;32m--> 475\u001b[0m loss, likelihood, cross_entropy, entropy \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcalculate_negative_ELBO\u001b[49m\u001b[43m(\u001b[49m\u001b[43mh1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mh2\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    476\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tree_num \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    477\u001b[0m     t\u001b[38;5;241m.\u001b[39mset_description(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIteration \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepochs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[1;32m~\\Documents\\VaRT_Supplementary\\variationalRegressionTree.py:496\u001b[0m, in \u001b[0;36mvariationalRegressionTree.calculate_negative_ELBO\u001b[1;34m(self, h1, h2)\u001b[0m\n\u001b[0;32m    494\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcalculate_negative_ELBO\u001b[39m(\u001b[38;5;28mself\u001b[39m, h1\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m, h2 \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m):\n\u001b[1;32m--> 496\u001b[0m     likelihood \u001b[38;5;241m=\u001b[39m   \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcalculate_full_likelihood\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    497\u001b[0m     cross_entropy \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcalculate_cross_entropy_vectorized()  \n\u001b[0;32m    498\u001b[0m     entropy \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcalculate_entropy_vectorized()\n",
      "File \u001b[1;32m~\\Documents\\VaRT_Supplementary\\variationalRegressionTree.py:505\u001b[0m, in \u001b[0;36mvariationalRegressionTree.calculate_full_likelihood\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    504\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcalculate_full_likelihood\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m--> 505\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcalculate_likelihood_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\Documents\\VaRT_Supplementary\\variationalRegressionTree.py:519\u001b[0m, in \u001b[0;36mvariationalRegressionTree.calculate_likelihood_batch\u001b[1;34m(self, X, y)\u001b[0m\n\u001b[0;32m    516\u001b[0m error \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(weights\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m*\u001b[39m(y\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m-\u001b[39m mean_prediction_vector), \u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m    517\u001b[0m error \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m0.5\u001b[39m\u001b[38;5;241m*\u001b[39merror\n\u001b[1;32m--> 519\u001b[0m log_var \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m0.5\u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mξ\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mview\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mview\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;241m-\u001b[39m (\u001b[38;5;241m1\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m2\u001b[39m)\u001b[38;5;241m*\u001b[39mtorch\u001b[38;5;241m.\u001b[39mlog(torch\u001b[38;5;241m.\u001b[39mtensor(\u001b[38;5;241m2\u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m3.14159\u001b[39m)\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice))\n\u001b[0;32m    520\u001b[0m O \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mclip(O, \u001b[38;5;28mmin\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-6\u001b[39m, \u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e9\u001b[39m)\n\u001b[0;32m    521\u001b[0m leaf_indicators \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mclip(leaf_indicators, \u001b[38;5;28mmin\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-6\u001b[39m, \u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e9\u001b[39m)\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# we ran the experiments in batches for the seeds [0,42,13,31415,1]\n",
    "seeds = [0,42,13,31415,1]\n",
    "for seed in seeds:\n",
    "\n",
    "    print(f\"Seed = {seed}\")\n",
    "\n",
    "    torch.manual_seed(seed)\n",
    "    np.random.seed(seed)\n",
    "\n",
    "\n",
    "    rmse_df = pd.DataFrame(np.zeros((5,len(datasets))),columns = datasets, index = ['VaRT_median', 'VaRT_low', 'VaRT_high','BART','MLP'])\n",
    "    for dataset,obs,dim in zip(datasets,n_obs, n_dims):\n",
    "        \n",
    "        DEVICE = 'cpu'\n",
    "        if obs <= 120:\n",
    "            epochs = 2000\n",
    "            depth = 4\n",
    "        elif obs <= 600 and obs > 120:\n",
    "            depth = 5\n",
    "            epochs = 2000\n",
    "        elif obs > 600 and obs <= 800:\n",
    "            depth = 6\n",
    "            epochs = 2000\n",
    "        else:\n",
    "            DEVICE = 'cuda'\n",
    "            depth = 6\n",
    "            epochs = 2000\n",
    "            \n",
    " \n",
    "        data = Dataset(f\"{dataset}\")\n",
    "\n",
    "        x_train, y_train, x_test, y_test = data.get_split(split = 0)\n",
    "\n",
    "\n",
    "        scaler_features = MinMaxScaler(feature_range = (0,1))\n",
    "        scaler_target = MinMaxScaler(feature_range = (0,1))\n",
    "\n",
    "        scaler_features.fit(x_train)\n",
    "        scaler_target.fit(y_train)\n",
    "\n",
    "        x_train, x_test = scaler_features.transform(x_train), scaler_features.transform(x_test)\n",
    "        y_train, y_test = scaler_target.transform(y_train), scaler_target.transform(y_test)\n",
    "\n",
    "        # Adding the intercept term \n",
    "        x_train = np.hstack((np.ones(x_train.shape[0]).reshape(-1,1), x_train))\n",
    "        x_test = np.hstack((np.ones(x_test.shape[0]).reshape(-1,1), x_test))\n",
    "\n",
    "        x_train = torch.from_numpy(x_train).to(DEVICE)\n",
    "        x_test = torch.from_numpy(x_test).to(DEVICE)\n",
    "\n",
    "        y_train = torch.from_numpy(y_train).to(DEVICE)\n",
    "        y_test = torch.from_numpy(y_test).to(DEVICE)\n",
    "\n",
    "        λ = 1e-3\n",
    "        VaRT = variationalRegressionTree(depth, x_train, y_train, device=DEVICE)\n",
    "        VaRT.train(epochs=epochs,lr0=1e-1,lrf=1e-3,h1=λ, h2=-λ)\n",
    "\n",
    "        rmse_list = [] \n",
    "        for _ in tqdm(range(100)):\n",
    "            yhat_VaRT = VaRT.predict(x_test, samples=10)\n",
    "            rmse_VaRT = VaRT.rmse(yhat_VaRT, y_test)\n",
    "            rmse_list.append(rmse_VaRT)\n",
    "\n",
    "        rmse_list = np.array(rmse_list)\n",
    "        rmse_VaRT_median = np.quantile(rmse_list, .50)\n",
    "        rmse_VaRT_min = np.quantile(rmse_list, .10)\n",
    "        rmse_VaRT_max = np.quantile(rmse_list, .90)\n",
    "\n",
    "        BART = SklearnModel(n_trees = 50)\n",
    "        MLP = MLPRegressor(hidden_layer_sizes=(100,50))\n",
    "\n",
    "        MLP.fit(x_train.cpu().numpy(), y_train.cpu().numpy().reshape(-1))\n",
    "        BART.fit(x_train.cpu().numpy(), y_train.cpu().numpy().reshape(-1))\n",
    "\n",
    "        yhat_BART = torch.tensor(BART.predict(x_test.cpu().numpy()))\n",
    "        yhat_MLP = torch.tensor(MLP.predict(x_test.cpu().numpy()))\n",
    "\n",
    "        rmse_BART = VaRT.rmse(yhat_BART, y_test)\n",
    "        rmse_MLP = VaRT.rmse(yhat_MLP, y_test)\n",
    "\n",
    "        rmse = [rmse_VaRT_median, rmse_VaRT_min, rmse_VaRT_max, rmse_BART, rmse_MLP]\n",
    "\n",
    "        rmse_df[dataset] = rmse\n",
    "\n",
    "        print(rmse_df[dataset])\n",
    "\n",
    "        rmse_df.T.to_excel(f\"run_{seed}.xlsx\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f34e4454",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63204020",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
