{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "PROJECT_HOME = \"ABSOLUTE_PATH_TO_THE_ROOT\"\n",
    "sys.path.append(os.path.join(PROJECT_HOME, \"code\"))\n",
    "\n",
    "import json\n",
    "import math\n",
    "import multiprocessing\n",
    "from functools import partial\n",
    "from collections import defaultdict\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from pprint import pprint\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from models.uncertainty.dist_match.tree import DistMatchQRF\n",
    "from models.uncertainty.dist_match.utils import match_ks_stat, match_ks_p_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "MATCH_THRESHOLD = 0.6\n",
    "\n",
    "\n",
    "def matcher(x1, x2):\n",
    "    return match_ks_stat(x1, x2) < MATCH_THRESHOLD\n",
    "\n",
    "\n",
    "def get_qrf(path: str) -> DistMatchQRF:\n",
    "    qrf = DistMatchQRF(\n",
    "        alpha=0.1,\n",
    "        n_quantile_bins=10,\n",
    "        feature_dim=-1,\n",
    "        matcher=matcher,\n",
    "        match_mask=None,\n",
    "        n_trees=10,\n",
    "        bagging_ratio=0.9,\n",
    "        verbose=False,\n",
    "    )\n",
    "    qrf.load_trees(path)\n",
    "    return qrf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_file_dir(data_key: str, table_key: str) -> str:\n",
    "    return os.path.join(PROJECT_HOME, f\"outputs/{data_key}/wandb/latest-run/files/media/table/Eval_{table_key}_plots\")\n",
    "\n",
    "\n",
    "def read_json(path: str) -> pd.DataFrame:\n",
    "    with open(path, \"r\") as file:\n",
    "        data = json.load(file)\n",
    "    cols = data[\"columns\"]\n",
    "    data = data[\"data\"]\n",
    "    data = [dict(zip(cols, datum)) for datum in data]\n",
    "    return pd.DataFrame.from_records(data)\n",
    "\n",
    "\n",
    "def get_data_df(data_key: str, table_key: str, file_idx: int = 0) -> pd.DataFrame:\n",
    "    file_dir = get_file_dir(data_key, table_key)\n",
    "    filename = os.listdir(file_dir)[file_idx]\n",
    "    return read_json(os.path.join(file_dir, filename))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def max_len_from_same_distro(values):\n",
    "    max_len = 0\n",
    "    n_values = len(values)\n",
    "    for val1 in values:\n",
    "        cur_len = 0\n",
    "        for val2 in values:\n",
    "            cur_len += int(match_ks_p_val(val1, val2).item() > 0.05)\n",
    "\n",
    "        max_len = max(cur_len, max_len)\n",
    "        if max_len == n_values:\n",
    "            break\n",
    "    return max_len\n",
    "\n",
    "\n",
    "def study_qrf_ood_samples(data, qrf, patch_len: int = 100):\n",
    "    n_values_ratio = 0\n",
    "    n_leaves = 0\n",
    "    n_ood_values = 0\n",
    "\n",
    "    calib_ids = data[\"step\"] >= 0\n",
    "    calib_len = len(calib_ids)\n",
    "\n",
    "    n_trees = len(qrf.trees)\n",
    "\n",
    "    for tree in qrf.trees:\n",
    "        node = tree.leaf_nodes[0] # 0 -- leftmost node\n",
    "        values = node.get_values()[0]\n",
    "        max_len = max_len_from_same_distro(values)\n",
    "        n_values = len(values)\n",
    "\n",
    "        print(f\"{max_len} / {n_values}\")\n",
    "        n_values_ratio += max_len / n_values\n",
    "        n_leaves += len(tree.leaf_nodes)\n",
    "        n_ood_values += n_values - max_len\n",
    "    \n",
    "    n_leaves /= n_trees\n",
    "    n_values_ratio /= n_trees\n",
    "    n_ood_values /= n_trees\n",
    "    return n_values_ratio, n_leaves, calib_len, n_ood_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "169 / 252\n",
      "166 / 247\n",
      "154 / 236\n",
      "163 / 257\n",
      "160 / 240\n",
      "163 / 235\n",
      "162 / 250\n",
      "168 / 249\n",
      "160 / 238\n",
      "160 / 246\n",
      "[Elec] same_distro: 0.66; n_leaves: 245.0; calib_len: 1378; n_ood: 82.5; ood_ratio: 0.06\n",
      "62 / 75\n",
      "61 / 69\n",
      "55 / 70\n",
      "55 / 67\n",
      "62 / 74\n",
      "56 / 69\n",
      "59 / 74\n",
      "59 / 73\n",
      "53 / 68\n",
      "48 / 65\n",
      "[Solar] same_distro: 0.81; n_leaves: 70.4; calib_len: 800; n_ood: 13.4; ood_ratio: 0.02\n",
      "364 / 721\n",
      "366 / 721\n",
      "369 / 720\n",
      "364 / 709\n",
      "347 / 708\n",
      "351 / 735\n",
      "356 / 729\n",
      "351 / 689\n",
      "354 / 718\n"
     ]
    }
   ],
   "source": [
    "PATCH_LEN = 100\n",
    "TREE_DIR = os.path.join(PROJECT_HOME, \"models_save/uc/\")\n",
    "DATA_MAP = {\n",
    "    \"Elec\": (\n",
    "        # \"qrf_ks_stat<0.1_error_normal_electricelectricity-normalized_darts-forest|100.pkl\",\n",
    "        \"qrf_ks_stat<0.01_error_normal_electricelectricity-normalized_darts-forest|100.pkl\",\n",
    "        #\"darts_forest_dist_match_enbPI_electric_s20_260126_140743\",\n",
    "        \"darts_forest_dist_match_enbPI_electric_s20_270126_170906\",\n",
    "        \"electricelectricity-normalized\",\n",
    "        0,\n",
    "    ),\n",
    "    \"Solar\": (\n",
    "        # \"qrf_ks_stat<0.1_error_normal_solarSolar_Atl_data_aligned_darts-forest|100.pkl\",\n",
    "        \"qrf_ks_stat<0.01_error_normal_solarSolar_Atl_data_aligned_darts-forest|100.pkl\",\n",
    "        # \"darts_forest_dist_match_enbPI_solar_atlanta_s20_270126_174558\",\n",
    "        \"darts_forest_dist_match_enbPI_solar_atlanta_s20_270126_174528\",\n",
    "        \"solarSolar_Atl_data_aligned\",\n",
    "        0,\n",
    "        # 1,\n",
    "    ),\n",
    "    \"Wind\": (\n",
    "        # \"qrf_ks_stat<0.1_error_normal_windWind_Hackberry_Generation_2019_2020_darts-forest|100.pkl\",\n",
    "        \"qrf_ks_stat<0.01_error_normal_windWind_Hackberry_Generation_2019_2020_darts-forest|100.pkl\",\n",
    "        # \"darts_forest_dist_match_enbPI_wind_s20_260126_140743\",\n",
    "        \"darts_forest_dist_match_enbPI_wind_s20_270126_170906\",\n",
    "        \"windWind_Hackberry_Generation_2019_2020\",\n",
    "        0,\n",
    "    ),\n",
    "    \"META\": (\n",
    "        # \"qrf_ks_stat<0.1_error_normal_stockMETA_5m_darts-forest|100.pkl\",\n",
    "        \"qrf_ks_stat<0.01_error_normal_stockMETA_5m_darts-forest|100.pkl\",\n",
    "        # \"darts_forest_dist_match_stock_meta_5m_s20_260126_140743\",\n",
    "        \"darts_forest_dist_match_stock_meta_5m_s20_270126_170906\",\n",
    "        \"stockMETA_5m\",\n",
    "        0,\n",
    "    ),\n",
    "    \"NVDA\": (\n",
    "        # \"qrf_ks_stat<0.1_error_normal_stockNVDA_5m_darts-forest|100.pkl\",\n",
    "        \"qrf_ks_stat<0.01_error_normal_stockNVDA_5m_darts-forest|100.pkl\",\n",
    "        # \"darts_forest_dist_match_stock_nvda_5m_s20_260126_140743\",\n",
    "        \"darts_forest_dist_match_stock_nvda_5m_s20_270126_170905\",\n",
    "        \"stockNVDA_5m\",\n",
    "        0,\n",
    "    ),\n",
    "    # \"PEMS\": (\n",
    "    #     \"qrf_ks_error_normal_data_darts-forest|100.pkl\",\n",
    "    #     \"darts_forest_dist_match_pems_s20_180725_164234\",\n",
    "    #     \"pems2022_01\",\n",
    "    #     0\n",
    "    # ),\n",
    "    \"rain\": (\n",
    "        # \"qrf_ks_stat<0.1_error_normal_raindaily_weather_darts-forest|100.pkl\",\n",
    "        \"qrf_ks_stat<0.01_error_normal_raindaily_weather_darts-forest|100.pkl\",\n",
    "        # \"darts_forest_dist_match_rain_s20_280126_161125\",\n",
    "        \"darts_forest_dist_match_rain_s20_280126_161121\",\n",
    "        \"raindaily_weather\",\n",
    "        0\n",
    "    )\n",
    "}\n",
    "\n",
    "for data_type, (qrf_path, data_key, table_key, file_idx) in DATA_MAP.items():\n",
    "    qrf = get_qrf(os.path.join(TREE_DIR, qrf_path))\n",
    "    data = get_data_df(data_key, table_key, file_idx)\n",
    "    n_values_ratio, n_leaves, calib_len, n_ood_values = study_qrf_ood_samples(\n",
    "        data, qrf, PATCH_LEN\n",
    "    )\n",
    "    print(f\"[{data_type}] same_distro: {n_values_ratio:.2}; n_leaves: {n_leaves}; calib_len: {calib_len}; n_ood: {n_ood_values}; ood_ratio: {n_ood_values / calib_len:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
