{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4d4939a0-55da-4946-ada6-96597c4f5290",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_21254/2947318579.py:20: MatplotlibDeprecationWarning: The seaborn styles shipped by Matplotlib are deprecated since 3.6, as they no longer correspond to the styles shipped by seaborn. However, they will remain available as 'seaborn-v0_8-<style>'. Alternatively, directly use the seaborn API instead.\n",
      "  plt.style.use('seaborn-whitegrid')\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tqdm\n",
    "\n",
    "from functools import partial\n",
    "\n",
    "from datasets import Recidivism, Diabetes, FICO, Schizo, Adults, Dataset, Readmission\n",
    "from sa import BaseAlgorithm, AlgorithmParams\n",
    "from rules import TwoWeightKnapsackRule, IntegerKnapsackRule, ORRule, Rule\n",
    "from neighbors import swap_high_rule, swap_low_rule, move_low_to_high, move_high_to_low\n",
    "from complement import Complement, ComplementParams\n",
    "\n",
    "from consistency import Coverage, Consistency, ConsistencySoft, CoverageConsistencyParams\n",
    "from benchmark import BenchmarkRuleMiner\n",
    "\n",
    "plt.style.use('seaborn-whitegrid')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7626d945-0354-486d-9886-7eef5a9175f8",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded Dataset with 48 columns, 107 exploded columns => 40 rules\n"
     ]
    }
   ],
   "source": [
    "dataset = Recidivism(random_seed=2, Q=5, num_features_universe=40, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2fe0c732-b82c-47cf-9f69-254d6974a23e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:04<00:00, 204.76it/s]\n"
     ]
    }
   ],
   "source": [
    "param = ComplementParams(\n",
    "    num_iter=1000, \n",
    "    N=5,\n",
    "    c=0.30,\n",
    "    p=0.30,\n",
    "    allow_high_low_switch=False,\n",
    "    should_validate=True,\n",
    ")\n",
    "\n",
    "alg_bb = Consistency(dataset, param)\n",
    "checklist_rule = alg_bb.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5a12120c-a63b-4d53-b413-b4d4bdcc5856",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7932732789721991"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "alg_bb.score_rule(checklist_rule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7b8c541e-86a6-410a-b2a2-5ec9a5a7ee35",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>rule</th>\n",
       "      <th>metric</th>\n",
       "      <th>coverage</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>An OR Rule of the 1 Rules\\nAn AND Rule of 2 co...</td>\n",
       "      <td>0.873322</td>\n",
       "      <td>0.023225</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>108</th>\n",
       "      <td>An OR Rule of the 1 Rules\\nAn AND Rule of 2 co...</td>\n",
       "      <td>0.873322</td>\n",
       "      <td>0.023225</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>140</th>\n",
       "      <td>An OR Rule of the 1 Rules\\nAn AND Rule of 2 co...</td>\n",
       "      <td>0.873322</td>\n",
       "      <td>0.023225</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>293</th>\n",
       "      <td>An OR Rule of the 1 Rules\\nAn AND Rule of 2 co...</td>\n",
       "      <td>0.873322</td>\n",
       "      <td>0.023225</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>428</th>\n",
       "      <td>An OR Rule of the 1 Rules\\nAn AND Rule of 2 co...</td>\n",
       "      <td>0.873322</td>\n",
       "      <td>0.023225</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>860</th>\n",
       "      <td>An OR Rule of the 3 Rules\\nAn AND Rule of 2 co...</td>\n",
       "      <td>0.770197</td>\n",
       "      <td>0.352188</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1310</th>\n",
       "      <td>An OR Rule of the 3 Rules\\nAn AND Rule of 2 co...</td>\n",
       "      <td>0.764542</td>\n",
       "      <td>0.368497</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>An OR Rule of the 3 Rules\\nAn AND Rule of 2 co...</td>\n",
       "      <td>0.756489</td>\n",
       "      <td>0.377168</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>480</th>\n",
       "      <td>An OR Rule of the 3 Rules\\nAn AND Rule of 2 co...</td>\n",
       "      <td>0.759400</td>\n",
       "      <td>0.394818</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1636</th>\n",
       "      <td>An OR Rule of the 3 Rules\\nAn AND Rule of 2 co...</td>\n",
       "      <td>0.747618</td>\n",
       "      <td>0.382638</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>433 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                   rule    metric  coverage\n",
       "8     An OR Rule of the 1 Rules\\nAn AND Rule of 2 co...  0.873322  0.023225\n",
       "108   An OR Rule of the 1 Rules\\nAn AND Rule of 2 co...  0.873322  0.023225\n",
       "140   An OR Rule of the 1 Rules\\nAn AND Rule of 2 co...  0.873322  0.023225\n",
       "293   An OR Rule of the 1 Rules\\nAn AND Rule of 2 co...  0.873322  0.023225\n",
       "428   An OR Rule of the 1 Rules\\nAn AND Rule of 2 co...  0.873322  0.023225\n",
       "...                                                 ...       ...       ...\n",
       "860   An OR Rule of the 3 Rules\\nAn AND Rule of 2 co...  0.770197  0.352188\n",
       "1310  An OR Rule of the 3 Rules\\nAn AND Rule of 2 co...  0.764542  0.368497\n",
       "36    An OR Rule of the 3 Rules\\nAn AND Rule of 2 co...  0.756489  0.377168\n",
       "480   An OR Rule of the 3 Rules\\nAn AND Rule of 2 co...  0.759400  0.394818\n",
       "1636  An OR Rule of the 3 Rules\\nAn AND Rule of 2 co...  0.747618  0.382638\n",
       "\n",
       "[433 rows x 3 columns]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "brm_soft = BenchmarkRuleMiner(dataset)\n",
    "brm_soft.get_pareto_rules(zmax=2)\n",
    "brm_soft.get_or_rules(num=2000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6b476546-a22e-44e5-9e5f-a21b1bf13ac6",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.7852071111624249, 0.7639377646946588)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "benchmark_rules = brm_soft.get_top_rule(\n",
    "    alg_bb.evaluate_rule_train,\n",
    "    alg_bb.score_rule,\n",
    "    num=5,\n",
    ")\n",
    "\n",
    "np.mean([alg_bb.evaluate_rule_train(b) for b in benchmark_rules]),\\\n",
    "np.mean([alg_bb.score_rule(b) for b in benchmark_rules])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cb5e811-5a77-48af-8b68-db52d7313472",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
