{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "BASE_DIR = \"ABSOLUTE_PATH_TO_THE_ROOT\"\n",
    "DATA_DIR = os.path.join(BASE_DIR, \"data\")\n",
    "CODE_DIR = os.path.join(BASE_DIR, \"code\")\n",
    "FC_DIR = os.path.join(BASE_DIR, 'models_save/fc')\n",
    "UC_DIR = os.path.join(BASE_DIR, 'models_save/uc')\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(CODE_DIR)\n",
    "\n",
    "import os\n",
    "import json\n",
    "from collections import defaultdict\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib.gridspec as grid_spec\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from loader.generator import DataGenerator\n",
    "from config import TSDataConfig, TaskConfig\n",
    "from main_utils import _init_fc\n",
    "from omegaconf import DictConfig\n",
    "\n",
    "from models.forcast.darts import SimpleDartsModel\n",
    "from models.forcast.forcast_service import ForcastService\n",
    "from models.forcast.forcast_base import FCPredictionData\n",
    "from models.uncertainty.uc_service import UncertaintyService\n",
    "from models.uncertainty.dist_match.tree import DistMatchQRF\n",
    "from models.uncertainty.dist_match.utils import match_ks_stat\n",
    "from utils.calc_torch import calc_residuals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def matcher(x1, x2):\n",
    "    return match_ks_stat(x1, x2) < 0.6\n",
    "\n",
    "\n",
    "def get_qrf(path: str) -> DistMatchQRF:\n",
    "    qrf = DistMatchQRF(\n",
    "        alpha=0.1,\n",
    "        n_quantile_bins=10,\n",
    "        feature_dim=-1,\n",
    "        matcher=matcher,\n",
    "        match_mask=None,\n",
    "        n_trees=10,\n",
    "        bagging_ratio=0.9,\n",
    "        verbose=False,\n",
    "    )\n",
    "    qrf.load_trees(path)\n",
    "    return qrf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK_CONFIG = DictConfig(\n",
    "    {\n",
    "        \"task_type\": \"PI\",\n",
    "        \"alpha\": 0.1,\n",
    "        \"data_splits\": [0.6, 0.15, 0.25],\n",
    "        \"fc_estimator_mode\": \"single\",\n",
    "        \"global_norm\": False,\n",
    "        \"add_config\": None,\n",
    "    }\n",
    ")\n",
    "\n",
    "\n",
    "def load_dataset(data_config: TSDataConfig, task_config: TaskConfig = None):\n",
    "    task_config = task_config or TaskConfig(**TASK_CONFIG)\n",
    "    return DataGenerator.get_data(\n",
    "        data_config=data_config,\n",
    "        task_config=TASK_CONFIG,\n",
    "        replace_base_dir=DATA_DIR,\n",
    "        X_norm_param=None,\n",
    "        Y_norm_param=None,\n",
    "        hydro_static_norm_param=None,\n",
    "    )\n",
    "\n",
    "\n",
    "def load_forecast_service(\n",
    "    fc_model=None, model_config=None, data_config=None, task_config=None\n",
    "):\n",
    "    if fc_model is None:\n",
    "        fc_model = SimpleDartsModel\n",
    "        model_config = dict(\n",
    "            model=\"darts-forest\", model_params={\"lags\": 50, \"lags_past_covariates\": 50}\n",
    "        )\n",
    "\n",
    "    task_config = task_config or TASK_CONFIG\n",
    "\n",
    "    return ForcastService(\n",
    "        int_fc_model=lambda: fc_model(**model_config),\n",
    "        task_config=task_config,\n",
    "        model_config=model_config,\n",
    "        data_config=data_config,\n",
    "        persist_dir=FC_DIR,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_prototype(target, node):\n",
    "    samples = node.values[0]\n",
    "    parent_split_values = [parent.split_value[0] for parent in node.get_all_parents()]\n",
    "\n",
    "    samples = [\n",
    "        sample\n",
    "        for sample in samples\n",
    "        if not any(np.allclose(sample, parent) for parent in parent_split_values)\n",
    "    ]\n",
    "\n",
    "    min_dist = None\n",
    "    best_sample = None\n",
    "    for sample in samples:\n",
    "        target_renorm = target\n",
    "        dist = (abs(target_renorm - sample)).mean()\n",
    "        if min_dist is None or dist < min_dist:\n",
    "            print(dist)\n",
    "            min_dist = dist\n",
    "            best_sample = sample\n",
    "\n",
    "    return best_sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_residuals(forcast_service: ForcastService, data, is_calib: bool = False):\n",
    "    data = forcast_service.prepare(data, forcast_service._task_config.alpha)\n",
    "    \n",
    "    if is_calib:\n",
    "        calib_data = UncertaintyService._map_to_calib_data(data)\n",
    "        fc_result = forcast_service.predict(\n",
    "            FCPredictionData(\n",
    "                ts_id=calib_data.ts_id,\n",
    "                X_past=calib_data.X_pre_calib,\n",
    "                Y_past=calib_data.Y_pre_calib,\n",
    "                X_step=calib_data.X_calib,\n",
    "                step_offset=calib_data.step_offset,\n",
    "            )\n",
    "        )\n",
    "        return calc_residuals(Y_hat=fc_result.point, Y=calib_data.Y_calib).numpy()\n",
    "\n",
    "    fc_result = forcast_service.predict(\n",
    "        FCPredictionData(\n",
    "            ts_id=data.ts_id,\n",
    "            X_past=data.X_calib,\n",
    "            Y_past=data.Y_calib,\n",
    "            X_step=data.X_test,\n",
    "            step_offset=data.test_step,\n",
    "        )\n",
    "    )\n",
    "    return calc_residuals(Y_hat=fc_result.point, Y=data.Y_test).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data_prototypes(\n",
    "    data,\n",
    "    qrf,\n",
    "    forcast_service,\n",
    "    patch_len,\n",
    "    start: int = 0,\n",
    "    stop: int = 1000,\n",
    "    step: int = 10,\n",
    "    group_colors: list = None,\n",
    "    n_samples: int = 4,\n",
    "    is_calib: bool = True,\n",
    "):\n",
    "    group_colors = group_colors or [\"#FFA332\", \"#9DC3FE\"]\n",
    "    ex_data = get_residuals(forcast_service, data, is_calib)\n",
    "    ex_data = np.ravel(ex_data)\n",
    "\n",
    "    tree = qrf.trees[0]\n",
    "\n",
    "    node_idx_map = {node: idx for idx, node in enumerate(tree.leaf_nodes)}\n",
    "\n",
    "    prototypes = {}\n",
    "    prototype_samples = defaultdict(list)\n",
    "\n",
    "    ex_data = ex_data[start:stop]\n",
    "    proto_src_data = np.lib.stride_tricks.sliding_window_view(ex_data, patch_len)[\n",
    "        ::step\n",
    "    ]\n",
    "    p_bar = tqdm(proto_src_data, leave=False)\n",
    "    for cur_idx, patch in enumerate(p_bar):\n",
    "        node = tree.predict_single_node(patch)\n",
    "        node_idx = node_idx_map[node]\n",
    "        prototype = find_prototype(patch, node)\n",
    "        if node.parent.right != node:\n",
    "            continue\n",
    "        prototype = node.parent.split_value[0]\n",
    "        if node_idx not in prototypes:\n",
    "            prototypes[node_idx] = prototype\n",
    "        \n",
    "        prototype_samples[node_idx].append(patch)\n",
    "\n",
    "    n_groups = len(group_colors)\n",
    "\n",
    "    max_prototypes = sorted(prototype_samples.items(), key=lambda x: len(x[1]), reverse=True)[:n_groups]\n",
    "    max_prototypes = [proto_idx[0] for proto_idx in max_prototypes]\n",
    "\n",
    "    fig, axes = plt.subplots(nrows=n_groups, ncols=n_samples + 1, figsize=(2 * (n_samples + 1), 2 * n_groups))\n",
    "\n",
    "    for idx, proto_id in enumerate(max_prototypes):\n",
    "        samples = prototype_samples[proto_id]\n",
    "        sample_ids = np.random.choice(len(samples), n_samples)\n",
    "        samples = np.array(samples)[sample_ids]\n",
    "\n",
    "        color = group_colors[idx]\n",
    "        axes[idx, 0].plot(prototypes[proto_id], color=color, linewidth=2)\n",
    "        axes[idx, 0].tick_params(axis='x', labelsize=15)\n",
    "        axes[idx, 0].tick_params(axis='y', labelsize=15)\n",
    "\n",
    "        y_min = prototypes[proto_id].min()\n",
    "        y_max = prototypes[proto_id].max()\n",
    "\n",
    "        for i, sample in enumerate(samples):\n",
    "            ax = axes[idx, i+1]\n",
    "            ax.plot(sample, color=color, linewidth=2)\n",
    "            y_min = min(y_min, sample.min())\n",
    "            y_max = max(y_max, sample.max())\n",
    "\n",
    "            ax.set_yticklabels([])\n",
    "\n",
    "            # spines = [\"top\",\"right\",\"left\",\"bottom\"]\n",
    "            spines = [\"right\", \"bottom\"]\n",
    "            for s in spines:\n",
    "                ax.spines[s].set_visible(False)\n",
    "            \n",
    "            ax.tick_params(axis='x', labelsize=15)\n",
    "        \n",
    "        for ax in axes[idx]:\n",
    "            ax.set_ylim(y_min, y_max)\n",
    "\n",
    "    l_offset = 0.045\n",
    "    r_offset = 0.01\n",
    "    sep_color = \"#b7b7b7\"\n",
    "    v_sep_pos = l_offset + (idx) / (n_samples + 1) * (1 - l_offset - r_offset)\n",
    "    v_sep_line = plt.Line2D([v_sep_pos, v_sep_pos], [0.025, 0.975], transform=fig.transFigure, color=sep_color, linewidth=2)\n",
    "    fig.add_artist(v_sep_line)\n",
    "\n",
    "    fig.legend()\n",
    "    fig.tight_layout()\n",
    "\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor([0.0517, 0.5529, 0.0033, 0.4921]), tensor([0.0565, 0.1078, 0.0033, 0.1544]), tensor([0.5001]), tensor([0.1856])]\n"
     ]
    }
   ],
   "source": [
    "PATCH_LEN = 100\n",
    "DATA_MAP = {\n",
    "    \"Elec\": (\n",
    "        \"qrf_ks_error_normal_electricelectricity-normalized_darts-forest|100.pkl\",\n",
    "        \"darts_forest_dist_match_enbPI_electric_s10_290725_143222\",\n",
    "        \"electricelectricity-normalized\",\n",
    "        0,\n",
    "    ),\n",
    "    \"Solar\": (\n",
    "        \"qrf_ks_error_normal_solarSolar_Atl_data_aligned_darts-forest|100.pkl\",\n",
    "        # \"darts_forest_dist_match_enbPI_solar_atlanta_s30_210725_154023\",\n",
    "        \"darts_forest_dist_match_enbPI_solar_atlanta_280525_144732\",\n",
    "        \"solarSolar_Atl_data_aligned\",\n",
    "        # 0,\n",
    "        1,\n",
    "    ),\n",
    "    \"Wind\": (\n",
    "        \"qrf_ks_error_normal_windWind_Hackberry_Generation_2019_2020_darts-forest|100.pkl\",\n",
    "        \"darts_forest_dist_match_enbPI_wind_s30_070725_162757\",\n",
    "        \"windWind_Hackberry_Generation_2019_2020\",\n",
    "        0,\n",
    "    ),\n",
    "    \"Elec\": (\n",
    "        \"qrf_ks_error_normal_electricelectricity-normalized_darts-forest|100.pkl\",\n",
    "        \"electric\",\n",
    "        [\"/some_base_dir/data/enbPI/electricity-normalized.csv\"],\n",
    "        0,\n",
    "    ),\n",
    "    \"NFLX\": (\n",
    "        \"qrf_ks_error_normal_stockNFLX_darts-forest|100.pkl\",\n",
    "        \"darts_forest_dist_match_stock_nflx_s30_170725_142303\",\n",
    "        \"stockNFLX\",\n",
    "        0,\n",
    "    ),\n",
    "}\n",
    "\n",
    "for data_type, (qrf_path, datatype, data_paths, file_idx) in DATA_MAP.items():\n",
    "    data_config = DictConfig(\n",
    "        {\"dataset_type\": datatype, \"paths\": data_paths, \"add_config\": None}\n",
    "    )\n",
    "    datasets = load_dataset(TSDataConfig(**data_config))\n",
    "    print([getattr(datasets[0], k) for k in [\"_X_means\", \"_X_stds\", \"_Y_means\", \"_Y_stds\"]])\n",
    "    break\n",
    "    forcast_service = load_forecast_service(data_config=data_config)\n",
    "    fig = get_data_prototypes(\n",
    "        datasets[file_idx],\n",
    "        qrf,\n",
    "        forcast_service,\n",
    "        PATCH_LEN,\n",
    "        start=0000,\n",
    "        stop=3000,\n",
    "        is_calib=False,\n",
    "    )\n",
    "    fig.savefig(\"prototypes.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
