{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "BASE_DIR = \"ABSOLUTE_PATH_TO_THE_ROOT\"\n",
    "DATA_DIR = os.path.join(BASE_DIR, \"data\")\n",
    "CODE_DIR = os.path.join(BASE_DIR, \"code\")\n",
    "FC_DIR = os.path.join(BASE_DIR, 'models_save/fc')\n",
    "UC_DIR = os.path.join(BASE_DIR, 'models_save/uc')\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(CODE_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.decomposition import PCA, KernelPCA\n",
    "from sklearn.neighbors import NeighborhoodComponentsAnalysis\n",
    "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from models.uncertainty.dist_match.tree import DistMatchQRF\n",
    "from models.uncertainty.dist_match.utils import match_ks_raw\n",
    "\n",
    "from sklearn_quantile import RandomForestQuantileRegressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def matcher(x1, x2):\n",
    "    return match_ks_raw(x1, x2) > 0.25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_path = os.path.join(UC_DIR, \"qrf_ks_electricelectricity-normalized.pkl\")\n",
    "qrf = DistMatchQRF(\n",
    "    alpha=0.1,\n",
    "    n_quantile_bins=10,\n",
    "    feature_dim=-1,\n",
    "    matcher=matcher,\n",
    "    match_mask=None,\n",
    "    n_trees=10,\n",
    "    bagging_ratio=0.9,\n",
    "    verbose=False,\n",
    ")\n",
    "qrf.load_trees(tree_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "leaf_node = qrf.trees[0].leaf_nodes[-1]\n",
    "xs, ys, _ = leaf_node.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_quantile(xs, ys, x, quantiles=[0.05, 0.95]):\n",
    "    n_samples = len(xs)\n",
    "    assert n_samples == len(ys)\n",
    "    xs = xs.reshape(n_samples, -1)\n",
    "    ys = ys.reshape(n_samples)\n",
    "    x = x.reshape(1, -1)\n",
    "\n",
    "    return (\n",
    "        RandomForestQuantileRegressor(\n",
    "            n_estimators=10,\n",
    "            max_depth=2,\n",
    "            criterion=\"squared_error\",\n",
    "            q=quantiles,\n",
    "        )\n",
    "        .fit(xs, ys)\n",
    "        .predict(x)\n",
    "        .squeeze()\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-2.2461395007371903, 1.8897021752595902, 0.0, 33.6)"
      ]
     },
     "execution_count": 114,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAloAAAG+CAYAAABLZqrjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAANsElEQVR4nO3dS4hcZRrH4a+qq7uT7iSdS3cYHAZnxEG8UONCFxFEXehSstC4FMSFqCAuBHGTZKWgIEIQFyq49AKKWxcqAV3owilMgjAwDAPD0OlEE03n0p0+s5iNRTXjVMq/55zkeZYv5OQNNMkv56tTp1NVVVUAAPjVdeteAADgaiW0AABChBYAQIjQAgAIEVoAACFCCwAgRGgBAIQILQCAEKEFABDSq3sBANjMpZMny5d79w7N9i0vl5mlpZo2gvG5owUAECK0AABChBYAQIjQAgAIEVoAACGeOgSgkaa2bSs3HjkyMoM26VRVVdW9BADA1cjRIQBAiNACAAgRWgAAIUILACBEaAEAhPh6BwAaaW1lpXx1881DsztPnCjTi4s1bQTjE1oANFJVVWVtZWVkBm3i6BAAIERoAQCECC0AgBChBQAQIrQAAEI8dUgrDAaD2LX7/X7s2gBc29zRAgAIEVoAACFCCwAgRGgBAIQILQCAEKEFABAitAAAQnyPFgCNNDU3V64/eHBkBm3SqaqqqnsJ+CW+sBSANnJ0CAAQIrQAAEKEFgBAiNACAAgRWgAAIb7eAYBGWjt9unxz991Ds9uPHi3Tu3fXtBGMT2gB0EjV5ctl9fjxkRm0iaNDAIAQoQUAECK0AABChBYAQIjQAgAIEVoAACFCCwAgRGgBAIQILQCAEKEFABAitAAAQoQWAECI0AIACOnVvQAAbKa7ZUu57sknR2bQJp2qqqq6l4BfMhgMYtfu9/uxawNwbXN0CAAQIrQAAEKEFgBAiNACAAgRWgAAIb7eAYBGWv/hh/Lt/v1Ds9s++qj0du6sZR+4EkILgEbaWFsrZz7/fGQGbeLoEAAgRGgBAIQILQCAEKEFABAitAAAQoQWAECI0AIACBFaAAAhQgsAIERoAQCECC0AgBChBQAQIrQAAEJ6dS8AAJvpzsyUxYceGplBmwgtABqpt7BQbn3//brXgIk4OgQACBFaAAAhQgsAIERoAQCECC0AgBBPHQLQSOtnzpTvHn98aHbTm2+W3sJCTRvB+IQWAI20celSWfngg6HZn19/vaZt4Mo4OgQACBFaAAAhQgsAIERoAQCECC0AgBChBQAQIrQAAEKEFgBAiNACAAgRWgAAIUILACBEaAEAhAgtAICQXt0LAMBmutPTZeGee0Zm0CZCC4BG6u3cWW7/7LO614CJODoEAAgRWgAAIUILACBEaAEAhAgtAIAQTx0C0EjrP/5Y/v7880OzP730Uult317TRjA+oQVAI21cuFD+9frrQ7PrDx0qRWjRIo4OAQBChBYAQIjQAgAIEVoAACFCCwAgRGgBAIQILQCAEKEFABAitAAAQoQWAECI0AIACBFaAAAhQgsAIKRX9wIAsJnO1FSZu+WWkRm0idACoJGmd+8udx47VvcaMBFHhwAAIUILACBEaAEAhAgtAIAQoQUAEOKpQwAa6fK5c+WfL788NPvDc8+Vqfn5mjaC8QktABrp8upq+cfhw0Oz6556SmjRKo4OAQBChBYAQIjQAgAIEVoAACFCCwAgRGgBAIQILQCAEKEFABAitAAAQoQWAECI0AIACBFaAAAhQgsAIKRX9wIAsJlOp1OmFxdHZtAmQguARppeXCx3nTxZ9xowEUeHAAAhQgsAIERoAQCECC0AgBChBQAQ4qlDABrp8vnz5d9vvz00+91jj5WprVtr2gjGJ7QAaKTLP/1U/vb000OzpQMHhBatIrTgKjcYDKLX7/f70esDtJnPaAEAhAgtAIAQoQUAECK0AABChBYAQIjQAgAIEVoAACFCCwAgRGgBAIQILQCAEKEFABAitAAAQoQWAEBIr+4FAGAzM0tL5Z6qqnsNmIg7WgAAIUILACBEaAEAhAgtAIAQoQUAEOKpQwAaaePixXLq44+HZnsefLB0Z2dr2gjGJ7QAaKT1s2fL8QMHhmb7lpfLzNJSTRvB+BwdAgCECC0AgBChBQAQIrQAAEKEFgBAiNACAAgRWgAAIUILACBEaAEAhAgtAIAQoQUAECK0AABChBYAQIjQAgAI6dW9AABsZnrPnrJveXlkBm0itABopE63W2aWlupeAyYitLhig8Egev1+vx+9PgCk+YwWAECI0AIACBFaAAAhPqMFQCNtXLpUzn7xxdBsx113le7MTE0bwfiEFgCNtH7mTPnrffcNzfYtL3sSkVZxdAgAECK0AABChBYAQIjQAgAIEVoAACFCCwAgRGgBAIQILQCAEKEFABAitAAAQoQWAECI0AIACBFaAAAhQgsAIKRX9wIAsJnpXbvKHd9+OzKDNhFaADRSp9cr87feWvcaMBFHhwAAIUILACBEaAEAhAgtAIAQH4YHoJGq9fWy+t13Q7O5m24qnZ5/umgPP60ANNLa99+Xr2+7bWi2b3m5zCwt1bQRjM/RIQBAiNACAAgRWgAAIUILACBEaAEAhAgtAIAQoQUAECK0AABChBYAQIjQAgAIEVoAACFCCwAgRGgBAIQILQCAkF7dCwDAZnoLC+Uvn346MoM2EVoANFJ3ZqbsvPfeuteAiTg6BAAIEVoAACFCCwAgRGgBAIT4MDwAjVRtbJS1U6eGZtN79pRO1z0C2kNoAdBIa6dOlS/37h2a7VteLjNLSzVtBOPz3wIAgBB3tIBfzWAwiF6/3+9Hrw/wa3NHCwAgRGgBAIQILQCAEKEFABAitAAAQoQWAECI0AIACBFaAAAhQgsAIERoAQCECC0AgBDvOoSaeC8gwNVPaAHQSL0dO8ot7703MoM2EVoANFJ3drYsPfxw3WvARHxGCwAgRGgBAIQILQCAEKEFABAitAAAQjx1CEAjXTp5sny5d+/QbN/ycplZWqppIxifO1oAACFCCwAgRGgBAIQILQCAEKEFABAitAAAQoQWAECI0AIACBFaAAAhQgsAIERoAQCECC0AgBChBQAQIrQAAEJ6dS8AAJuZ2rat3HjkyMgM2kRoAdBIU1u3lt8/9VTda8BEHB0CAIS4owXQEIPBIHr9fr8fvT4wyh0tAIAQoQUAECK0AABCfEYLgEZaW1kpX91889DszhMnyvTiYk0bwfiEFgCNVFVVWVtZGZlBmzg6BAAIEVoAACFCCwAgRGgBAIQILQCAEKEFABDi6x0A/gfvHwQm4Y4WAECI0AIACBFaAAAhQgsAIERoAQCECC0AgBBf7wBAI03NzZXrDx4cmUGbCC0AGmlqfr788dChuteAiTg6BAAIEVoAACFCCwAgRGgBAIQILQCAEE8dAtBIa6dPl2/uvntodvvRo2V69+6aNoLxCS0AGqm6fLmsHj8+MoM2cXQIABAitAAAQoQWAECI0AIACBFaAAAhQgsAIERoAQCECC0AgBChBQAQIrQAAEK8ggeAmMFgcMW/duP06ZHZ8WPHSvdn7zrs9/tXfH34LbijBQAQIrQAAEIcHQLQSJ3Z2TL7yCMjM2gToQVAI3Xm58vcCy/UvQZMxNEhAECI0AIACBFaAAAhQgsAIERoAQCEeOoQgEbaOHu2nHv22aHZ/Kuvlu6OHTVtBOMTWgA00/p6Wf/665EZtImjQwCAEHe04GcmeQHu/8MLcAGuLe5oAQCECC0AgBChBQAQIrQAAEKEFgBAiNACAAgRWgAAIUILACBEaAEAhAgtAIAQoQUAEOJdhwA0Umd6ukzff//IDNpEaAHQSJ3t28u2V16pew2YiKNDAIAQoQUAECK0AABChBYAQIjQAgAI8dQhAI1U/fhjOXf48NBs/uDB0tm+vaaNYHxCC4BGqtbWytonnwzPXnihdGraB66Eo0MAgBChBQAQIrQAAEKEFgBAiA/DX2UGg0H0+v1+P3p9gCvl7z+ayB0tAIAQoQUAECK0AABChBYAQIjQAgAIEVoAACFCCwAgRGgBAIT4wlIAmqnXK7077hiZQZv4iQWgkbo7dpTtb71V9xowEUeHAAAhQgsAIMTRIdBKXiAMtIE7WgAAIUILACDE0SEAjVSdO1fOv/ba0GzrM8+Uzvx8TRvB+IQWAI1UXbxYLr777tBsyxNPCC1axdEhAECI0AIACBFaAAAhQgsAIERoAQCECC0AgBChBQAQ4nu0AK4x3hNJna61nz93tAAAQoQWAECI0AIACBFaAAAhQgsAIMRThwA009RU6d5ww8gM2kRoAdBI3YWFsvDhh3WvARNxdAgAECK0AABChBYAQIjQAgAIEVoAACGNe+rwan3Z5NX65wJIqVZXy4V33hmabXn00dKZm6tpo//6Lf8+/61+L/9G5TQutACglFKqCxfKhTfeGJrNPvJI7aEF43B0CAAQIrQAAEKEFgBAiNACAAgRWgAAIUILACBEaAEAhAgtAIAQoQUAECK0AABChBYAQIjQAgAIEVoAACG9uhcAgE11OqWza9fIDNpEaAHQSN1du8rOzz6rew2YiKNDAIAQoQUAECK0AABChBYAQIjQAgAI8dQhAI1UXbhQLn700dBsdv/+0tmypZ6F4AoILQAaqVpdLedffHFoNvPAA0KLVnF0CAAQIrQAAEKEFgBAiNACAAgRWgAAIUILACBEaAEAhAgtAIAQoQUAECK0AABChBYAQIjQAgAIEVoAACFCCwAgpFNVVVX3EgAAVyN3tAAAQoQWAECI0AIACBFaAAAhQgsAIERoAQCECC0AgBChBQAQIrQAAEL+A6NseWSNEC+KAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "test_quantiles = get_quantile(xs, ys, xs[11])\n",
    "\n",
    "plt.hist(ys, bins=20, color=\"#d1d1d1\", rwidth=0.9)\n",
    "plt.axvline(test_quantiles[1], ymin=0, ymax=32, color=\"#c00000\", linestyle=\"dashed\", linewidth=3)\n",
    "plt.tight_layout()\n",
    "plt.xticks([]),plt.yticks([])\n",
    "plt.axis('off')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
