{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f651d2ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e05923f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.metrics import mean_squared_error\n",
    "from sklearn.tree import DecisionTreeRegressor\n",
    "\n",
    "from imodels import HSTreeRegressorCV\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acf69d56",
   "metadata": {},
   "outputs": [],
   "source": [
    "def simulation1(n=500, noise=\"gaussian\"):\n",
    "    X = np.random.uniform(0, 1, (n, 50))\n",
    "    y = X[:, :10].sum(axis=1)\n",
    "    if noise == \"gaussian\":\n",
    "        y += np.random.normal(0, 0.01, n)\n",
    "    elif noise == \"laplacian\":\n",
    "        y += np.random.laplace(0, 0.01, n)\n",
    "    return X, y\n",
    "\n",
    "\n",
    "def simulation2(n=500, noise=\"gaussian\"):\n",
    "    X = np.random.uniform(0, 1, (n, 50))\n",
    "    y = X[:, :10].sum(axis=1) + X[:, 0]*X[:, 1] + X[:, 4]*X[:, 5] + X[:, 10]*X[:, 11]\n",
    "    if noise == \"gaussian\":\n",
    "        y += np.random.normal(0, 0.01, n)\n",
    "    elif noise == \"laplacian\":\n",
    "        y += np.random.laplace(0, 0.01, n)\n",
    "    return X, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "802601bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def experiment(max_leaf_nodes, X_test, y_test, simulation_f, noise=\"gaussian\", N=100):\n",
    "    CART_mse = list()\n",
    "    CART_pred = list()\n",
    "    hsCART_mse = list()\n",
    "    hsCART_pred = list()\n",
    "    hsCART_lambda = list()\n",
    "\n",
    "    for _ in range(N):\n",
    "        X, y = simulation_f(noise=noise)\n",
    "\n",
    "        CART = DecisionTreeRegressor(max_leaf_nodes=max_leaf_nodes)\n",
    "        CART.fit(X, y)\n",
    "        hsCART = HSTreeRegressorCV(deepcopy(CART),  [0.1, 1, 10, 25, 50, 100])\n",
    "        hsCART.fit(X, y)\n",
    "        hsCART_lambda.append(hsCART.reg_param)\n",
    "\n",
    "        y_pred = CART.predict(X_test)\n",
    "        CART_mse.append(mean_squared_error(y_test, y_pred))\n",
    "        CART_pred.append(y_pred)\n",
    "        y_pred = hsCART.predict(X_test)\n",
    "        hsCART_mse.append(mean_squared_error(y_test, y_pred))\n",
    "        hsCART_pred.append(y_pred)\n",
    "\n",
    "    CART_mean_pred = np.array(CART_pred).mean(axis=0)\n",
    "    CART_MSE = np.mean(CART_mse)\n",
    "    CART_bias2 = np.power(CART_mean_pred-y_test, 2).mean()\n",
    "    CART_variance = np.array(CART_pred).var(axis=0).mean()\n",
    "    hsCART_mean_pred = np.array(hsCART_pred).mean(axis=0)\n",
    "    hsCART_MSE = np.mean(hsCART_mse)\n",
    "    hsCART_bias2 = np.power(hsCART_mean_pred-y_test, 2).mean()\n",
    "    hsCART_variance = np.array(hsCART_pred).var(axis=0).mean()\n",
    "    \n",
    "    return CART_MSE, CART_bias2, CART_variance, hsCART_MSE, hsCART_bias2, hsCART_variance, np.mean(hsCART_lambda)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54e7f339",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bias_variance(simulation_function, noise=\"gaussian\", save=False):\n",
    "    simulation_f = simulation1 if simulation_function == \"simulation_1\" else simulation2\n",
    "    \n",
    "    CART_MSE, CART_bias2, CART_variance = list(), list(), list()\n",
    "    hsCART_MSE, hsCART_bias2, hsCART_variance, hsCART_lambda = list(), list(), list(), list()\n",
    "\n",
    "    leaf_nodes = [2, 4, 8, 12, 16, 20, 24, 28, 30, 32] + list(range(40, 151, 10))\n",
    "\n",
    "    X_test, y_test = simulation_f(noise=None)\n",
    "\n",
    "    for max_leaf_nodes in tqdm(leaf_nodes):\n",
    "        (\n",
    "            cart_mse, cart_bias2, cart_variance,\n",
    "            hscart_mse, hscart_bias2, hscart_variance, hscart_lambda\n",
    "        ) = experiment(max_leaf_nodes, X_test, y_test, simulation_f, noise)\n",
    "        CART_MSE.append(cart_mse)\n",
    "        CART_bias2.append(cart_bias2)\n",
    "        CART_variance.append(cart_variance)\n",
    "        hsCART_MSE.append(hscart_mse)\n",
    "        hsCART_bias2.append(hscart_bias2)\n",
    "        hsCART_variance.append(hscart_variance)\n",
    "        hsCART_lambda.append(hscart_lambda)\n",
    "    \n",
    "    fig, ax = plt.subplots()\n",
    "    lns1 = ax.plot(leaf_nodes, CART_MSE, color=\"lightsalmon\", label=\"CART MSE\")\n",
    "    lns2 = ax.plot(leaf_nodes, CART_bias2, color=\"lightsalmon\", linestyle=\"dotted\", label=\"CART Bias2\")\n",
    "    lns3 = ax.plot(leaf_nodes, CART_variance, color=\"lightsalmon\", linestyle=\"--\", label=\"CART Variance\")\n",
    "    lns4 = ax.plot(leaf_nodes, hsCART_MSE, color=\"firebrick\", label=\"hsCART MSE\")\n",
    "    lns5 = ax.plot(leaf_nodes, hsCART_bias2, color=\"firebrick\", linestyle=\"dotted\", label=\"hsCART Bias2\")\n",
    "    lns6 = ax.plot(leaf_nodes, hsCART_variance, color=\"firebrick\", linestyle=\"--\", label=\"hsCART Variance\")\n",
    "    ax.set_xlabel(\"Number of Leaves\")\n",
    "    ax.set_ylabel(\"Error\")\n",
    "    ax2 = ax.twinx()\n",
    "    lns7 = ax2.plot(leaf_nodes, hsCART_lambda, color=\"skyblue\", linestyle=\"dashdot\", label=\"hsCART Lambda\")\n",
    "    ax2.set_ylabel(\"Lambda\")\n",
    "    lns = lns1+lns2+lns3+lns4+lns5+lns6+lns7\n",
    "    labs = [l.get_label() for l in lns]\n",
    "    ax.legend(lns, labs, prop={'size': 6})\n",
    "    if save:\n",
    "        plt.savefig(f\"../graphs/miscelenious/bias_variance/{simulation_function}_{noise}\", bbox_inches=\"tight\", facecolor=\"white\", edgecolor=\"auto\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c01441cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "bias_variance(\"simulation_1\", \"gaussian\", save=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bcb3a1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "bias_variance(\"simulation_1\", \"laplacian\", save=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7963d8fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "bias_variance(\"simulation_2\", \"gaussian\", save=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a443b5b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "bias_variance(\"simulation_2\", \"laplacian\", save=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c010313c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlds",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:16:33) [MSC v.1929 64 bit (AMD64)]"
  },
  "vscode": {
   "interpreter": {
    "hash": "aaea2abca523dde69918745ab0433be939f84f870f3a8b0d9770b38003b437c2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
