{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "da6638b8-5ca9-45bd-91d6-bd3fc8c425e2",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_34824/3822570784.py:28: MatplotlibDeprecationWarning: The seaborn styles shipped by Matplotlib are deprecated since 3.6, as they no longer correspond to the styles shipped by seaborn. However, they will remain available as 'seaborn-v0_8-<style>'. Alternatively, directly use the seaborn API instead.\n",
      "  plt.style.use('seaborn-whitegrid')\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tqdm\n",
    "\n",
    "from functools import partial, reduce\n",
    "\n",
    "from datasets2 import Recidivism, Diabetes, FICO, Schizo, Adults, Dataset, Readmission\n",
    "from sa import BaseAlgorithm, AlgorithmParams\n",
    "from rules import TwoWeightKnapsackRule, IntegerKnapsackRule, ORRule, Rule, Condition, Operator\n",
    "from neighbors import swap_high_rule, swap_low_rule, move_low_to_high, move_high_to_low\n",
    "from complement import Complement, ComplementParams\n",
    "from cart_rule_list import GreedyRuleList\n",
    "from consistency import Coverage, Consistency, ConsistencySoft, CoverageConsistencyParams\n",
    "from benchmark import BenchmarkRuleMiner\n",
    "\n",
    "from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor\n",
    "from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor\n",
    "from sklearn import tree\n",
    "from goal_cart import parse_tree_to_rules\n",
    "\n",
    "import time\n",
    "\n",
    "plt.style.use('seaborn-whitegrid')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2a9c7b95-6c62-459d-8391-11c6a92c6add",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded Dataset with 23 columns, 69 exploded columns => 39 rules\n",
      "Loaded Dataset with 23 columns => 74 rules\n"
     ]
    }
   ],
   "source": [
    "from datasets import Recidivism, FICO, Readmission, Diabetes, Schizo, Adults\n",
    "\n",
    "from datasets2 import Recidivism as Recidivism2\n",
    "from datasets2 import FICO as FICO2\n",
    "from datasets2 import Readmission as Readmission2\n",
    "from datasets2 import Diabetes as Diabetes2\n",
    "from datasets2 import Adults as Adults2\n",
    "from datasets2 import Schizo as Schizo2\n",
    "\n",
    "dataset = FICO(random_seed=0, Q=5, verbose=True, num_features_universe=40)\n",
    "dataset2 = FICO2(random_seed=0, Q=5, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6ed01987-1029-491e-b9f4-58930b4fb15b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "param1 = CoverageConsistencyParams(\n",
    "    num_iter=1000, \n",
    "    N=7,\n",
    "    c=0.1,\n",
    "    p=0.,\n",
    "    allow_high_low_switch=False,\n",
    "    should_validate=True,\n",
    ")\n",
    "# alg1 = ConsistencySoft(dataset, param1, True)\n",
    "# print(alg1.score_rule(alg1.starting_rule))\n",
    "\n",
    "# rule = alg1.run()\n",
    "# print(alg1.score_rule(rule))\n",
    "# rule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e37386ac-c60d-4e7f-917d-496bd7fd687b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8751434034416834\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:04<00:00, 247.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9201960304026264\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "---High-Weight Rules---\n",
       "\n",
       "---Low-Weight Rules---\n",
       "Rule: ExternalRiskEstimate <= 63.0 (weight=1.00)\n",
       "Rule: NetFractionRevolvingBurden >= 60.0 (weight=1.00)\n",
       "Rule: ExternalRiskEstimate <= 69.0 (weight=1.00)\n",
       "Rule: NumSatisfactoryTrades <= 12.0 (weight=1.00)\n",
       "Rule: PercentInstallTrades >= 48.0 (weight=1.00)\n",
       "Rule: MSinceOldestTradeOpen <= 127.0 (weight=1.00)\n",
       "Rule: NumBank2NatlTradesWHighUtilization >= 1.0 (weight=1.00)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "alg2 = ConsistencySoft(dataset2, param1, True)\n",
    "print(alg2.score_rule(alg2.starting_rule))\n",
    "\n",
    "rule2 = alg2.run()\n",
    "print(alg2.score_rule(rule2))\n",
    "\n",
    "rule2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "75327633-8809-4cdd-bbe1-0551f3829971",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 523,
   "id": "fc80f8cd-600f-45ca-9616-9f1955782fef",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>rule</th>\n",
       "      <th>mask</th>\n",
       "      <th>mask_test</th>\n",
       "      <th>support</th>\n",
       "      <th>abbr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>An AND Rule of 4 conditions (weight=1.00)\\nRul...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>527</th>\n",
       "      <td>An AND Rule of 4 conditions (weight=1.00)\\nRul...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>528</th>\n",
       "      <td>An AND Rule of 4 conditions (weight=1.00)\\nRul...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>529</th>\n",
       "      <td>An AND Rule of 4 conditions (weight=1.00)\\nRul...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>530</th>\n",
       "      <td>An AND Rule of 4 conditions (weight=1.00)\\nRul...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>270</th>\n",
       "      <td>An AND Rule of 4 conditions (weight=1.00)\\nRul...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>271</th>\n",
       "      <td>An AND Rule of 4 conditions (weight=1.00)\\nRul...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>272</th>\n",
       "      <td>An AND Rule of 4 conditions (weight=1.00)\\nRul...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>262</th>\n",
       "      <td>An AND Rule of 4 conditions (weight=1.00)\\nRul...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>799</th>\n",
       "      <td>An AND Rule of 4 conditions (weight=1.00)\\nRul...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>800 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                  rule  \\\n",
       "0    An AND Rule of 4 conditions (weight=1.00)\\nRul...   \n",
       "527  An AND Rule of 4 conditions (weight=1.00)\\nRul...   \n",
       "528  An AND Rule of 4 conditions (weight=1.00)\\nRul...   \n",
       "529  An AND Rule of 4 conditions (weight=1.00)\\nRul...   \n",
       "530  An AND Rule of 4 conditions (weight=1.00)\\nRul...   \n",
       "..                                                 ...   \n",
       "270  An AND Rule of 4 conditions (weight=1.00)\\nRul...   \n",
       "271  An AND Rule of 4 conditions (weight=1.00)\\nRul...   \n",
       "272  An AND Rule of 4 conditions (weight=1.00)\\nRul...   \n",
       "262  An AND Rule of 4 conditions (weight=1.00)\\nRul...   \n",
       "799  An AND Rule of 4 conditions (weight=1.00)\\nRul...   \n",
       "\n",
       "                                                  mask  \\\n",
       "0    0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...   \n",
       "527  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...   \n",
       "528  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...   \n",
       "529  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...   \n",
       "530  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...   \n",
       "..                                                 ...   \n",
       "270  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...   \n",
       "271  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...   \n",
       "272  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...   \n",
       "262  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...   \n",
       "799  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...   \n",
       "\n",
       "                                             mask_test  support  abbr  \n",
       "0    0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...      0.0   NaN  \n",
       "527  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...      0.0   NaN  \n",
       "528  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...      0.0   NaN  \n",
       "529  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...      0.0   NaN  \n",
       "530  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...      0.0   NaN  \n",
       "..                                                 ...      ...   ...  \n",
       "270  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...      0.0   NaN  \n",
       "271  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...      0.0   NaN  \n",
       "272  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...      0.0   NaN  \n",
       "262  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...      0.0   NaN  \n",
       "799  0       False\n",
       "1       False\n",
       "2       False\n",
       "3   ...      0.0   NaN  \n",
       "\n",
       "[800 rows x 5 columns]"
      ]
     },
     "execution_count": 523,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grl.rules_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 533,
   "id": "e29cd626-4a8f-4c29-beb1-540da4162cf8",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0       False\n",
       "1       False\n",
       "2       False\n",
       "3       False\n",
       "4       False\n",
       "        ...  \n",
       "5224    False\n",
       "5225    False\n",
       "5226    False\n",
       "5227    False\n",
       "5228    False\n",
       "Name: NetFractionRevolvingBurden, Length: 5229, dtype: bool"
      ]
     },
     "execution_count": 533,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grl.all_rules[0].conditions_list[0].get_mask(dataset2.get_X_train())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 530,
   "id": "d14d4463-f9a4-44cf-9030-8e69be4435a9",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Rule: NetFractionRevolvingBurden <= 36.5,\n",
       " Rule: MaxDelq2PublicRecLast12M <= 4.5,\n",
       " Rule: ExternalRiskEstimate <= 70.5,\n",
       " Rule: PercentTradesNeverDelq <= 88.5]"
      ]
     },
     "execution_count": 530,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grl.all_rules[0].conditions_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a0699f61-dcab-4d0b-9f8f-a39d4a152522",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:00<00:00, 59.82it/s]\n",
      "100%|██████████| 800/800 [00:01<00:00, 477.15it/s]\n",
      "/tmp/ipykernel_34824/4043553141.py:40: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  rules_table['best_abbr'] = rules_table['abbr'][::-1].cummax()[::-1]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>next_index</th>\n",
       "      <th>rule_list</th>\n",
       "      <th>train_cover</th>\n",
       "      <th>test_cover</th>\n",
       "      <th>train_abbr</th>\n",
       "      <th>test_abbr</th>\n",
       "      <th>eval_in</th>\n",
       "      <th>eval_out</th>\n",
       "      <th>cons_in</th>\n",
       "      <th>cons_out</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>6</td>\n",
       "      <td>An OR Rule of the 1 Rules\\nAn AND Rule of 4 co...</td>\n",
       "      <td>0.052209</td>\n",
       "      <td>0.055258</td>\n",
       "      <td>0.936596</td>\n",
       "      <td>0.940139</td>\n",
       "      <td>0.733480</td>\n",
       "      <td>0.736511</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>20</td>\n",
       "      <td>An OR Rule of the 2 Rules\\nAn AND Rule of 4 co...</td>\n",
       "      <td>0.089118</td>\n",
       "      <td>0.095220</td>\n",
       "      <td>0.919923</td>\n",
       "      <td>0.922107</td>\n",
       "      <td>0.870539</td>\n",
       "      <td>0.899306</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>16</td>\n",
       "      <td>An OR Rule of the 3 Rules\\nAn AND Rule of 4 co...</td>\n",
       "      <td>0.130618</td>\n",
       "      <td>0.129828</td>\n",
       "      <td>0.908017</td>\n",
       "      <td>0.911625</td>\n",
       "      <td>0.914414</td>\n",
       "      <td>0.920020</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   next_index                                          rule_list  train_cover  \\\n",
       "0           6  An OR Rule of the 1 Rules\\nAn AND Rule of 4 co...     0.052209   \n",
       "1          20  An OR Rule of the 2 Rules\\nAn AND Rule of 4 co...     0.089118   \n",
       "2          16  An OR Rule of the 3 Rules\\nAn AND Rule of 4 co...     0.130618   \n",
       "\n",
       "   test_cover  train_abbr  test_abbr   eval_in  eval_out  cons_in  cons_out  \n",
       "0    0.055258    0.936596   0.940139  0.733480  0.736511      1.0       1.0  \n",
       "1    0.095220    0.919923   0.922107  0.870539  0.899306      1.0       1.0  \n",
       "2    0.129828    0.908017   0.911625  0.914414  0.920020      1.0       1.0  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grl = GreedyRuleList(dataset2, max_depth=4, n_estimators=50, support_lb=0.025, tolerance=0.9)\n",
    "grl.generate_rule(0.1, debug=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 432,
   "id": "66e4b09e-7079-4a88-95ee-0331572bc111",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:00<00:00, 291.53it/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 585.48it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>next_index</th>\n",
       "      <th>rule_list</th>\n",
       "      <th>train_cover</th>\n",
       "      <th>test_cover</th>\n",
       "      <th>train_abbr</th>\n",
       "      <th>test_abbr</th>\n",
       "      <th>eval_in</th>\n",
       "      <th>eval_out</th>\n",
       "      <th>cons_in</th>\n",
       "      <th>cons_out</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>15</td>\n",
       "      <td>An OR Rule of the 1 Rules\\nAn AND Rule of 2 co...</td>\n",
       "      <td>0.23733</td>\n",
       "      <td>0.235564</td>\n",
       "      <td>0.837385</td>\n",
       "      <td>0.841012</td>\n",
       "      <td>0.836327</td>\n",
       "      <td>0.842663</td>\n",
       "      <td>0.997583</td>\n",
       "      <td>0.997565</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   next_index                                          rule_list  train_cover  \\\n",
       "0          15  An OR Rule of the 1 Rules\\nAn AND Rule of 2 co...      0.23733   \n",
       "\n",
       "   test_cover  train_abbr  test_abbr   eval_in  eval_out   cons_in  cons_out  \n",
       "0    0.235564    0.837385   0.841012  0.836327  0.842663  0.997583  0.997565  "
      ]
     },
     "execution_count": 432,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grl = GreedyRuleList(dataset2, max_depth=2, n_estimators=50, support_lb=0.025, tolerance=0.9)\n",
    "grl.generate_rule(0.1, debug=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 433,
   "id": "3f0ece5d-8047-4270-99ca-5474b8ed559b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25/25 [00:00<00:00, 59.30it/s]\n",
      "100%|██████████| 397/397 [00:00<00:00, 467.71it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>next_index</th>\n",
       "      <th>rule_list</th>\n",
       "      <th>train_cover</th>\n",
       "      <th>test_cover</th>\n",
       "      <th>train_abbr</th>\n",
       "      <th>test_abbr</th>\n",
       "      <th>eval_in</th>\n",
       "      <th>eval_out</th>\n",
       "      <th>cons_in</th>\n",
       "      <th>cons_out</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>7</td>\n",
       "      <td>An OR Rule of the 1 Rules\\nAn AND Rule of 4 co...</td>\n",
       "      <td>0.243832</td>\n",
       "      <td>0.243403</td>\n",
       "      <td>0.830079</td>\n",
       "      <td>0.839268</td>\n",
       "      <td>0.831161</td>\n",
       "      <td>0.837044</td>\n",
       "      <td>0.998431</td>\n",
       "      <td>0.999214</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   next_index                                          rule_list  train_cover  \\\n",
       "0           7  An OR Rule of the 1 Rules\\nAn AND Rule of 4 co...     0.243832   \n",
       "\n",
       "   test_cover  train_abbr  test_abbr   eval_in  eval_out   cons_in  cons_out  \n",
       "0    0.243403    0.830079   0.839268  0.831161  0.837044  0.998431  0.999214  "
      ]
     },
     "execution_count": 433,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grl_bench = GreedyRuleList(dataset2, max_depth=4, n_estimators=25, support_lb=0.025, tolerance=0.9, use_quantile=False)\n",
    "grl_bench.generate_rule(0.1, debug=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 502,
   "id": "6949867f-57df-463a-aed6-f4922e49549a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "class RuleListSA(ConsistencySoft):\n",
    "    def get_preds(\n",
    "        self, \n",
    "        knapsack_rule: TwoWeightKnapsackRule, \n",
    "        X: pd.DataFrame, \n",
    "        y: pd.Series,\n",
    "    ):\n",
    "        y_preds = self.convert_to_quantile(knapsack_rule.get_mask(X))\n",
    "        return y[y_preds >= 1 - self.params.c]\n",
    "    \n",
    "    def get_start(self) -> ORRule: \n",
    "        grl = GreedyRuleList(self.dataset, max_depth=4, n_estimators=50, support_lb=0.025, tolerance=0.9)\n",
    "        self.rule_candidates = list(grl.pareto_rules.rule)\n",
    "\n",
    "        or_rule = grl.generate_rule(self.params.c)\n",
    "        \n",
    "        or_rule = ORRule(rule_list=self.rule_candidates[:self.params.N])\n",
    "\n",
    "        or_rule.name = 'initial'\n",
    "        \n",
    "        if len(or_rule.rule_list) < 2: \n",
    "            or_rule.rule_list.append(or_rule.rule_list[0])\n",
    "            or_rule.rule_list.append(or_rule.rule_list[0])\n",
    "\n",
    "        return or_rule\n",
    "                \n",
    "    def get_neighbor(self, rule: ORRule) -> ORRule: \n",
    "        return self.replace(rule)\n",
    "    \n",
    "    def replace(self, rule: ORRule) -> ORRule:\n",
    "        index_replace = random.randint(0, len(rule.rule_list) - 1)\n",
    "        index_new = random.randint(0, len(self.rule_candidates) - 1)\n",
    "\n",
    "        rl = rule.rule_list.copy()\n",
    "        rl[index_replace] = self.rule_candidates[index_new]\n",
    "\n",
    "        assert(len(rl) == len(rule.rule_list))\n",
    "\n",
    "        new_rule = ORRule(rule_list=rl)\n",
    "        new_rule.name='Added %s, Removed %s' % (index_new, index_replace)\n",
    "        assert(len(new_rule.rule_list) == len(rule.rule_list))\n",
    "        return new_rule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 503,
   "id": "0cbb811c-ff1a-40d2-a52e-a0cbd416aec2",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/50 [00:00<?, ?it/s]/nfs/home2/evanyao/paper/goal_cart.py:61: RuntimeWarning: Mean of empty slice.\n",
      "  leaf_node_rules.sort(key=lambda r: -dataset.get_y_train_quantile()[r.get_mask(dataset.get_X_train())].mean())\n",
      "/home/software/anaconda3/2023.07/lib/python3.11/site-packages/numpy/core/_methods.py:194: RuntimeWarning: invalid value encountered in scalar divide\n",
      "  ret = ret / rcount\n",
      "100%|██████████| 50/50 [00:00<00:00, 59.20it/s]\n",
      "  0%|          | 0/800 [00:00<?, ?it/s]/tmp/ipykernel_18671/1389361965.py:33: RuntimeWarning: Mean of empty slice.\n",
      "  'abbr': self.y_train_quantile[msk].mean(),\n",
      "100%|██████████| 800/800 [00:01<00:00, 467.00it/s]\n"
     ]
    },
    {
     "ename": "IndexError",
     "evalue": "index 0 is out of bounds for axis 0 with size 0",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[503], line 10\u001b[0m\n\u001b[1;32m      1\u001b[0m param1 \u001b[38;5;241m=\u001b[39m CoverageConsistencyParams(\n\u001b[1;32m      2\u001b[0m     num_iter\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1000\u001b[39m, \n\u001b[1;32m      3\u001b[0m     N\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m      7\u001b[0m     should_validate\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m      8\u001b[0m )\n\u001b[0;32m---> 10\u001b[0m rlsa \u001b[38;5;241m=\u001b[39m RuleListSA(dataset2, param1, \u001b[38;5;28;01mTrue\u001b[39;00m)\n",
      "File \u001b[0;32m/nfs/home2/evanyao/paper/sa.py:40\u001b[0m, in \u001b[0;36mBaseAlgorithm.__init__\u001b[0;34m(self, dataset, params, show_progress_bar, skeleton)\u001b[0m\n\u001b[1;32m     37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrun_information \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m     39\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m skeleton:\n\u001b[0;32m---> 40\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstarting_rule \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_start()\n",
      "Cell \u001b[0;32mIn[502], line 8\u001b[0m, in \u001b[0;36mRuleListSA.get_start\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m      5\u001b[0m grl \u001b[38;5;241m=\u001b[39m GreedyRuleList(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset, max_depth\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m, n_estimators\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m50\u001b[39m, support_lb\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.025\u001b[39m, tolerance\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.9\u001b[39m)\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrule_candidates \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(grl\u001b[38;5;241m.\u001b[39mpareto_rules\u001b[38;5;241m.\u001b[39mrule)\n\u001b[0;32m----> 8\u001b[0m or_rule \u001b[38;5;241m=\u001b[39m grl\u001b[38;5;241m.\u001b[39mgenerate_rule(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparams\u001b[38;5;241m.\u001b[39mc)\n\u001b[1;32m     10\u001b[0m or_rule \u001b[38;5;241m=\u001b[39m ORRule(rule_list\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrule_candidates[:\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparams\u001b[38;5;241m.\u001b[39mN])\n\u001b[1;32m     12\u001b[0m or_rule\u001b[38;5;241m.\u001b[39mname \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124minitial\u001b[39m\u001b[38;5;124m'\u001b[39m\n",
      "Cell \u001b[0;32mIn[428], line 70\u001b[0m, in \u001b[0;36mGreedyRuleList.generate_rule\u001b[0;34m(self, c, debug)\u001b[0m\n\u001b[1;32m     67\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mgenerate_rule\u001b[39m(\u001b[38;5;28mself\u001b[39m, c: \u001b[38;5;28mfloat\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m, debug\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m     68\u001b[0m     rules_pool \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpareto_rules\u001b[38;5;241m.\u001b[39msample(frac\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.5\u001b[39m)\u001b[38;5;241m.\u001b[39msort_index()\n\u001b[0;32m---> 70\u001b[0m     indices \u001b[38;5;241m=\u001b[39m [rules_pool\u001b[38;5;241m.\u001b[39mindex[\u001b[38;5;241m0\u001b[39m]]\n\u001b[1;32m     72\u001b[0m     param_dummy \u001b[38;5;241m=\u001b[39m CoverageConsistencyParams(\n\u001b[1;32m     73\u001b[0m         num_iter\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m, \n\u001b[1;32m     74\u001b[0m         N\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     78\u001b[0m         should_validate\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m     79\u001b[0m     )\n\u001b[1;32m     81\u001b[0m     alg_dummy \u001b[38;5;241m=\u001b[39m ConsistencySoft(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset, param_dummy, \u001b[38;5;28;01mFalse\u001b[39;00m, skeleton\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
      "File \u001b[0;32m/home/software/anaconda3/2023.07/lib/python3.11/site-packages/pandas/core/indexes/base.py:5320\u001b[0m, in \u001b[0;36mIndex.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m   5317\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_integer(key) \u001b[38;5;129;01mor\u001b[39;00m is_float(key):\n\u001b[1;32m   5318\u001b[0m     \u001b[38;5;66;03m# GH#44051 exclude bool, which would return a 2d ndarray\u001b[39;00m\n\u001b[1;32m   5319\u001b[0m     key \u001b[38;5;241m=\u001b[39m com\u001b[38;5;241m.\u001b[39mcast_scalar_indexer(key, warn_float\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m-> 5320\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m getitem(key)\n\u001b[1;32m   5322\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(key, \u001b[38;5;28mslice\u001b[39m):\n\u001b[1;32m   5323\u001b[0m     \u001b[38;5;66;03m# This case is separated from the conditional above to avoid\u001b[39;00m\n\u001b[1;32m   5324\u001b[0m     \u001b[38;5;66;03m# pessimization com.is_bool_indexer and ndim checks.\u001b[39;00m\n\u001b[1;32m   5325\u001b[0m     result \u001b[38;5;241m=\u001b[39m getitem(key)\n",
      "\u001b[0;31mIndexError\u001b[0m: index 0 is out of bounds for axis 0 with size 0"
     ]
    }
   ],
   "source": [
    "param1 = CoverageConsistencyParams(\n",
    "    num_iter=1000, \n",
    "    N=5,\n",
    "    c=0.2,\n",
    "    p=0.,\n",
    "    allow_high_low_switch=False,\n",
    "    should_validate=True,\n",
    ")\n",
    "\n",
    "rlsa = RuleListSA(dataset2, param1, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29e4eb9b-60a4-4ac6-a948-f954095ca8e2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "result = rlsa.run()\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 499,
   "id": "bd684412-20d1-407e-b673-4ed3cdb4a047",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Iteration</th>\n",
       "      <th>Operation</th>\n",
       "      <th>Evaluation (In-Sample)</th>\n",
       "      <th>Validation Score</th>\n",
       "      <th>Score (Out-of-Sample)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>initial</td>\n",
       "      <td>0.507068</td>\n",
       "      <td>0.507068</td>\n",
       "      <td>0.495814</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Iteration Operation  Evaluation (In-Sample)  Validation Score  \\\n",
       "0          0   initial                0.507068          0.507068   \n",
       "\n",
       "   Score (Out-of-Sample)  \n",
       "0               0.495814  "
      ]
     },
     "execution_count": 499,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(rlsa.run_information)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 495,
   "id": "0a5937e2-b934-4c04-a315-35f8945a3a99",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0       0\n",
       "1       0\n",
       "2       0\n",
       "3       0\n",
       "4       0\n",
       "       ..\n",
       "5224    0\n",
       "5225    0\n",
       "5226    0\n",
       "5227    0\n",
       "5228    0\n",
       "Length: 5229, dtype: int64"
      ]
     },
     "execution_count": 495,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.get_mask(dataset2.get_X_train())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 497,
   "id": "62dc49ef-35b2-4b7e-9eb3-98ebd6d72c9f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Unexpected exception formatting exception. Falling back to standard exception\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"/home/software/anaconda3/2023.07/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3505, in run_code\n",
      "    exec(code_obj, self.user_global_ns, self.user_ns)\n",
      "  File \"/tmp/ipykernel_18671/3131944863.py\", line 1, in <module>\n",
      "    result.apply(dataset2.get_X_train(), dataset2.get_y_train_quantile())\n",
      "  File \"/nfs/home2/evanyao/paper/rules.py\", line 133, in apply\n",
      "    'coverage': y[msk >= i].mean(),\n",
      "                  ^^^\n",
      "NameError: name 'msk' is not defined\n",
      "\n",
      "During handling of the above exception, another exception occurred:\n",
      "\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/software/anaconda3/2023.07/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 2102, in showtraceback\n",
      "    stb = self.InteractiveTB.structured_traceback(\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/home/software/anaconda3/2023.07/lib/python3.11/site-packages/IPython/core/ultratb.py\", line 1310, in structured_traceback\n",
      "    return FormattedTB.structured_traceback(\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/home/software/anaconda3/2023.07/lib/python3.11/site-packages/IPython/core/ultratb.py\", line 1199, in structured_traceback\n",
      "    return VerboseTB.structured_traceback(\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/home/software/anaconda3/2023.07/lib/python3.11/site-packages/IPython/core/ultratb.py\", line 1052, in structured_traceback\n",
      "    formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,\n",
      "                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/home/software/anaconda3/2023.07/lib/python3.11/site-packages/IPython/core/ultratb.py\", line 978, in format_exception_as_a_whole\n",
      "    frames.append(self.format_record(record))\n",
      "                  ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/home/software/anaconda3/2023.07/lib/python3.11/site-packages/IPython/core/ultratb.py\", line 878, in format_record\n",
      "    frame_info.lines, Colors, self.has_colors, lvals\n",
      "    ^^^^^^^^^^^^^^^^\n",
      "  File \"/home/software/anaconda3/2023.07/lib/python3.11/site-packages/IPython/core/ultratb.py\", line 712, in lines\n",
      "    return self._sd.lines\n",
      "           ^^^^^^^^^^^^^^\n",
      "  File \"/home/software/anaconda3/2023.07/lib/python3.11/site-packages/stack_data/utils.py\", line 145, in cached_property_wrapper\n",
      "    value = obj.__dict__[self.func.__name__] = self.func(obj)\n",
      "                                               ^^^^^^^^^^^^^^\n",
      "  File \"/home/software/anaconda3/2023.07/lib/python3.11/site-packages/stack_data/core.py\", line 698, in lines\n",
      "    pieces = self.included_pieces\n",
      "             ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/home/software/anaconda3/2023.07/lib/python3.11/site-packages/stack_data/utils.py\", line 145, in cached_property_wrapper\n",
      "    value = obj.__dict__[self.func.__name__] = self.func(obj)\n",
      "                                               ^^^^^^^^^^^^^^\n",
      "  File \"/home/software/anaconda3/2023.07/lib/python3.11/site-packages/stack_data/core.py\", line 649, in included_pieces\n",
      "    pos = scope_pieces.index(self.executing_piece)\n",
      "                             ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/home/software/anaconda3/2023.07/lib/python3.11/site-packages/stack_data/utils.py\", line 145, in cached_property_wrapper\n",
      "    value = obj.__dict__[self.func.__name__] = self.func(obj)\n",
      "                                               ^^^^^^^^^^^^^^\n",
      "  File \"/home/software/anaconda3/2023.07/lib/python3.11/site-packages/stack_data/core.py\", line 628, in executing_piece\n",
      "    return only(\n",
      "           ^^^^^\n",
      "  File \"/home/software/anaconda3/2023.07/lib/python3.11/site-packages/executing/executing.py\", line 164, in only\n",
      "    raise NotOneValueFound('Expected one value, found 0')\n",
      "executing.executing.NotOneValueFound: Expected one value, found 0\n"
     ]
    }
   ],
   "source": [
    "result.apply(dataset2.get_X_train(), dataset2.get_y_train_quantile())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b46cd39-c3a2-4554-a062-eacbba3f74bf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
