{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-19T21:29:58.457953Z",
     "start_time": "2021-07-19T21:29:56.288637Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import matplotlib\n",
    "from sklearn.metrics import zero_one_loss\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-19T21:29:58.474309Z",
     "start_time": "2021-07-19T21:29:58.460009Z"
    }
   },
   "outputs": [],
   "source": [
    "from utils.data_gen import compute_bayes_risk_binary, compute_bayes_risk_binary_label_shift,\\\n",
    "    generate_2d_example\n",
    "\n",
    "from utils.concentrations import hoeffding_ci_lower_limit,\\\n",
    "    hoeffding_ci_upper_limit\n",
    "from utils.concentrations import pm_bernstein_ci_upper_limit,\\\n",
    "    pm_bernstein_ci_lower_limit\n",
    "\n",
    "from utils.concentrations import pm_bernstein_lower_limit, pm_bernstein_upper_limit\n",
    "from utils.concentrations import pm_hoeffding_upper_limit, pm_hoeffding_lower_limit\n",
    "\n",
    "from utils.data_gen import LDA_predictor\n",
    "\n",
    "from utils.concentrations import betting_ci_lower_limit, betting_ci_upper_limit\n",
    "\n",
    "from utils.tests import Drop_tester"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-19T21:29:58.481141Z",
     "start_time": "2021-07-19T21:29:58.476872Z"
    }
   },
   "outputs": [],
   "source": [
    "sns.set(\n",
    "    style=\"whitegrid\",\n",
    "    font_scale=1.4,\n",
    "    rc={\n",
    "        \"lines.linewidth\": 2,\n",
    "        #             \"axes.facecolor\": \".9\",\n",
    "        'figure.figsize': (12, 6)\n",
    "    })\n",
    "sns.set_palette('Set2')\n",
    "matplotlib.rcParams['text.usetex'] = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-19T21:29:58.487549Z",
     "start_time": "2021-07-19T21:29:58.483449Z"
    }
   },
   "outputs": [],
   "source": [
    "legend_dict = ['$\\{0\\}$', '$\\{1\\}$']\n",
    "clr1 = np.array([193, 142, 206]) / 256\n",
    "clr2 = np.array([125, 225, 125]) / 256\n",
    "colors = [clr1, clr2]\n",
    "\n",
    "color1 = np.array([244, 236, 118]) / 256\n",
    "color2 = np.array([110, 200, 235]) / 256\n",
    "color3 = np.array([233, 113, 183]) / 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-19T21:29:58.554349Z",
     "start_time": "2021-07-19T21:29:58.489201Z"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Estimate number of samples to reject"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-08-02T16:45:12.054449Z",
     "start_time": "2021-08-02T16:45:12.029387Z"
    }
   },
   "outputs": [],
   "source": [
    "num_of_repeats = 200\n",
    "num_of_batches_to_sample = 30\n",
    "\n",
    "size_of_batch = 50\n",
    "size_source_sample = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-08-02T16:45:12.370664Z",
     "start_time": "2021-08-02T16:45:12.345567Z"
    }
   },
   "outputs": [],
   "source": [
    "eps_tol = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-08-02T16:45:12.689093Z",
     "start_time": "2021-08-02T16:45:12.661969Z"
    }
   },
   "outputs": [],
   "source": [
    "target_cand_probs = np.linspace(0.1,0.9,20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-08-02T16:45:12.968626Z",
     "start_time": "2021-08-02T16:45:12.942874Z"
    }
   },
   "outputs": [],
   "source": [
    "mu_1 = np.array([1, 0])\n",
    "mu_0 = np.array([-1, 0])\n",
    "prob_class_1 = 0.25\n",
    "prob_class_0 = 1 - prob_class_1\n",
    "cov = np.eye(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-08-02T16:45:13.231229Z",
     "start_time": "2021-08-02T16:45:13.205727Z"
    }
   },
   "outputs": [],
   "source": [
    "clf = LDA_predictor()\n",
    "\n",
    "clf.mean_class_0 = mu_0\n",
    "clf.mean_class_1 = mu_1\n",
    "clf.class_0_prior = prob_class_0\n",
    "clf.class_1_prior = prob_class_1\n",
    "\n",
    "clf.predict_both_classes = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-08-02T16:45:13.534218Z",
     "start_time": "2021-08-02T16:45:13.505837Z"
    }
   },
   "outputs": [],
   "source": [
    "tester_hoeffding = Drop_tester()\n",
    "tester_bernstein = Drop_tester()\n",
    "tester_betting = Drop_tester()\n",
    "\n",
    "tester_hoeffding.eps_tol = eps_tol\n",
    "tester_bernstein.source_conc_type = 'hoeffding'\n",
    "tester_hoeffding.target_conc_type = 'pm_hoeffding'\n",
    "tester_hoeffding.change_type = 'relative'\n",
    "\n",
    "tester_bernstein.eps_tol = eps_tol\n",
    "tester_bernstein.source_conc_type = 'pm_bernstein'\n",
    "tester_bernstein.target_conc_type = 'pm_bernstein'\n",
    "tester_bernstein.change_type = 'relative'\n",
    "\n",
    "tester_betting.eps_tol = eps_tol\n",
    "tester_betting.source_conc_type = 'betting'\n",
    "tester_betting.target_conc_type = 'betting'\n",
    "tester_betting.change_type = 'relative'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-08-02T16:45:13.965133Z",
     "start_time": "2021-08-02T16:45:13.938437Z"
    }
   },
   "outputs": [],
   "source": [
    "bern_num_rejects = list()\n",
    "hoef_num_rejects = list()\n",
    "bet_num_rejects = list()\n",
    "\n",
    "bern_num_samples_to_reject = list()\n",
    "hoef_num_samples_to_reject = list()\n",
    "bet_num_samples_to_reject = list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-08-02T16:45:14.260513Z",
     "start_time": "2021-08-02T16:45:14.233754Z"
    }
   },
   "outputs": [],
   "source": [
    "from utils.tests import brier_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2021-08-02T16:45:14.764Z"
    }
   },
   "outputs": [],
   "source": [
    "for cur_ind, cur_target_prob in enumerate(target_cand_probs):\n",
    "\n",
    "    bern_num_samples_to_reject += [[]]\n",
    "    hoef_num_samples_to_reject += [[]]\n",
    "    bet_num_samples_to_reject += [[]]\n",
    "\n",
    "    bern_num_rejects += [0]\n",
    "    hoef_num_rejects += [0]\n",
    "    bet_num_rejects += [0]\n",
    "\n",
    "    for cur_sim in range(num_of_repeats):\n",
    "\n",
    "        X_val_source, y_val_source = generate_2d_example(\n",
    "            0.25, mu_0, mu_1, size_source_sample)\n",
    "\n",
    "        y_pred_val = clf.predict_proba(X_val_source)\n",
    "        ind_loss_source = brier_scores(y_val_source,y_pred_val)\n",
    "\n",
    "        tester_hoeffding.estimate_risk_source(ind_loss_source)\n",
    "        tester_bernstein.estimate_risk_source(ind_loss_source)\n",
    "        tester_betting.estimate_risk_source(ind_loss_source)\n",
    "\n",
    "        X_new_target, y_new_target = generate_2d_example(\n",
    "            cur_target_prob, mu_0, mu_1,\n",
    "            size_of_batch * num_of_batches_to_sample)\n",
    "\n",
    "        y_pred_target = clf.predict_proba(X_new_target)\n",
    "        ind_loss_target = brier_scores(y_new_target, y_pred_target)\n",
    "\n",
    "        for cur_batch in range(num_of_batches_to_sample):\n",
    "            cur_losses = ind_loss_target[0:(cur_batch + 1) * size_of_batch]\n",
    "            tester_hoeffding.estimate_risk_target(cur_losses)\n",
    "            \n",
    "            if tester_hoeffding.test_for_drop():\n",
    "                hoef_num_rejects[cur_ind] += 1\n",
    "                hoef_num_samples_to_reject[cur_ind] += [tester_hoeffding.target_num_of_samples_used]\n",
    "                break\n",
    "             \n",
    "        for cur_batch in range(num_of_batches_to_sample):\n",
    "            cur_losses = ind_loss_target[0:(cur_batch + 1) * size_of_batch]\n",
    "            tester_bernstein.estimate_risk_target(cur_losses)\n",
    "            \n",
    "            if tester_bernstein.test_for_drop():\n",
    "                bern_num_rejects[cur_ind] += 1\n",
    "                bern_num_samples_to_reject[cur_ind] += [tester_bernstein.target_num_of_samples_used]\n",
    "                break\n",
    "                \n",
    "        for cur_batch in range(num_of_batches_to_sample):\n",
    "            cur_losses = ind_loss_target[0:(cur_batch + 1) * size_of_batch]\n",
    "            tester_betting.estimate_risk_target(cur_losses)\n",
    "            \n",
    "            if tester_betting.test_for_drop():\n",
    "                bet_num_rejects[cur_ind] += 1\n",
    "                bet_num_samples_to_reject[cur_ind] += [tester_betting.target_num_of_samples_used]\n",
    "                break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2021-08-02T16:45:15.723Z"
    }
   },
   "outputs": [],
   "source": [
    "hoef_fraction_rejected = [num_rej / num_of_repeats for num_rej in hoef_num_rejects]\n",
    "bern_fraction_rejected = [num_rej / num_of_repeats for num_rej in bern_num_rejects]\n",
    "bet_fraction_rejected = [num_rej / num_of_repeats for num_rej in bet_num_rejects]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2021-08-02T16:45:16.047Z"
    }
   },
   "outputs": [],
   "source": [
    "plt.plot(target_cand_probs, hoef_fraction_rejected, marker='*', label = 'Hoeffding')\n",
    "plt.plot(target_cand_probs, bern_fraction_rejected,  marker='*',label = 'Bernstein')\n",
    "plt.plot(target_cand_probs, bet_fraction_rejected,  marker='*',label = 'Betting')\n",
    "plt.legend(loc=2, markerscale=1.5, prop={'size': 20})\n",
    "plt.ylabel('Proportion of rejected', fontsize=25)\n",
    "plt.xlabel('Class 1 probability', fontsize=25)\n",
    "# plt.savefig('img/prop_rej_sim_brier.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## number of samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2021-08-02T16:45:17.035Z"
    }
   },
   "outputs": [],
   "source": [
    "hoef_lengths = [len(cur_prob) for cur_prob in hoef_num_samples_to_reject]\n",
    "bern_lengths = [len(cur_prob) for cur_prob in bern_num_samples_to_reject]\n",
    "bet_lengths = [len(cur_prob) for cur_prob in bet_num_samples_to_reject]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2021-08-02T16:45:17.291Z"
    }
   },
   "outputs": [],
   "source": [
    "hoef_avgs = [\n",
    "    np.mean(cur_prob)\n",
    "    for cur_ind, cur_prob in enumerate(hoef_num_samples_to_reject)\n",
    "    if hoef_lengths[cur_ind] > 0\n",
    "]\n",
    "\n",
    "bern_avgs = [\n",
    "    np.mean(cur_prob)\n",
    "    for cur_ind, cur_prob in enumerate(bern_num_samples_to_reject)\n",
    "    if bern_lengths[cur_ind] > 0\n",
    "]\n",
    "\n",
    "bet_avgs = [\n",
    "    np.mean(cur_prob)\n",
    "    for cur_ind, cur_prob in enumerate(bet_num_samples_to_reject)\n",
    "    if bet_lengths[cur_ind] > 0\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2021-08-02T16:45:17.815Z"
    }
   },
   "outputs": [],
   "source": [
    "hoef_pos_length = np.array(hoef_lengths)>0\n",
    "bern_pos_length = np.array(bern_lengths)>0\n",
    "bet_pos_length = np.array(bet_lengths)>0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2021-08-02T16:45:18.203Z"
    }
   },
   "outputs": [],
   "source": [
    "plt.plot(target_cand_probs[hoef_pos_length],hoef_avgs, marker='*',label = 'Hoeffding')\n",
    "\n",
    "# plt.fill_between(\n",
    "#     target_cand_probs[hoef_pos_length],\n",
    "#     y1=np.repeat(hoef_avgs - 2 * hoef_stds/np.sqrt(hoef_at_least_2),\n",
    "#                    repeats=num_of_batches_to_sample),\n",
    "#     y2=np.repeat(hoef_avgs + 2 * hoef_stds/np.sqrt(hoef_at_least_2),\n",
    "#                    repeats=num_of_batches_to_sample),\n",
    "#     alpha=0.5)\n",
    "\n",
    "plt.plot(target_cand_probs[bern_pos_length],bern_avgs, marker='*', label = 'Bernstein')\n",
    "# plt.plot(target_cand_probs[bet_bern_pos_length],bet_bern_avgs, marker='*', label = 'Betting + Bernstein')\n",
    "plt.plot(target_cand_probs[bet_pos_length],bet_avgs, marker='*', label = 'Betting')\n",
    "# plt.axvline(x=target_prob_class_1[np.argmin(nulls == True)], c='red', linestyle='dashed')\n",
    "plt.legend(loc=2, markerscale=1.5, prop={'size': 20})\n",
    "plt.ylabel('Number of samples to reject', fontsize=25)\n",
    "plt.xlabel('Class 1 probability', fontsize=25)\n",
    "# plt.savefig('img/number_of_samples_brier.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
