{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d989d404",
   "metadata": {},
   "source": [
    "# import useful libraries and functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "23393c85",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random as random\n",
    "import pandas as pd\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from scipy.special import expit\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d2889dfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(90)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19c8312e",
   "metadata": {},
   "source": [
    "# covariate shift"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae764f33",
   "metadata": {},
   "source": [
    "## helper functions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a726300",
   "metadata": {},
   "source": [
    "#### Generate the initial data distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "70732b3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_init_CS(N, sigma2, sigma3):\n",
    "    '''\n",
    "    generate the initial dataframe for covariate shift\n",
    "    \n",
    "    Input:\n",
    "        - N: total nunber of samples\n",
    "        - sigma2: variance of the noisy term for X2\n",
    "        - sigma3: variance of the noisy term for X3\n",
    "    \n",
    "    N: total number of points\n",
    "    X1 ~ unif(-1, 1)\n",
    "    X2 ~ 1.2X1 + N(0,sigma2)\n",
    "    X3 ~ -0.8X1^2 + N(0,sigma3)\n",
    "    Y is threshold based on the value of X2\n",
    "    \n",
    "    output:\n",
    "        - df: a data frame df with ['X1', 'X2', 'X3', 'Y']\n",
    "    '''\n",
    "    \n",
    "    # generate X1, X2, X3\n",
    "    X1 = np.random.uniform(-1, 1, N)\n",
    "    X2 = X1 * 1.2 + np.random.normal(0,sigma2, N)  \n",
    "    X3 = -X1**2 * 0.8 + np.random.normal(0,sigma3, N)\n",
    "    # Generate Y\n",
    "    Y = 1*(X2 > 0)\n",
    "    d = {'X1': list(X1), 'X2': list(X2), 'X3': list(X3),  'Y': list(Y)}\n",
    "    df = pd.DataFrame(data=d)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fc38f76",
   "metadata": {},
   "source": [
    "#### Compute optimal threshold (hS) for the source distribution "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "1ac34ffd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# helper function: compute average square error of a threshold on a dataframe df\n",
    "def compute_error(df, threshold):\n",
    "    '''\n",
    "    Input: \n",
    "        - df: a dataset with a column for 'Q' already computed\n",
    "        - a threshold\n",
    "    \n",
    "    output:\n",
    "        - the mean squared error\n",
    "    '''\n",
    "    error = 0\n",
    "    # compute error\n",
    "    for i in range(len(df)):\n",
    "        # if df_result[\"x\"][i]'s value is greater than the treshold then h(x) = +1; otherwise h(x) = 0\n",
    "        error = error + ((df.iloc[i][\"Q\"] > threshold) - df.iloc[i][\"Y\"])**2\n",
    "    return error/len(df) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "e81ef135",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute the optimal threshold for the original distribution\n",
    "def compute_optimal_threshold(df, M):\n",
    "    '''\n",
    "    A function to learn the optimal threshold on the original dataset \n",
    "    \n",
    "    Input:\n",
    "        - df: a dataset with a column for 'Q' already computed\n",
    "        - M is the number of potential thresholds.\n",
    "        \n",
    "    Output:\n",
    "        - the optimal threshold for df\n",
    "    '''\n",
    "    # compute the boundary of Q:\n",
    "    Q_min, Q_max = np.min(df[\"Q\"]), np.max(df[\"Q\"])\n",
    "    \n",
    "    # loop through potential thresholds:\n",
    "    error_list = []\n",
    "    for i in range(M):\n",
    "        threshold = Q_min + i*(Q_max - Q_min)/M\n",
    "        error_list.append(compute_error(df, threshold))\n",
    "    index_min = np.argmin(error_list)\n",
    "    optimal_threshold = Q_min + index_min*(Q_max - Q_min)/M\n",
    "    return optimal_threshold"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a450501c",
   "metadata": {},
   "source": [
    "#### Compute induced distribution given a threshold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "3f52be7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute the induced distribution\n",
    "def compute_new_distribution_CS(df, sigma2, sigma3, threshold, c):\n",
    "    '''\n",
    "    input:\n",
    "        - df: original dataframe (without 'Q')\n",
    "        - sigma2:\n",
    "        - sigma3:\n",
    "        - threshold: a given threshold\n",
    "        - c: coefficient for adjusting the magnitude of distribution shift\n",
    "    output:\n",
    "        - the new dataframe for holding the induced distribution\n",
    "    '''\n",
    "    N = len(df)\n",
    "    # compute the new data distribution\n",
    "    # get the old classifier result\n",
    "    hX = 1*(df[\"Q\"] > threshold)\n",
    "    # generate new X1 \n",
    "    X1_new = df['X1'] + c*(2*hX - 1)\n",
    "    \n",
    "    X2_new = X1_new * 1.2 + np.random.normal(0,sigma2, N)  \n",
    "    X3_new = -X1_new**2 * 0.8 + np.random.normal(0,sigma3, N)\n",
    "    # Generate Y\n",
    "    Y_new = 1*(X2_new > 0)\n",
    "    d = {'X1': list(X1_new), 'X2': list(X2_new), 'X3': list(X3_new),  'Y': list(Y_new)}\n",
    "    df_new = pd.DataFrame(data=d)\n",
    "    return df_new"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8d9a058",
   "metadata": {},
   "source": [
    "#### Compute optimal classifier achieve minimum induced risk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "201752ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_optimal_threshold_MinInduceRisk_CS(df, sigma2, sigma3, M, c):\n",
    "    '''\n",
    "    find the optimal classifier achieve min Err(h)(h)\n",
    "    \n",
    "    input:\n",
    "        -  df: original dataframe (without 'Q')\n",
    "        -  M: number of potential classifiers to search for\n",
    "        -  c: coefficient for adjusting the magnitude of distribution shift\n",
    "        \n",
    "    output: \n",
    "        -  optimal threshold  = argmin Err(h)(h)\n",
    "        -  the minimum error \n",
    "    '''\n",
    "    N = len(df)\n",
    "    \n",
    "    # compute the coefficients on the initial data distribution\n",
    "    clf_CS = LogisticRegression(solver='liblinear', fit_intercept=True)\n",
    "    clf_CS.fit(df[[\"X1\", \"X2\", \"X3\"]], df['Y'])\n",
    "    \n",
    "    # compute the boundary of Q:\n",
    "    Q_min, Q_max = np.min(df[\"Q\"]), np.max(df[\"Q\"])\n",
    "    \n",
    "    # loop through the potential thresholds\n",
    "    error_list = []\n",
    "    for j in range(M):\n",
    "        threshold = Q_min + j*(Q_max - Q_min)/M  \n",
    "        # compute the new data distribution\n",
    "        df_new = compute_new_distribution_CS(df, sigma2, sigma3, threshold, c)\n",
    "        # compute the new qualification\n",
    "        # using the coefficient from the logistic regression coefficient for the old dataset\n",
    "        df_new['Q'] = clf_CS.predict_proba(df_new[[\"X1\", \"X2\", \"X3\"]])[:, 1]\n",
    "        # compute new data distribution's loss\n",
    "        Err_h_h = compute_error(df_new, threshold)\n",
    "        error_list.append(Err_h_h)\n",
    "    index_min = np.argmin(error_list)\n",
    "    min_error = np.min(error_list)\n",
    "\n",
    "    # compute the optimal threshold on the dataset that it induced call it hT\n",
    "    optimal_threshold = Q_min + index_min*(Q_max - Q_min)/M\n",
    "    return optimal_threshold, min_error"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb92620d",
   "metadata": {},
   "source": [
    "#### compute varaince of coefficients between two dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "21bd2085",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute the variance of w_h for two distribution\n",
    "def var_w(df, df_new, number_bin):\n",
    "    '''\n",
    "    Give two dataframe, compute w(x) = Pr_{df_new}(x)/Pr_{df}(x)\n",
    "    \n",
    "    Input:\n",
    "        - df: one dataframe\n",
    "        - df_new: another dataframe\n",
    "        - number_bin: # of bins\n",
    "    Output:\n",
    "        - w(x)\n",
    "    \n",
    "    '''\n",
    "    bin_init, _ = np.histogram(np.array(df[['X1', 'X2', 'X3']]), bins = number_bin)\n",
    "    bin_new, _ = np.histogram(np.array(df_new[['X1', 'X2', 'X3']]), bins = number_bin)\n",
    "    \n",
    "    return np.var(bin_init / bin_new)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6df881b",
   "metadata": {},
   "source": [
    "### Upper bound"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "994e34a2",
   "metadata": {},
   "source": [
    "Covariate Shift Upper Bound is:\n",
    "\n",
    "          Err(hS)(hS) - Err(hT)(hT) <= \\sqrt{Err(DS)(hT)} (\\sqrt{Var(w(hS))} + \\sqrt{Var(w(hT))})\n",
    "\n",
    "To characterize the CS upper bound, we need to compute the following quantities:\n",
    "  -  hS (optimal classifier trained on the original dataset)\n",
    "  \n",
    "  -  Err_hs_hs (error of hS on the distribution it induces D(hS)\n",
    "  \n",
    "  -  hT (optimal classifier considering induced data distribution)\n",
    "  \n",
    "  -  Err_hT_hT (error of hT on the distribution it induced D(hT))\n",
    "  \n",
    "  -  Err_S_hT (error of hT on the source distribution S)\n",
    "  \n",
    "  -  sqrt(var(whS))+sqrt(var(whT)): where whS and whT are the varaince of w(hS) and w(hT)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29aeca7e",
   "metadata": {},
   "source": [
    "#### Set Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "788ab243",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 500\n",
    "c = 0.25\n",
    "M = 100\n",
    "sigma2 = 0.15\n",
    "sigma3 = 0.1\n",
    "\n",
    "number_bin = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "67c9f909",
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate the initial data distribution\n",
    "df_init_CS = generate_init_CS(N, sigma2, sigma3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1254ad18",
   "metadata": {},
   "source": [
    "#### Train a logistic regression classifier based on the initial dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "40a66e89",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LogisticRegression(solver='liblinear')"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf_CS = LogisticRegression(solver='liblinear', fit_intercept=True)\n",
    "clf_CS.fit(df_init_CS[[\"X1\", \"X2\", \"X3\"]], df_init_CS['Y'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "60760aa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute the \"qualification\" (linear combination of features) based on the logistic regression model we trained\n",
    "df_init_CS['Q'] = clf_CS.predict_proba(df_init_CS[[\"X1\", \"X2\", \"X3\"]])[:, 1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81029a49",
   "metadata": {},
   "source": [
    "#### plot Y vs Q for the initial dataset D(S)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "ba61bc19",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAADgCAYAAADPGumFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAX+ElEQVR4nO3dfZRddX3v8fcnw4ROJDDBjNQMCUEMIAgpOpJIFbHU5gG9RBqvBKmF2lJ6RV1dvVyoteIDCppq0RU0F1mURamhWGmKGoxab6WrGGQikBAwGiKQByrhITwk0Uwm3/vHb09y5uScmZP8Zs+ZCZ/XWrPWOXv/9t7fc+bsz977t8/ZWxGBmVmOMc0uwMxGPweJmWVzkJhZNgeJmWVzkJhZNgeJmWVzkJhZNgfJKCbpdyX9QtJLkuYNwfwWS/rboWgr6aOSbmxwXjdLurrBtlMlhaRDGmk/2kjqkLRW0m810PZUSfcMR12DcZBUkfRPkm6qGvY2Sc9IenWJy22X9FVJ/y1pu6TVkv54kMk+BSyKiMMiYmluDRFxaUR8en/bSjpL0saq8Z+NiD/NrSlHrbqatZwiLHdKerH4e0jSNZKOqGp6JfAPEfHrYrqTJX1P0nOStkpaKWkuQESsArZKelcpL2w/OEj29WFgrqR3ABRbhq8BfxURT5axQEljgR8AxwBvBo4ALgc+L+nDA0x6DLDmAJd5UG7RR7jPR8R4oAO4GJgJ/JekVwBIOhT4Y+DWimm+BXwfOAp4Fenz+ULF+H8C/rz80gcREf6r+gPeA/wSeAVwDXBXnXYzgf8GWiqGvRtYVTw+Hegm/eN/BXyxznw+ADwFvKJq+HuLacfXmOZRYDewA3gJOBSYBNwJPAusA/6sov0ngH8hfUhfAP60xjxvBq4uHp8FbAT+qqjtSeDi6rbFe7SjqOWl4m9SsbxbK9p/o3ivngfuBk6utdwaNbUAfwc8DawHPggEcEgx/mLgEeDFYvyfF8Pr1XU68GNga/GaFgFji2kE/H3xep8HVgGvL8YdWtTxRPG/XAy01VvOQO9txbDxRQ2XFc/PBNZVjJ9YvNb2AT6rncXyD23mOuM9khoi4hvASmAJcAl1Ej8iVgDbgN+rGHwB8PXi8ZeAL0XE4cBxwO11FvkOUlhtqxr+TWAcKbCql30c6UP9rkiHNr8p6t1IWmHmA5+VdHbFZOeSwqSdtCUbzG+T9o46SWF3vaQJVXVsA+YAm4s6DouIzTXmdRcwjbRV/WmDywf4M+CdwGlAV/G6Kj1VjD+cFCp/L+kNA9TVC/wlaSV9M3A28L+Kef0BaWU+nvQevRd4phj3uWL47wCvLd6Tj+/H699HRLxI2tt4azHoFGBtRZNnSBuEWyXNk3RUjXlsAnqAExpZZlkcJPV9kBQQn4qIJwZotwRYACBpPDC3GAbpH/xaSRMj4qUieGqZSNoy9RMRu0hb4o7BipU0GXgLcEVE/DoiHgBuBP6ootmPI2JpROyOiB2DzbOo/1MR0RMRy0hb2wP6wEbETRHxYhF4nwCm1+gfqOV/AtdFxIaIeJa0h1g53+9ExKOR/Aj4HntXzFp1rIyIFRGxKyIeA/4v8LZidA9pL+FEQBHxSEQ8KUmkQPvLiHi2CIDPAufvx1tQz2bgyOJxO2nPqq/WAN4OPAZ8AXhS0t2SplXN48Vi2qZxkNQREb8ircSD9UF8HTivOL49D/hpRDxejPsAaSv2M0n3SXpnnXk8DezTkVv0Y0wEtjRQ8iSg70Pe53HSlrPPhgbmU+mZIsz6bAcO2895IKlF0rWSHpX0AmnFgPTaBjOJ/nU/XjlS0hxJKyQ9K2krKcjrzlfS8ZK+XXRqv0AKhIkAEfFD0qHO9cCvJN0g6XBSkI8DVhYdnluB79JAwDegk3QoCvAcKcj2iIiNEXFZsQd6DGkP+JaqeYwnHao1jYMkU0Q8TPpwz6H/YQ0R8YuIWEDanf8c8C99HWtVfgDMqTHuD0lbyZ80UMpm4Mhir6jPFGBTZbkNzOdADDbfC0iHVb9POlSaWgxXA/N+Ephc8XxK34MivL9J6rs4KiLagWUV861V11eBnwHTikPOj1bWERFfjog3AieTNgKXk4J+B6lfp734OyIi+kL1gN5XSYeR3pP/LAatKpZZU0RsIIXc6yvmMQkYS/9DomHnIBkaXyf1pp9J6lQEQNKFkjoiYjd7txi9Nab/R1LfxjeK70m0SpoFfJnU0//8YAUUH7J7gGsk/ZakU0l7RI32ReT4FfDKAQ5VxgO/IR3zjyPtBTTqduDDko4u+meurBg3ltQJugXYJWkOqZ9joLrGkzqbX5J0IvAXfSMkvUnSDEmtpC3/r4He4v/3NVL/y6uKtp3F/6iR19+PpEMlvRFYStoL+Ydi1E+AdkmdRbsJkj4p6bWSxkiaCPwJUHmIfBbww+KQsWkcJENjCXv/oU9XDJ8NrJH0Eqnj9fwovh9QqfgQ/D5pF/5e0tbvu8B1wCf3o44FpK39ZuBfgasi4vv7+Vr2W0T8jPQerC92/SdVNbmFtNe2CXiY/ivCYL4GLAceJHXS3lGx3BdJAX47aYW8gHTWaqC6/nfR7sVi3v9csazDi2HPFfU+Q9rbAbiC1PG5ojgk+gFFf1EDr7/P/5H0IulQ5hZSh/4ZfZ3sEbGTdHbnwqL9TtL/8wek8HuIFMgXVczzfaQzSE2l4hSSjSDFFvEu0op3Ufif9LIhqYN0qHPaYB3ikk4BboiINw9LcQPV4s/oyFTsJn8EuL3Y4pmNWA4SM8vmPhIzy+YgMbNso+6HWxMnToypU6c2uwyzl52VK1c+HRE1v4Q36oJk6tSpdHd3N7sMs5cdSY/XG+dDGzPL5iAxs2ylHdoUVxl7J/BURLy+xniRvu05l/RjsIsi4qdl1WP5lt6/iYXL17J56w4mtbdx+az0Q+BGhs07rXOfeRzR1ooEW7f37NOukeUd0dZKT+9utu1Mvzpob2vlE//jZLoff5av3/sEuyu+2TBGcMGM9DOdJfduoDeCFonXdIzj0S3b+rWtnKbW8D6tY2DMmDH8ZtduIP1g54zjjuSBDc/vqQnS1np3nXlcOHMKV887BYCPLV29p7aBvGJsC9t29iLBUH9747Frzzmg6Ur7HomkM0k/O7+lTpDMBT5ECpIZpOt2zBhsvl1dXeE+kuG39P5N/PUdq9nRs3cFaR0jEPT07v0MtbYIAnoq1sC21hauOS+tLNXzqNTXbt5pnQ0v72Bw4cwUcLeuGOhqFcOnXphIWhkRXbXGlbZHEhF3S5o6QJNzSSETpN8vtEt6dZR0OUPLs3D52n0CoKfG5rrWSr6jp5eFy9fueVxPX7t5p3U2vLyDwZJ79/fqDiNPM8/adNL/OhMbi2H7BImkS0hXKmPKlCnVo20YbN7ayHWQ8qfva5e7vNFksEOZ0aCZna21rkVR8x2NiBsioisiujo6huJaMra/JrW3ZU/fyDz62uQubzRpkWhRI5dmGbmaGSQb6X/BmqNJP3+3EejyWSfQ1trSb1jrGKU+kcphLUp9GRXaWlu4fNYJNedRq93+LO9gsGDGZBbMmDx4wxGsmYc2dwKXSbqN1Nn6vPtHRq6+sym5Z20qxw101qaR5R1sZ20qaxvIy+2sTd/FfiaSriB1FdAKEBGLi9O/i0gX/9lOutXBoKdjfNbGrDmaddZmwSDjg3SldjMb5fzNVjPL5iAxs2wOEjPL5iAxs2wOEjPL5iAxs2wOEjPL5iAxs2wOEjPL5iAxs2wOEjPL5iAxs2wOEjPL5iAxs2wOEjPL5iAxs2wOEjPL5iAxs2wOEjPL5iAxs2ylBomk2ZLWSlon6coa44+Q9C1JD0paI+niMusxs3KUFiSSWoDrgTnAScACSSdVNfsg8HBETCfduuILksaWVZOZlaPMPZLTgXURsT4idgK3kW4cXimA8cU9bg4DngV2lViTmZWgzCCpd5PwSouA15Fu1bka+EhE1LspmZmNUGUGSSM3CZ8FPABMAn4HWCTp8H1mJF0iqVtS95YtW4a6TjPLVGaQNHKT8IuBOyJZB/wSOLF6RhFxQ0R0RURXR0dHaQWb2YEpM0juA6ZJOrboQD2fdOPwSk8AZwNIOgo4AVhfYk1mVoIy7/27S9JlwHKgBbgpItZIurQYvxj4NHCzpNWkQ6ErIuLpsmoys3KUFiQAEbEMWFY1bHHF483AH5RZg5mVz99sNbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy+YgMbNspQaJpNmS1kpaJ+nKOm3OkvSApDWSflRmPWZWjtJuRyGpBbgeeAfprnv3SbozIh6uaNMOfAWYHRFPSHpVWfWYWXnK3CM5HVgXEesjYidwG3BuVZsLSLfsfAIgIp4qsR4zK0mZQdIJbKh4vrEYVul4YIKk/5C0UtL7a83INxE3G9nKDBLVGBZVzw8B3gicA8wC/lbS8ftM5JuIm41oZd6ycyMwueL50cDmGm2ejohtwDZJdwPTgZ+XWJeZDbEy90juA6ZJOlbSWOB84M6qNv8GvFXSIZLGATOAR0qsycxKUNoeSUTsknQZsBxoAW6KiDWSLi3GL46IRyR9F1gF7AZujIiHyqrJzMqhiOpui5Gtq6sruru7m12G2cuOpJUR0VVrnL/ZambZHCRmls1BYmbZHCRmls1BYmbZHCRmls1BYmbZ6gaJpGWSpg5jLWY2Sg20R3Iz8D1JfyOpdZjqMbNRqO5X5CPidknfAT4OdEv6R9LX2PvGf3EY6jOzUWCw39r0ANuAQ4HxVASJmVmfukEiaTbwRdIvdt8QEduHrSozG1UG2iP5G+A9EbFmuIoxs9FpoD6Stw5nIWY2evl7JGaWzUFiZtkcJGaWzUFiZtkcJGaWzUFiZtkcJGaWrdQgkTRb0lpJ6yRdOUC7N0nqlTS/zHrMrBylBYmkFuB6YA5wErBA0kl12n2OdP8bMxuFytwjOR1YFxHrI2IncBtwbo12HwK+CTxVYi1mVqIyg6QT2FDxfGMxbA9JncC7gcUDzUjSJZK6JXVv2bJlyAs1szxlBolqDKu+rd91wBUR0TvQjCLihojoioiujo6OoarPzIZIaff+Je2BTK54fjSwuapNF3CbJICJwFxJuyJiaYl1mdkQKzNI7gOmSToW2AScD1xQ2SAiju17LOlm4NsOEbPRp7QgiYhdki4jnY1pAW6KiDWSLi3GD9gvYmajR5l7JETEMmBZ1bCaARIRF5VZi5mVx99sNbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy+YgMbNsDhIzy9bUm4hLep+kVcXfPZKml1mPmZWj2TcR/yXwtog4Ffg0cENZ9ZhZeZp6E/GIuCciniueriDdjc/MRpmm3kS8ygeAu0qsx8xKUuYNshq5iXhqKL2dFCRvqTP+EuASgClTpgxVfWY2RMrcI2nkJuJIOhW4ETg3Ip6pNaOIuCEiuiKiq6Ojo5RizezAlRkke24iLmks6Sbid1Y2kDQFuAP4o4j4eYm1mFmJmn0T8Y8DrwS+IglgV0R0lVWTmZVDETW7LUasrq6u6O7ubnYZZi87klbW29D7m61mls1BYmbZHCRmls1BYmbZHCRmls1BYmbZHCRmls1BYmbZHCRmls1BYmbZHCRmls1BYmbZHCRmls1BYmbZHCRmls1BYmbZHCRmls1BYmbZHCRmls1BYmbZHCRmlq3MO+0haTbwJdLtKG6MiGurxqsYPxfYDlwUET/NWebUK7+TM/kBax0DrS1j2N6zu9+w3oDdAS0SM18zge7Ht/KbXbsHmFN/It2ecMK4ViLg+R09HNHWigRbt/cwqb2Ny2edwLzT+t8N9WNLV7Pk3g30RtAisWDGZLqOOZKFy9eyeeuOPdMBLFy+lk1bd+xZVp8J41o559RX8/9+tqXmNH3D3n5iB99+8Em27ujpN+1V7zqZead1svT+TXzyW2t4bntPv9fUWWd+tV5Pn6X3b2q4bRnTW22l3Y5CUgvwc+AdpLvu3QcsiIiHK9rMBT5ECpIZwJciYsZA8x3odhTNCpFma2tt4ZrzTtmzQnxs6WpuXfHEPu1axoje3Xv/360tgoCe3Y1/BlrHCAQ9vYNP09oi3vumyfzzfRvqtq81v+rX02fp/Zv46ztWs6Ond9C2teRO/3LXrNtRnA6si4j1EbETuA04t6rNucAtkawA2iW9usSaDko7enpZuHztnudL7t1Qs11vVWD09MZ+hQik0GkkRPrmv+Te+iFSb37Vr6fPwuVr+4XAQG1ryZ3e6iszSDqByk/0xmLY/rZB0iWSuiV1b9myZcgLPRhs3rpjz+PeEXTTswOtpfL1DDRsoOFDPb3VV2aQqMaw6k9VI218E/EGTGpv2/O4RbXe1uY40FoqX89AwwYaPtTTW31lBslGYHLF86OBzQfQxgbR1tqyp9MSYMGMyTXbtYzpv1K3tij1UeyH1jFKfSuNtG1JnbwDta81v+rX0+fyWSfQ1trSUNtacqe3+soMkvuAaZKOlTQWOB+4s6rNncD7lcwEno+IJw90gY9de86BV5updQyMax2zz7C+9bRF4nePO5JDD9m/t7xvFZswrpX2tlYEtLe1MmFcetzZ3rZPZ+HV807hwplT9uwNtEhcOHMKX3jPdDrb2/ZMt3D+dBYWwyqXRcUyL5w5pf8075nOwvn953PhzCm0t7XuM+3C+dO5et4pLJw/nQnj9o7vW069+dXr/Jx3WifXnHdKQ21ryZ3e6iv1JuLFWZnrSKd/b4qIz0i6FCAiFhenfxcBs0mnfy+OiAHvEO6biJs1x0BnbUr9HklELAOWVQ1bXPE4gA+WWYOZlc/fbDWzbA4SM8tWah9JGSRtAR5voOlE4OmSy8nlGvON9Ppg5NfYaH3HRETN71+MuiBplKTueh1DI4VrzDfS64ORX+NQ1OdDGzPL5iAxs2wHc5Dc0OwCGuAa8430+mDk15hd30HbR2Jmw+dg3iMxs2Ey6oNE0mxJayWtk3RljfGS9OVi/CpJbxiBNb6vqG2VpHskTR9J9VW0e5OkXknzh7O+YtmD1ijpLEkPSFoj6UcjqT5JR0j6lqQHi/ouHub6bpL0lKSH6ozPW08iYtT+kX7D8yjwGmAs8CBwUlWbucBdpN+KzQTuHYE1ngFMKB7PGc4aG6mvot0PST95mD8C38N24GFgSvH8VSOsvo8CnysedwDPAmOHscYzgTcAD9UZn7WejPY9ktFwFbZBa4yIeyLiueLpCtLlFEZMfYUPAd8EnhrG2vo0UuMFwB0R8QRARAxnnY3UF8D44oeqh5GCZNdwFRgRdxfLrCdrPRntQTJkV2Er0f4u/wOkLcNwGbQ+SZ3Au4HFNEcj7+HxwARJ/yFppaT3D1t1jdW3CHgd6Xo7q4GPRETjVwEvX9Z6Uuqvf4fBkF2FrUQNL1/S20lB8pZSK6pabI1h1fVdB1wREb1qztXXGqnxEOCNwNlAG/BjSSsi4udlF0dj9c0CHgB+DzgO+L6k/4yIF0qurVFZ68loD5LRcBW2hpYv6VTgRmBORDwzTLVBY/V1AbcVITIRmCtpV0QsHZYKG/8/Px0R24Btku4GppPuZDAS6rsYuDZSh8Q6Sb8ETgR+Mgz1NSJvPRmuzp6SOpAOAdYDx7K3k+vkqjbn0L8T6ScjsMYpwDrgjJH4Hla1v5nh72xt5D18HfDvRdtxwEPA60dQfV8FPlE8PgrYBEwc5vdxKvU7W7PWk1G9RxIRuyRdBixn71XY1lRehY10lmEuaUXdTtoyjLQaPw68EvhKsdXfFcP0I68G62uqRmqMiEckfRdYBewm3ZCt5qnOZtQHfBq4WdJq0sp6RUQM2y+CJS0BzgImStoIXAW0VtSXtZ74m61mlm20n7UxsxHAQWJm2RwkZpbNQWJm2RwkZpbNQWKlk3S0pH+T9AtJ6yUtknRos+uyoeMgsVIVP1K7A1gaEdOAaaSvsH++qYXZkPL3SKxUks4GroqIMyuGHU66pcjkiHipacXZkPEeiZXtZGBl5YBIP1R7DHhtMwqyoecgsbKJ2r8ibcrPiK0cDhIr2xrSr4f3KA5tjgLWNqUiG3IOEivbvwPj+i40JKkF+AKwKCJ2NLUyGzIOEitVpN78dwPzJf0CeAbYHRGfaW5lNpR81saGlaQzgCXAeRGxcrD2Njo4SMwsmw9tzCybg8TMsjlIzCybg8TMsjlIzCybg8TMsjlIzCzb/wdJxDJdn9CGLQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 288x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(1, figsize=(4, 3))\n",
    "plt.scatter(df_init_CS[\"Q\"], df_init_CS[\"Y\"])\n",
    "plt.ylabel(\"Y\")\n",
    "plt.xlabel(\"Q\")\n",
    "plt.title(\"Y vs Q for initial dataset D(S)\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36a04390",
   "metadata": {},
   "source": [
    "#### compute hS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "79fbb397",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hS: 0.5199982722524372\n"
     ]
    }
   ],
   "source": [
    "# compute the optimal threshold on the original dataset, name it \"hS_CS\"\n",
    "hS_CS = compute_optimal_threshold(df_init_CS, M)\n",
    "print(\"hS:\", hS_CS)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "53a2490b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Err_S_hS_CS: 0.014\n"
     ]
    }
   ],
   "source": [
    "# error on the original dataset S using classifier hS_CS\n",
    "Err_S_hS_CS = compute_error(df_init_CS, hS_CS)\n",
    "print(\"Err_S_hS_CS:\", Err_S_hS_CS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "073fceb7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Err_hS_hS_CS: 0.002\n"
     ]
    }
   ],
   "source": [
    "# compute the induced data distribution by hS call it df_hS_CS\n",
    "df_hS_CS = compute_new_distribution_CS(df_init_CS, sigma2, sigma3, hS_CS, c)\n",
    "\n",
    "\n",
    "# compute the new qualification\n",
    "df_hS_CS['Q'] = clf_CS.predict_proba(df_hS_CS[[\"X1\", \"X2\", \"X3\"]])[:, 1]\n",
    "# compute the error on the induced dataset using hS, call it Err_hS_hS_LS\n",
    "Err_hS_hS_CS = compute_error(df_hS_CS, hS_CS)\n",
    "print(\"Err_hS_hS_CS:\", Err_hS_hS_CS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "ae79ba30",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 1.0, 'Y vs Q for D(hS)')"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAADgCAYAAADPGumFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAT0UlEQVR4nO3de5AdZZ3G8e/DJMFBkEEzWmYgBDVEUIjgiCyrgiKbENRAFqoAFaFYs1SJ+oebBZZa8bIKmoJCK0CKpTDrDQo0xrhc4m0RS0QyMUAIEIggkIlKAkRuo+Ty2z+6J5wczsyc5J33XJLnUzXlOd1vd//Scp7T/Z7ufhURmJml2K3ZBZhZ+3OQmFkyB4mZJXOQmFkyB4mZJXOQmFkyB4mZJXOQWDaS/lHSw5Kel3RiA7fbLWmVpFeV72+T9C87uK7LJJ0zuhXufBwkbUbS9yRdWzXtaElPSXpjxu12SbpK0p8lvShphaRPjLDYl4B5EbFnRCwahRoWSHpJ0nPl332SLpa0d1XT84FvRcTf6lzv2ZIeLNf5F0k3SdqrnD0XuFDSuNT6d2YOkvbzGWCGpOMAym/d/wY+FxF/yrHB8kP0c2B/4B+AvYE5wNclfWaYRfcHVu7gNscMMevrEbEX0A2cBRwJ/EbSq8vldgc+AXy3zu0cDXwVOK1c70HADYPzy336IPCRHfl37CocJG0mIp4CPg1cXX54LgL+EBELqttKOrI8guiomHaSpHvL10dI6pP0bPlNfNkQm/04MBE4JSIejYiNEXErRaj9V8W3d+W2/wC8CfhJeWqzu6QJkhZLelrSakmfrGj/BUk/kPRdSc8CZ46wH/4WEUspPuCvowgVgHcDGyJiTdUi+0v6TXnU8VNJ48vp7wJ+GxHLy/U+HRH/ExHPVSx7G3DCcPXs6hwkbSgibgSWAdcBs4F/HaLdncALwAcqJp8OfL98/Q3gGxHxGuDNVHwTVzkOuCUiXqia/kNgD4qjguptvxl4HPhweWrz97LeNcAE4GTgq5KOrVhsJvADoAv43hC1VG/nOeBnwHvLSYcAq2o0PZ0ibF4PjAP+rZz+O2CapC+WfTq711j2AWBqPfXsqhwk7etTFAHxpYh4fJh21wGnAZRHDjPKaQAbgbdIGh8Rz5fBU8t44BWnTRGxCVhPcZoxLEn7Ae8BziuPJu4GrqE42hn024hYFBFbImJgpHVWWAu8tnzdBTxXo823IuKhcr03AO8o/w2/BmYBhwM3AU+VHawdFcs+V67XhuAgaVMR8ReKD/FIfRDfB2aV37SzgN9HxGPlvLOBA4EHJS2V9KEh1rEeeEVHbtmPMR5YV0fJE4Cnq04ZHgN6Kt4/Ucd6aukBni5fPwO84lQL+HPF6xeBPQffRMQtEfFhijCaSXFaVfkrz17Ahh2sbZfgINnJRcT9FB/Y49n2tIaIeDgiTqM43P8a8IPBTssqPweOrzHvnymOau6qo5S1wGur+lMmAv2V5daxnm1I2hP4IPDrctK9FOG43cojoV8AvwTeXjHrIOCeHVnnrsJBsmv4PkXH6PuAGwcnSvqYpO6I2MLL37ibayz/HYq+jRslTZI0VtI04JsUv6L8daQCIuIJ4A7gYkmvknQoxRFRXX0h1crO23cCiyiOQr5VzroL6JLUM9SyVeuZKelUSfuocARwNFB5mnc0cMuO1LmrcJDsGq4DjgF+GRHrK6ZPB1ZKep6i4/XUWtdelB2lH6Q49fgdMADcClwOfHE76jgNmERxdPIj4KKI+Nl2/lv+XdJzFKcy36bodD5qsCM4Il4CFgAfq3N9zwCfBB4GnqX42XhuRHwPoLw252CKwLIhyE9Is+0laSzFN3Q/cGa02H9EkropTnUO285O21rrupTi5/UrR6W4nZSDxHZIeTXpZ4EbIuLBZtdjzeUgMbNk7iMxs2QOEjNLNtSNUS1r/PjxMWnSpGaXYbbLWbZs2fqIqHkVc9sFyaRJk+jr62t2GWa7HEmPDTXPpzZmlsxBYmbJsp3alE/x+hDwZES8vcZ8UVxNOYPiJqozI+L3ueqx1rBoeT9zl6xi7YYBJnR1MmfaFE48rK6r2Xcag/ugf8MAHRKbI+gZYV8sWt7PhT9awQsv1bqDobD7mN34+6YtSbX98ZIde+xKziOSBRSXYA/leGBy+TcbuCpjLdYCFi3v54KFK+jfMEAA/RsGuGDhChYt7x9x2Z1F5T4A2FxexzXcvli0vJ/P3XjPsCECJIcIwKTzb9qh5bIFSUTczsu3dtcyE/h2FO6kuNEq2zNHrfnmLlnFwMZtPwwDGzczd0mt5xDtnGrtg0FD7Yu5S1axeUtrXzjazD6SHrZ9/sQatn02xVaSZpePBOxbt66eR19YK1q7ofZtL0NN3xmN9G+tNb8d9k8zg0Q1ptWM3Yi4OiJ6I6K3u3vEh3FZi5rQ1bld03dGI/1ba81vh/3TzCBZA+xX8X5fitvLbSc1Z9oUOsd2bDOtc2wHc6ZNaVJFjVdrHwwaal/MmTaFjt1qfe+2jmYGyWLgjPJhMkcCf801nIK1hhMP6+HiWYfQ09WJgJ6uTi6edcgu9atN5T4A6FAREMPtixMP6+HSU6by6nG1A2jQ7mPSP847+qtNtrt/JQ0+TGc88BeKYRPGAkTE/PLn33kUv+y8CJwVESNestrb2xu+stWs8SQti4jeWvOyXUdSPgt0uPlB8SR0M2tzvrLVzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsWdYgkTRd0ipJqyWdX2P+3pJ+IukeSSslnZWzHjPLI1uQSOoArqAYLPxg4DRJB1c1+xRwf0RMpRi64lJJ43LVZGZ55DwiOQJYHRGPRMRLwPUUA4dXCmCvcoybPSkGHd+UsSYzyyBnkNQzSPg84CCKoTpXAJ+NiC0ZazKzDHIGST2DhE8D7gYmAO8A5kl6zStWJM2W1Cepb926daNdp5klyhkk9QwSfhawMAqrgUeBt1avKCKujojeiOjt7u7OVrCZ7ZicQbIUmCzpgLID9VSKgcMrPQ4cCyDpDcAU4JGMNZlZBjnH/t0k6VxgCdABXBsRKyWdU86fD3wZWCBpBcWp0HkRsT5XTWaWR7YgAYiIm4Gbq6bNr3i9FvinnDWYWX6+stXMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkmUNEknTJa2StFrS+UO0OUbS3ZJWSvpVznrMLI9sw1FI6gCuAI6jGHVvqaTFEXF/RZsu4EpgekQ8Lun1ueoxs3xyHpEcAayOiEci4iXgemBmVZvTKYbsfBwgIp7MWI+ZZZIzSHqAJyrerymnVToQ2EfSbZKWSTqj1oo8iLhZa8sZJKoxLarejwHeCZwATAP+U9KBr1jIg4ibtbScQ3auAfareL8vsLZGm/UR8QLwgqTbganAQxnrMrNRlvOIZCkwWdIBksYBpwKLq9r8GHivpDGS9gDeDTyQsSYzyyDbEUlEbJJ0LrAE6ACujYiVks4p58+PiAck3QrcC2wBromI+3LVZGZ5KKK626K19fb2Rl9fX7PLMNvlSFoWEb215vnKVjNL5iAxs2QOEjNL5iAxs2QOEjNL5iAxs2QOEjNLNmSQSLpZ0qQG1mJmbWq4I5IFwE8lXShpbIPqMbM2NOQl8hFxg6SbgM8DfZK+Q3EZ++D8yxpQn5m1gZHutdkIvADsDuxFRZCYmQ0aMkgkTQcuo7hj9/CIeLFhVZlZWxnuiORC4JSIWNmoYsysPQ3XR/LeRhZiZu3L15GYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWbKsQSJpuqRVklZLOn+Ydu+StFnSyTnrMbM8sgWJpA7gCuB44GDgNEkHD9HuaxTj35hZG8p5RHIEsDoiHomIl4DrgZk12n0a+CHwZMZazCyjnEHSAzxR8X5NOW0rST3AScD84VYkabakPkl969atG/VCzSxNziBRjWnVw/pdDpwXEZuHW1FEXB0RvRHR293dPVr1mdkoyTb2L8URyH4V7/cF1la16QWulwQwHpghaVNELMpYl5mNspxBshSYLOkAoB84FTi9skFEHDD4WtIC4H8dImbtJ1uQRMQmSedS/BrTAVwbESslnVPOH7ZfxMzaR84jEiLiZuDmqmk1AyQizsxZi5nl4ytbzSyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjV1EHFJH5V0b/l3h6SpOesxszyaPYj4o8DREXEo8GXg6lz1mFk+TR1EPCLuiIhnyrd3UozGZ2ZtpqmDiFc5G7glYz1mlknOAbLqGUS8aCi9nyJI3jPE/NnAbICJEyeOVn1mNkpyHpHUM4g4kg4FrgFmRsRTtVYUEVdHRG9E9HZ3d2cp1sx2XM4g2TqIuKRxFIOIL65sIGkisBD4eEQ8lLEWM8uo2YOIfx54HXClJIBNEdGbqyYzy0MRNbstWlZvb2/09fU1uwyzXY6kZUN90fvKVjNL5iAxs2QOEjNL5iAxs2QOEjNL5iAxs2QOEjNL5iAxs2QOEjNL5iAxs2QOEjNL5iAxs2QOEjNL5iAxs2QOEjNL5iAxs2QOEjNL5iAxs2QOEjNL5iAxs2QOEjNLlnOkPSRNB75BMRzFNRFxSdV8lfNnAC8CZ0bE71O2Oen8m1IW30rAmN1g45aXp3V1juVDU9/I/z24jrUbBpjQ1cmcaVM48bBiJNJFy/uZu2QV/RsGkGDwAf1dnWP5wkfetrVdPQbXVWs7Zq0mW5BI6gCuAI6jGHVvqaTFEXF/RbPjgcnl37uBq8r/3SGjFSJQjC1aGSIAGwY28t07H9/6vn/DABcsXLH1/QULVzCwcXOxfGy73Jwb7wGoKwwWLe/fZl2V23GYWCvKeWpzBLA6Ih6JiJeA64GZVW1mAt+Owp1Al6Q3Zqxp1A1s3MzcJauYu2TV1g9+LRu3BHOXrKprnbXWNbgds1aUM0h6gCcq3q8pp21vGyTNltQnqW/dunWjXmiqtRsGWLthoK529a4vZXmzRssZJKoxrXpYv3ratPwg4hO6OpnQ1VlXu3rXl7K8WaPlDJI1wH4V7/cF1u5Am5bWObaDOdOmMGfaFDrHdgzZbuxuYs60KXWts9a6Brdj1opyBslSYLKkAySNA04FFle1WQycocKRwF8j4k87usE/XnLCjldbRcDYqr3T1TmWjx05kZ6uTgT0dHVy8axDOPGwHk48rIeLZx1CT3nUIG273NxTptbdUVq5rurtmLWirIOIS5oBXE7x8++1EfEVSecARMT88uffecB0ip9/z4qIYUcI9yDiZs0x3CDiWa8jiYibgZurps2veB3Ap3LWYGb5+cpWM0vmIDGzZFn7SHKQtA54rI6m44H1mctJ5RrTtXp90Po11lvf/hFR8/qLtguSeknqG6pjqFW4xnStXh+0fo2jUZ9PbcwsmYPEzJLtzEFydbMLqINrTNfq9UHr15hc307bR2JmjbMzH5GYWYO0fZBImi5plaTVks6vMV+SvlnOv1fS4S1Y40fL2u6VdIekqa1UX0W7d0naLOnkRtZXbnvEGiUdI+luSSsl/aqV6pO0t6SfSLqnrO+sBtd3raQnJd03xPy0z0lEtO0fxT08fwDeBIwD7gEOrmozA7iF4j68I4HftWCNRwH7lK+Pb2SN9dRX0e6XFLc8nNyC+7ALuB+YWL5/fYvV9x/A18rX3cDTwLgG1vg+4HDgviHmJ31O2v2IpB2ewjZijRFxR0Q8U769k+JxCi1TX+nTwA+BJxtY26B6ajwdWBgRjwNERCPrrKe+APYqb1TdkyJINjWqwIi4vdzmUJI+J+0eJKP2FLaMtnf7Z1N8MzTKiPVJ6gFOAubTHPXswwOBfSTdJmmZpDMaVl199c0DDqJ43s4K4LMRUfVU4KZK+pxkvfu3AUbtKWwZ1b19Se+nCJL3ZK2oarM1plXXdzlwXkRslmo1z66eGscA7wSOBTqB30q6MyIeyl0c9dU3Dbgb+ADwZuBnkn4dEc9mrq1eSZ+Tdg+SdngKW13bl3QocA1wfEQ81aDaoL76eoHryxAZD8yQtCkiFjWkwvr/f14fES8AL0i6HZgKNCJI6qnvLOCSKDokVkt6FHgrcFcD6qtH2uekUZ09mTqQxgCPAAfwcifX26ranMC2nUh3tWCNE4HVwFGtuA+r2i+g8Z2t9ezDg4BflG33AO4D3t5C9V0FfKF8/QagHxjf4P04iaE7W5M+J219RBIRmySdCyzh5aewrax8ChvFrwwzKD6oL1J8M7RajZ8HXgdcWX7rb4oG3eRVZ31NVU+NEfGApFuBe4EtFAOy1fypsxn1AV8GFkhaQfFhPS8iGnZHsKTrgGOA8ZLWABcBYyvqS/qc+MpWM0vW7r/amFkLcJCYWTIHiZklc5CYWTIHiZklc5BYdpL2lfRjSQ9LekTSPEm7N7suGz0OEsuqvEltIbAoIiYDkykuYf96UwuzUeXrSCwrSccCF0XE+yqmvYZiSJH9IuL5phVno8ZHJJbb24BllROiuFHtj8BbmlGQjT4HieUmat9F2pTbiC0PB4nltpLi7uGtylObNwCrmlKRjToHieX2C2CPwQcNSeoALgXmRcRAUyuzUeMgsayi6M0/CThZ0sPAU8CWiPhKcyuz0eRfbayhJB0FXAfMiohlI7W39uAgMbNkPrUxs2QOEjNL5iAxs2QOEjNL5iAxs2QOEjNL5iAxs2T/D0XI9fxbLjuPAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 288x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(1, figsize=(4, 3))\n",
    "plt.clf()\n",
    "plt.scatter(df_hS_CS[\"Q\"], df_hS_CS[\"Y\"])\n",
    "plt.ylabel(\"Y\")\n",
    "plt.xlabel(\"Q\")\n",
    "plt.title(\"Y vs Q for D(hS)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "126ff233",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hT_CS: 0.40000035434460346\n",
      "Err_hT_hT_CS: 0.0\n"
     ]
    }
   ],
   "source": [
    "# compute the optimal classifier \"hT_CS\" consider induced data and the corresponding induced error \"Err_T_T_CS\"\n",
    "hT_CS, Err_hT_hT_CS = compute_optimal_threshold_MinInduceRisk_CS(df_init_CS, sigma2, sigma3, M, c)\n",
    "\n",
    "print(\"hT_CS:\", hT_CS)\n",
    "print(\"Err_hT_hT_CS:\",Err_hT_hT_CS)\n",
    "\n",
    "# compute the new data distribution induced by hT\n",
    "df_hT_CS = compute_new_distribution_CS(df_init_CS, sigma2, sigma3, hT_CS, c)\n",
    "df_hT_CS['Q'] = clf_CS.predict_proba(df_hT_CS[[\"X1\", \"X2\", \"X3\"]])[:, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "27c0df0f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Err_S_hT_CS 0.022\n"
     ]
    }
   ],
   "source": [
    "# compute Err_S_hT\n",
    "Err_S_hT_CS = compute_error(df_init_CS, hT_CS)\n",
    "print(\"Err_S_hT_CS\", Err_S_hT_CS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "927779b4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 1.0, 'Y vs Q for D(hT)')"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAADgCAYAAADPGumFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAATO0lEQVR4nO3de5BcZZ3G8e+zQ4JhuQyY0TJDQhBDBAVER2BREEU2F9xNYLEKEBHKXZYqUf/YZcHVFdfLglJaYAVNsSymvIVCyca4BuJtEUsEMhFICBgI9yQqCReFEJck/PaPcwY6TfdMZ95++zJ5PlUpus95+z2/NOmnz3n7nPMqIjAzS/EX7S7AzLqfg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8SykfQOSQ9Iek7S3BZut0/SGkmvKp/fLOnvR9nXIkkzm1vh2OMg6TKSviPp2qpl75L0pKTXZdxur6SvS/q9pOclrZL0oRFe9llgXkTsGRGLm1DDAkkvSHq2/HOPpEsl7VPV9GLgGxHx5xH6O64MueckbZYUFc+fkzQFuAz4QmrtY52DpPt8DJgt6SSA8lv3P4F/iojf5digpPHAT4EDgL8C9gEuBL4k6WPDvPQAYPUot7lbnVVfioi9gD7gXOAY4FeS/rJ83e7Ah4Bvj7SNiPhlGXJ7Am8qF/cOLYuIxyLiDmBvSQOj+XvsKhwkXSYingQ+ClxdfnguAR6MiAXVbSUdU+5B9FQsO0XSyvLxUZIGJf1J0h8kfaXOZj8ITAHeHxEPR8TWiLiJItQ+L2mvGtt+EHg98MPy2313SZMkLZH0lKS1kv6hov1nJH1f0rcl/Qk4Z4T34c8RsRz4W+DVFKECcDTwTESsq3rJAZJ+Ve7J/FjSxOH6r3IzcPJOtN/lOEi6UER8D1gBLATOA/6xTrvbgM3AeyoWnwl8t3x8JXBlROwNHARcX2eTJwE3RsTmquU3AHtQ7BVUb/sg4DHgb8pv9/8r610HTAJOA/5D0okVL5sDfB/oBb5Tp5bq7TwL/AQ4rlx0GLCmRtMzKcLmNcB44J8b6b90H3DETrTf5ThIutdHKALisxHx2DDtFgJnAJR7DrPLZQBbgTdImhgRz5XBU8tE4BWHTRGxDdhEcZgxLEmTgXcCF5V7E3cB11Ds7Qz5dUQsjogXI2LLSH1W2ADsVz7uBZ6t0eYbEXF/2e/1wFt2ov9ny36tDgdJl4qIP1B8iEcag/gucGo5dnAq8JuIeLRc92HgYOC3kpZLel+dPjYBrxjILccxJgIbGyh5EvBUuQcx5FGgv+L54w30U0s/8FT5+GngFYdawO8rHj8P7LkT/e8FPDOqynYRDpIxLiLupfjAzmLHwxoi4oGIOINid/+LwPeHBi2r/BSYVWPd31Hs1dzRQCkbgP2qxlOmAOsry22gnx1I2hN4L/DLctFKinBspkOAu5vc55jiINk1fJdiYPR44HtDCyWdJakvIl7k5W/c7TVe/y2KsY3vSZoqaZykGcBXKX5F+eNIBUTE48CtwKWSXiXpcIo9oobGQqqVg7dvAxZT7IV8o1x1B9Arqb/ea0fhXcCNTexvzHGQ7BoWAicAP4+ITRXLZwKrJT1HMfB6eq1zL8qB0vdSHHrcDmwBbgKuAP59J+o4A5hKsXfy38AlEfGTnfy7/IukZykOZb5JMeh87NBAcES8ACwAztrJfmuS9HZgc/kzsNUh3yHNdpakcRTf0OuBc6LD/hFJ6qM41DlyJwdta/V1A/BfEbG0KcWNUQ4SG5XybNKPA9dHxG/bXY+1l4PEzJJ5jMTMkjlIzCxZvQujOtbEiRNj6tSp7S7DbJezYsWKTRFR8yzmrguSqVOnMjg42O4yzHY5kh6tt86HNmaWzEFiZsmyHdqUd/F6H/BERLy5xnpRnE05m+IiqnMi4je56jFrlcV3rufyZWvY8MwWJvVO4N1v7ON/f7vxpecXzpjO3CNfPoP/U4tXsfD2x9ne4KkYorgoqb/s+0crf8fTz29tSu2PXDa6265kO49E0vHAc8A36wTJbIob9MymuBnNlRFx9Ej9DgwMhMdIrFMtvnM9n1i0ii1ba12yVJgwrodLTz2MuUf286nFq/j2bcPdBaL16oWJpBURUfNOcdkObSLiFl6+tLuWORQhE+V9MHpz3nPUrBUuX7Zm2BAB2LJ1O5cvK+69tPD20d45obO0c4yknx3vP7GOHe9N8RJJ55W3BBzcuLGRW1+YtceGZxq7tGeoXaOHM52unUGiGstqvqsRcXVEDETEQF/fiDfjMmubSb0Tdqpdj2p9DLpPO4NkHTC54vn+FJeXm3WtC2dMZ8K4nmHbTBjXw4UzpgNwxtGTh23bLdoZJEuAs1U4BvhjrukUzFpl7pH9XHrqYfT3TkAUv6ycdcyUHZ4PDbQCfH7uYZx1zJSd2jMZajnU9757jGta/Z34q83QzXQmAn+gmDZhHEBEzC9//p1HcXOd54FzI2LEn2P8q41Zewz3q02280jKe4EOtz4o7oRuZl3OZ7aaWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklyxokkmZKWiNpraSLa6zfR9IPJd0tabWkc3PWY2Z5ZAsSST3AVcAs4FDgDEmHVjX7CHBvRBxBMXXFlyWNz1WTmeWRc4/kKGBtRDwUES8A11FMHF4pgL3KOW72pJh0fFvGmswsg5xB0sgk4fOAQyim6lwFfDwiXsxYk5llkDNIGpkkfAZwFzAJeAswT9Ler+hIOk/SoKTBjRs3NrtOM0uUM0gamST8XGBRFNYCDwNvrO4oIq6OiIGIGOjr68tWsJmNTs4gWQ5Mk3RgOYB6OsXE4ZUeA04EkPRaYDrwUMaazCyDnHP/bpN0AbAM6AGujYjVks4v188HPgcskLSK4lDooojYlKsmM8sjW5AARMRSYGnVsvkVjzcAf52zBjPLz2e2mlkyB4mZJXOQmFkyB4mZJXOQmFkyB4mZJXOQmFkyB4mZJXOQmFkyB4mZJXOQmFkyB4mZJXOQmFkyB4mZJXOQmFkyB4mZJXOQmFkyB4mZJXOQmFkyB4mZJXOQmFmyrEEiaaakNZLWSrq4TpsTJN0labWkX+Ssx8zyyDYdhaQe4CrgJIpZ95ZLWhIR91a06QW+BsyMiMckvSZXPWaWT849kqOAtRHxUES8AFwHzKlqcybFlJ2PAUTEExnrMbNMcgZJP/B4xfN15bJKBwP7SrpZ0gpJZ9fqyJOIm3W2nEGiGsui6vluwNuAk4EZwL9JOvgVL/Ik4mYdLeeUneuAyRXP9wc21GizKSI2A5sl3QIcAdyfsS4za7KceyTLgWmSDpQ0HjgdWFLV5gfAcZJ2k7QHcDRwX8aazCyDbHskEbFN0gXAMqAHuDYiVks6v1w/PyLuk3QTsBJ4EbgmIu7JVZOZ5aGI6mGLzjYwMBCDg4PtLsNslyNpRUQM1FrnM1vNLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLFndIJG0VNLUFtZiZl1quD2SBcCPJX1S0rgW1WNmXajuKfIRcb2kHwGfBgYlfYviNPah9V9pQX1m1gVGutZmK7AZ2B3Yi4ogMTMbUjdIJM0EvkJxxe5bI+L5llVlZl1luD2STwLvj4jVrSrGzLrTcGMkx7WyEDPrXj6PxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLJmDxMySZQ0SSTMlrZG0VtLFw7R7u6Ttkk7LWY+Z5ZEtSCT1AFcBs4BDgTMkHVqn3Rcp5r8xsy6Uc4/kKGBtRDwUES8A1wFzarT7KHAD8ETGWswso5xB0g88XvF8XbnsJZL6gVOA+cN1JOk8SYOSBjdu3Nj0Qs0sTc4gUY1l1dP6XQFcFBHbh+soIq6OiIGIGOjr62tWfWbWJNnm/qXYA5lc8Xx/YENVmwHgOkkAE4HZkrZFxOKMdZlZk+UMkuXANEkHAuuB04EzKxtExIFDjyUtAP7HIWLWfbIFSURsk3QBxa8xPcC1EbFa0vnl+mHHRcyse+TcIyEilgJLq5bVDJCIOCdnLWaWj89sNbNkDhIzS+YgMbNkDhIzS+YgMbNkDhIzS+YgMbNkDhIzS+YgMbNkDhIzS+YgMbNkDhIzS+YgMbNkDhIzS+YgMbNkDhIzS+YgMbNkDhIzS+YgMbNkDhIzS9bWScQlfUDSyvLPrZKOyFmPmeXR7knEHwbeFRGHA58Drs5Vj5nl09ZJxCPi1oh4unx6G8VsfGbWZdo6iXiVDwM3ZqzHzDLJOUFWI5OIFw2ld1MEyTvrrD8POA9gypQpzarPzJok5x5JI5OII+lw4BpgTkQ8WaujiLg6IgYiYqCvry9LsWY2ejmD5KVJxCWNp5hEfEllA0lTgEXAByPi/oy1mFlG7Z5E/NPAq4GvSQLYFhEDuWoyszwUUXPYomMNDAzE4OBgu8sw2+VIWlHvi95ntppZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZspwz7SFpJnAlxXQU10TEZVXrVa6fDTwPnBMRv0nZ5tSLf5Ty8h0IOPag/XjkyS2sf2bLK9b3907gwhnTmXtkP4vvXM/ly9aw4ZktTKpYPlrN7q/Vctff7e/PWJMtSCT1AFcBJ1HMurdc0pKIuLei2SxgWvnnaODr5X9HpZkhAsX8or968Km669c/s4VPLFrF4KNPccOK9WzZun2H5cCo/nEvvnM9n1i0qmn9tVru+rv9/RmLch7aHAWsjYiHIuIF4DpgTlWbOcA3o3Ab0CvpdRlrarotW7ez8PbHX/pHXbn88mVrRtXn5cvWNLW/Vstdf7e/P2NRziDpBx6veL6uXLazbZB0nqRBSYMbN25seqGptteZZGxDjcOhRtR73Wj7a7Xc9Xf7+zMW5QwS1VhW/YlrpE3HTyLeo1p/DZjUO2FU/dV73Wj7a7Xc9Xf7+zMW5QySdcDkiuf7AxtG0aajTRjXwxlHT2bCuJ5XLL9wxvRR9XnhjOlN7a/Vctff7e/PWJQzSJYD0yQdKGk8cDqwpKrNEuBsFY4B/hgRvxvtBh+57OTRV1uDgHcctB/9db7p+nsncOmph/H5uYdx6amH0d87AVUsH+3A39wj+5vaX6vlrr/b35+xKOsk4pJmA1dQ/Px7bUR8QdL5ABExv/z5dx4wk+Ln33MjYtgZwj2JuFl7DDeJeNbzSCJiKbC0atn8iscBfCRnDWaWn89sNbNkDhIzS5Z1jCQHSRuBRxtoOhHYlLmcVK4xXafXB51fY6P1HRARNc+/6LogaZSkwXoDQ53CNabr9Pqg82tsRn0+tDGzZA4SM0s2loPk6nYX0ADXmK7T64POrzG5vjE7RmJmrTOW90jMrEW6PkgkzZS0RtJaSRfXWC9JXy3Xr5T01g6s8QNlbSsl3SrpiE6qr6Ld2yVtl3RaK+srtz1ijZJOkHSXpNWSftFJ9UnaR9IPJd1d1ndui+u7VtITku6psz7tcxIRXfuH4hqeB4HXA+OBu4FDq9rMBm6kuAbvGOD2DqzxWGDf8vGsVtbYSH0V7X5OccnDaR34HvYC9wJTyuev6bD6/hX4Yvm4D3gKGN/CGo8H3grcU2d90uek2/dIuuEubCPWGBG3RsTT5dPbKG6n0DH1lT4K3AA80cLahjRS45nAooh4DCAiWllnI/UFsFd5oeqeFEGyrVUFRsQt5TbrSfqcdHuQNO0ubBnt7PY/TPHN0Coj1iepHzgFmE97NPIeHgzsK+lmSSsknd2y6hqrbx5wCMX9dlYBH4+IF1tTXkOSPidZr/5tgabdhS2jhrcv6d0UQfLOrBVVbbbGsur6rgAuiojtqnM3uMwaqXE34G3AicAE4NeSbouI+3MXR2P1zQDuAt4DHAT8RNIvI+JPmWtrVNLnpNuDpBvuwtbQ9iUdDlwDzIqIJ1tUGzRW3wBwXRkiE4HZkrZFxOKWVNj4/+dNEbEZ2CzpFuAIoBVB0kh95wKXRTEgsVbSw8AbgTtaUF8j0j4nrRrsyTSAtBvwEHAgLw9yvamqzcnsOIh0RwfWOAVYCxzbie9hVfsFtH6wtZH38BDgZ2XbPYB7gDd3UH1fBz5TPn4tsB6Y2OL3cSr1B1uTPiddvUcSEdskXQAs4+W7sK2uvAsbxa8Msyk+qM9TfDN0Wo2fBl4NfK381t8WLbrIq8H62qqRGiPiPkk3ASuBFykmZKv5U2c76gM+ByyQtIriw3pRRLTsimBJC4ETgImS1gGXAOMq6kv6nPjMVjNL1u2/2phZB3CQmFkyB4mZJXOQmFkyB4mZJXOQWHaS9pf0A0kPSHpI0jxJu7e7LmseB4llVV6ktghYHBHTgGkUp7B/qa2FWVP5PBLLStKJwCURcXzFsr0pphSZHBHPta04axrvkVhubwJWVC6I4kK1R4A3tKMgaz4HieUmal9F2pbLiC0PB4nltpri6uGXlIc2rwXWtKUiazoHieX2M2CPoRsNSeoBvgzMi4gtba3MmsZBYllFMZp/CnCapAeAJ4EXI+IL7a3Mmsm/2lhLSToWWAicGhErRmpv3cFBYmbJfGhjZskcJGaWzEFiZskcJGaWzEFiZskcJGaWzEFiZsn+H93w/SnR+0SlAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 288x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(1, figsize=(4, 3))\n",
    "plt.clf()\n",
    "plt.scatter(df_hT_CS[\"Q\"], df_hT_CS[\"Y\"])\n",
    "plt.ylabel(\"Y\")\n",
    "plt.xlabel(\"Q\")\n",
    "plt.title(\"Y vs Q for D(hT)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "caee74f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.6034427237486937\n"
     ]
    }
   ],
   "source": [
    "var_whs_CS = var_w(df_init_CS, df_hS_CS, number_bin = number_bin)\n",
    "print(var_whs_CS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "7fa506c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6035303835998264\n"
     ]
    }
   ],
   "source": [
    "var_whT_CS = var_w(df_init_CS, df_hT_CS, number_bin = number_bin)\n",
    "print(var_whT_CS)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13277de5",
   "metadata": {},
   "source": [
    "Compute the upper bound"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "382f58dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Err_hs_hs_CS - Err_hT_hT_CS: 0.002\n",
      "the upper bound is: 0.3030471351781373\n"
     ]
    }
   ],
   "source": [
    "# compute Err_hs_hs - Err_hT_hT\n",
    "print(\"Err_hs_hs_CS - Err_hT_hT_CS:\", Err_hS_hS_CS - Err_hT_hT_CS)\n",
    "\n",
    "\n",
    "# compute Err_S_hT\n",
    "Err_S_hT_CS = compute_error(df_init_CS, hT_CS)\n",
    "#print(\"Err_S_hT_CS\", Err_S_hT_CS)\n",
    "\n",
    "# compute the upper bound (theorem 5)\n",
    "UB = np.sqrt(Err_S_hT_CS)*(np.sqrt(var_whs_CS)+np.sqrt(var_whT_CS))\n",
    "print(\"the upper bound is:\", UB)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "475423f8",
   "metadata": {},
   "source": [
    "## lower bound"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84df64dc",
   "metadata": {},
   "source": [
    "Covariate Shift lower bound is:\n",
    "\n",
    "                max{Err_S(h), Err_h(h)} >= (dTV(D_S(Y), D_h(Y)) - dTV(D_S(h), D_h(h)))/2\n",
    "\n",
    "\n",
    "To test the lower bound, we need to compute the following quantities for hS and hT:\n",
    "\n",
    "   - max{Err_S(h), Err_h(h)} \n",
    "   \n",
    "        \n",
    "   - dTV(D_S(Y), D_h(Y)) = |Pr_S(Y = 1) - Pr_h(Y = 1)|\n",
    "   \n",
    "    \n",
    "   - dTV(D_S(h), D_h(h)) = |Pr_S(h = 1) - Pr_h(h = 1)|"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8af2466f",
   "metadata": {},
   "source": [
    "#### compute the lower bound for hT (theorem 4.6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "f3e9af51",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max{Err_S(hT), Err_hT(hT)} 0.022\n",
      "the lower bound for using hT is: 0.009000000000000008\n"
     ]
    }
   ],
   "source": [
    "# compute Pr_Ds[Y = +1]\n",
    "Pr_S_Y1_CS = df_init_CS['Y'].value_counts()[1]/len(df_init_CS)\n",
    "\n",
    "#print(\"Pr_S_Y1_CS\", Pr_S_Y1_CS)\n",
    "\n",
    "# compute Pr_DhT[Y = +1]\n",
    "Pr_hT_Y1_CS = df_hT_CS['Y'].value_counts()[1]/len(df_hT_CS)\n",
    "\n",
    "#print(\"Pr_Dh_Y1_CS\", Pr_Dh_Y1_CS)\n",
    "\n",
    "# compute Pr_Ds[hT(X) = +1]\n",
    "Pr_S_hT1_CS = np.sum(df_init_CS['Q'] > hT_CS)/len(df_init_CS)\n",
    "\n",
    "#print(\"Pr_S_hT1_CS\", Pr_S_hT1_CS)\n",
    "\n",
    "# compute Pr_DhT[hT(X) = +1]\n",
    "Pr_hT_hT1_CS = np.sum(df_hT_CS['Q'] > hT_CS)/len(df_hT_CS)\n",
    "\n",
    "#print(\"Pr_hT_hT1_CS\", Pr_hT_hT1_CS)\n",
    "\n",
    "print(\"max{Err_S(hT), Err_hT(hT)}\", max(Err_S_hT_CS, Err_hT_hT_CS))\n",
    "\n",
    "LB_hT_CS = (np.abs(Pr_S_Y1_CS - Pr_hT_Y1_CS) - np.abs(Pr_S_hT1_CS - Pr_hT_hT1_CS)) /2\n",
    "print(\"the lower bound for using hT is:\", LB_hT_CS)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "131a8de1",
   "metadata": {},
   "source": [
    "#### compute the lower bound for hS (theorem 4.6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "18e3f93c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max{Err_S(hS), Err_hS(hS)} 0.014\n",
      "the lower bound for using hS is: 0.0\n"
     ]
    }
   ],
   "source": [
    "# compute Pr_DS[Y = +1]\n",
    "Pr_S_Y1_CS = df_init_CS['Y'].value_counts()[1]/len(df_init_CS)\n",
    "\n",
    "\n",
    "# compute Pr_DhS[Y = +1]\n",
    "Pr_hS_Y1_CS = df_hS_CS['Y'].value_counts()[1]/len(df_hS_CS)\n",
    "\n",
    "\n",
    "# compute Pr_DS[hS(X) = +1]\n",
    "Pr_S_hS1_CS = np.sum(df_init_CS['Q'] > hS_CS)/len(df_init_CS)\n",
    "\n",
    "\n",
    "\n",
    "# compute Pr_DhS[hS(X) = +1]\n",
    "Pr_hS_hS1_CS = np.sum(df_hS_CS['Q'] > hS_CS)/len(df_hS_CS)\n",
    "\n",
    "\n",
    "print(\"max{Err_S(hS), Err_hS(hS)}\", max(Err_S_hS_CS, Err_hS_hS_CS))\n",
    "\n",
    "LB_hS_CS = (np.abs(Pr_S_Y1_CS - Pr_hS_Y1_CS) - np.abs(Pr_S_hS1_CS - Pr_hS_hS1_CS)) /2\n",
    "print(\"the lower bound for using hS is:\", LB_hS_CS)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef5bb2be",
   "metadata": {},
   "source": [
    "# Target Shift"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c2b391b",
   "metadata": {},
   "source": [
    "## helper functions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8423505a",
   "metadata": {},
   "source": [
    "#### generate initial distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "7a64ec5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def trunc_gauss(mu, sigma, bottom, top):\n",
    "    '''\n",
    "    generate one data point from a truncated gaussian distribution N(mu, sigma)\n",
    "    return a data point within the range [bottom, top]\n",
    "    \n",
    "    Input:\n",
    "        - mu: mean\n",
    "        - sigma: variance\n",
    "        - bottom: lower bound for the truncated Gaussian variable\n",
    "        - top: upper bound for the truncated Gaussian variable\n",
    "        \n",
    "    Output:\n",
    "        - one random variable sampled from N_[bottom, top](mu, sigma)\n",
    "    '''\n",
    "    a = np.random.normal(mu, sigma, 1)\n",
    "    while (bottom <= a <= top) == False:\n",
    "        a = np.random.normal(mu, sigma,1)\n",
    "    return a\n",
    "\n",
    "\n",
    "def generate_trunc_gauss_N(mu, sigma, bottom, top, N):\n",
    "    '''\n",
    "    Generate a length of N truncated Gaussian points from N(mu, sigma)\n",
    "    '''\n",
    "    result = []\n",
    "    result = np.array(result)\n",
    "    for i in range(N):\n",
    "        result = np.append(result, trunc_gauss(mu, sigma, bottom, top))\n",
    "    return result "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "dce96fc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_init_LS(alpha, mu0, mu1, sigma, N, bottom = 0, top = 1):\n",
    "    '''\n",
    "    Generate N data points from a truncated Gaussian mixture model \n",
    "    (in our case, the truncated upper and lower bound are always 0 and 1)\n",
    "    with two distributions which are N0(mu0, sigma^2), N1(mu1, sigma^2). \n",
    "    The ratio between the two distribution is alpha: (1-alpha)\n",
    "    \n",
    "    N: total number of points\n",
    "    Y ~ bernoulli(alpha)\n",
    "    X1|Y = y ~ N (mu_y, sigma)\n",
    "    X2 ~ -0.8X1 + N(0, sigma2)\n",
    "    X3 ~ 0.2Y + N(0,sigma3)\n",
    "    Y is threshold based on the value of X2\n",
    "    \n",
    "    Input:\n",
    "        - alpha: ratio of Y = 1 instances\n",
    "        - mu0: mean for Y = 0 instances\n",
    "        - mu1: mean for Y = 1 instances\n",
    "        - sigma: variance for all instances\n",
    "        - bottom: lower bound for the truncated Gaussian variable\n",
    "        - top: upper bound for the truncated Gaussian variable\n",
    "    Output:\n",
    "        - a data frame df with ['X1', 'X2', 'X3', 'Y'].\n",
    "    '''\n",
    "    # generate |N*alpha| number of y = 0 instances between 0 and 1:\n",
    "    s1 = generate_trunc_gauss_N(mu0, sigma, bottom, top, int(N*alpha))\n",
    "    # generate |N*(1-alpha)| number of y = 1 instances between 0 and 1:\n",
    "    s2 = generate_trunc_gauss_N(mu1, sigma, bottom, top, int(N*(1-alpha)))\n",
    "    \n",
    "    X1 = np.array(list(s1)+list(s2))\n",
    "    \n",
    "    Y = np.array(list(np.zeros(int(N*alpha)))+list(np.ones(int(N*(1-alpha)))))\n",
    "    # generate X2, X3\n",
    "    X2 = -X1 * 0.8 + np.random.normal(0,0.15, N)  \n",
    "    X3 = Y* 0.2 +  np.random.normal(0,0.2, N)  \n",
    "    \n",
    "    # combine them into a dataframe\n",
    "    d = {'X1': list(X1), 'X2': list(X2), 'X3': list(X3), 'Y': list(Y)}\n",
    "    df = pd.DataFrame(data=d)\n",
    "    return df "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "d3524de9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute the induced distribution for a particular threshold\n",
    "def compute_new_distribution_LS(df, threshold, alpha, mu0, mu1, sigma, bottom = 0, top = 1):\n",
    "    '''\n",
    "    compute the induced distribution given an initial distribution (df) and a threshold\n",
    "    \n",
    "    Input:\n",
    "        - df: initial distribution with Q\n",
    "        - threshold: a threshold\n",
    "        - alpha: ratio of Y = 1 instances for the initial distribution\n",
    "        - mu0: mean for Y = 0 instances for the initial distribution\n",
    "        - mu1: mean for Y = 1 instances for the initial distribution\n",
    "        - sigma: variance for all instances for the initial distribution\n",
    "        - bottom: lower bound for the truncated Gaussian variable\n",
    "        - top: upper bound for the truncated Gaussian variable\n",
    "    Output:\n",
    "        - an induced dataframe df_new with ['X1', 'X2', 'X3', 'Y'].\n",
    "    '''\n",
    "    N = len(df)\n",
    "    # compute the new data distribution\n",
    "    # get the old classifier result\n",
    "    hX = 1*(df[\"Q\"] > threshold)\n",
    "    # generate new Y \n",
    "    Y_new = list()\n",
    "\n",
    "    for i in range(N):\n",
    "        # generate a random number in [0,1]\n",
    "        rn = random.random()\n",
    "        # when h(x) = +1, assume the true qualification won't change\n",
    "        if hX[i] == +1:\n",
    "            Y_new.append(df.iloc[i][\"Y\"])\n",
    "        # when h(x) = -1: \n",
    "        # if Y = +1, P(Y' = +1 | h(x) = -1, Y = +1) = 0.8\n",
    "        elif df.iloc[i][\"Y\"] == +1:\n",
    "            if rn < 0.8:\n",
    "                Y_new.append(df.iloc[i][\"Y\"])\n",
    "            else:\n",
    "                Y_new.append(1 - df.iloc[i][\"Y\"])\n",
    "        # if Y = -1, P(Y' = +1 | h(x) = -1, Y = -1) = 0.15\n",
    "        else:\n",
    "            if rn < 0.15:\n",
    "                Y_new.append(1 - df.iloc[i][\"Y\"])\n",
    "            else:     \n",
    "                Y_new.append(df.iloc[i][\"Y\"])    \n",
    "    \n",
    "    # compute alpha_new = P(Y'= +1)\n",
    "    alpha_new = Y_new.count(1)/N\n",
    "    Y_new = np.array(Y_new)\n",
    "\n",
    "    # generate new X1, X2, X3:\n",
    "    # generate |N*alpha| number of y = 0 instances between 0 and 1:\n",
    "    s1 = generate_trunc_gauss_N(mu0, sigma, bottom, top, int(N*alpha_new))\n",
    "    # generate |N*(1-alpha)| number of y = 1 instances between 0 and 1:\n",
    "    s2 = generate_trunc_gauss_N(mu1, sigma, bottom, top, N - int(N*alpha_new))\n",
    "    X1_new = np.array(list(s1)+list(s2))\n",
    "\n",
    "    # generate X2, X3\n",
    "    X2_new = -X1_new * 0.8 + np.random.normal(0,0.15, N)  \n",
    "    X3_new = Y_new* 0.2 +  np.random.normal(0,0.2, N) \n",
    "\n",
    "    d = {'X1': list(X1_new), 'X2': list(X2_new), 'X3': list(X3_new),  'Y': list(Y_new)}\n",
    "    df_new = pd.DataFrame(data=d)\n",
    "    return df_new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "79730c30",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_optimal_threshold_MinInduceRisk_LS(df, M, alpha, mu0, mu1, sigma, bottom = 0, top = 1):\n",
    "    '''\n",
    "    find the optimal classifier achieve min Err(h)(h)\n",
    "    where M is the number of potential classifiers to search for\n",
    "    return the optimal threshold and the minimum error \n",
    "    \n",
    "    Input:\n",
    "        - df: initial distribution with 'Q'\n",
    "        - M: total number of threshold to loop through\n",
    "        - alpha: ratio of Y = 1 instances for the initial distribution\n",
    "        - mu0: mean for Y = 0 instances for the initial distribution\n",
    "        - mu1: mean for Y = 1 instances for the initial distribution\n",
    "        - sigma: variance for all instances for the initial distribution\n",
    "        - bottom: lower bound for the truncated Gaussian variable\n",
    "        - top: upper bound for the truncated Gaussian variable\n",
    "    Output:\n",
    "        - an optimal threshold\n",
    "        - the corresponding error\n",
    "    \n",
    "    '''\n",
    "    N = len(df)\n",
    "    # compute the boundary of Q:\n",
    "    Q_min, Q_max = np.min(df[\"Q\"]), np.max(df[\"Q\"])\n",
    "    #print(Q_min, Q_max)\n",
    "    \n",
    "    # compute the coefficients on the initial data distribution\n",
    "    clf = LogisticRegression(solver='liblinear', fit_intercept=True)\n",
    "    clf.fit(df[[\"X1\", \"X2\", \"X3\"]], df['Y'])\n",
    "    \n",
    "    # loop through the potential thresholds\n",
    "    error_list = []\n",
    "    for j in range(M):\n",
    "        threshold = Q_min + j*(Q_max - Q_min)/M\n",
    "        # compute the new data distribution\n",
    "        df_new = compute_new_distribution_LS(df, threshold, alpha = 0.5, mu0 = 0, mu1 = 1.5, sigma = 0.25, bottom = 0, top = 1 )\n",
    "        # compute the new qualification\n",
    "        # using the coefficient from the logistic regression coefficient for the\n",
    "        # old dataset\n",
    "        df_new['Q'] = clf.predict_proba(df_new[[\"X1\", \"X2\", \"X3\"]])[:, 1]\n",
    "        # compute new data distribution's loss\n",
    "        Err_h_h = compute_error(df_new, threshold)\n",
    "        error_list.append(Err_h_h)\n",
    "    index_min = np.argmin(error_list)\n",
    "    min_error = np.min(error_list)\n",
    "\n",
    "    # compute the optimal threshold on the dataset that it induced call it hT\n",
    "    optimal_threshold = Q_min + index_min*(Q_max - Q_min)/M\n",
    "    return optimal_threshold, min_error"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84f4d4e4",
   "metadata": {},
   "source": [
    "#### set parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "bc20d9f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 0.5\n",
    "mu0 = 0\n",
    "mu1 = 1\n",
    "sigma = 0.25\n",
    "N = 3000\n",
    "bottom = 0\n",
    "top = 1\n",
    "\n",
    "M = 100"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9237450",
   "metadata": {},
   "source": [
    "#### generate initial distribution df_init_LS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "410d2e9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_init_LS = generate_init_LS(alpha = alpha, mu0 = mu0, mu1 = mu1, sigma = sigma, N = N, bottom = bottom, top = top)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09052f6a",
   "metadata": {},
   "source": [
    "#### compute optimal threshold classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "e8a93bcd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LogisticRegression(solver='liblinear')"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# compute the best logistic regression model\n",
    "clf = LogisticRegression(solver='liblinear', fit_intercept=True)\n",
    "clf.fit(df_init_LS[[\"X1\", \"X2\", \"X3\"]], df_init_LS['Y'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "0e953ba2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute the \"qualification\" (linear combination of features) based on the logistic regression model we trained\n",
    "df_init_LS['Q'] = clf.predict_proba(df_init_LS[[\"X1\", \"X2\", \"X3\"]])[:, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "ac9b64ec",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 1.0, 'Y vs Q for DS_LS')"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAADgCAYAAADPGumFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAVGUlEQVR4nO3dfXRcdZ3H8fcn6QSSUgnQlLWFWsAiD0pZCNAFH0BgKUUXQVh5EuF4lmUX1D3ucWHVFVdFEA4InqIc5LAcVGBFsYIWWHzEFYtNXZ6KPJQCpVRtCxSkLTZNvvvHvQnTdJIM+c2dScLndU5OZu79zb3fTDufufd3H36KCMzMUjQ1ugAzG/scJGaWzEFiZskcJGaWzEFiZskcJGaWzEFiZskcJFZzkg6R9ISkVyR9oNH1WPEcJGOEpO9Ium7AtPdIel7Smwtcb7ukb0j6o6T1kh6S9JFhXvYFYF5EbBMR82tQw/WSNkr6c/7zsKSLJG1b1qZF0mWSVuQB9pSkr1ax7KclHTHIvE/ny3klX+5/p/4t45WDZOz4ODBX0pEAkrYGvgn8a0T8oYgVSmoBfgK8BfgbYFvgU8Alkj4+xEvfAiwZ4TonDDLrkoiYBHQAZwKzgV9LmpjP/3egEzgQmAQcBvzfSGrI6/gI8GHgiIjYJl/2T0e6vPHOQTJGRMTzwMeAa/IPzwXAkxFx/cC2kmbnWxDNZdOOk/Rg/vhASV2SXpb0J0mXD7LaDwPTgRMj4qmI6I6IO8lC7UuSJlVY95PArsDt+Tf5VpKmSrpN0guSlkr6h7L2n5f0PUnflvQycMYw78OrEbEI+DtgB7JQATgA+EFErIzM0xFxw1DLGsYBwF0R8WS+3j9GxDUJyxvXHCRjSETcAiwGbgLOAv5xkHYLgXXAe8smnwLcmD++ErgyIt4E7AZ8d5BVHgncERHrBkz/PtBGtlUwcN27AcuB9+e7Nn/J610BTAVOAL4s6fCylx0LfA9oB74zSC0D1/Nn4G7gXfmkhcAnJf2zpHdIUjXLGcJC4HRJn5LUWR7KtiUHydhzDllAfCEilg/R7ibgZIB8y2FuPg2gG3irpMkR8UoePJVMBrbYbYqITcAast2MIUnaGXgncF6+NXE/cC3Z1k6f30TE/IjojYgNwy2zzEpg+/zxRcBXgFOBLuC5KvpyBhUR3ybbAjwK+CWwStL5I13eeOcgGWMi4k9kH+Lh+iBuBI6XtBVwPPC7iHgmn/dRYHfgUUmLJL1vkGWsAbboyM37MSYDq6soeSrwQr4F0ecZYFrZ82erWE4l04AXACKiJyKuiohDyLZsLgSuk7TnCJdNRHwnIo7Il3c28AVJR410eeOZg2SciohHyD6wR7P5bg0R8UREnAxMIfsW/15Zp2W5nwBHV5j3QbKtmt9WUcpKYPsB/SnTgefKy61iOZuRtA1wBPCrgfMiYkNEXAW8COz1epddYXnd+W7lg8DbU5c3HjlIxrcbyTpG3w3c0jdR0mmSOiKiF1ibT+6p8PpvkfVt3CJphqRS/o38NbKjKC8NV0BEPAvcC1wkaWtJ+5BtEVXVFzJQ3nm7PzCfLCj+K5/+L5IOldQqaUK+WzOJ6o7clPLa+n4mSDpD0jGSJklqknQ0sDdw30jqHu8cJOPbTcChwM8iYk3Z9DnAEkmvkHW8nhQRrw58cd5RegTZrsd9wAbgTuAK4D9fRx0nAzPItk5+AFwQEXe/zr/l3yT9mWxX5gayTueDyzqCNwCXAX8k2yU7B/hgRCyrYtkL8tf3/XweeBn4NFnH8VrgEuCfIuJ/X2fdbwjyHdKsWpJKwB1kuyVnhP/zWM5bJFa1iOgm6x95Enhbg8uxUcRbJDauSZoOPDLI7L2GOYRuVXKQmFky79qYWbLBLpAatSZPnhwzZsxodBlmbziLFy9eExEVz2Yec0EyY8YMurq6Gl2G2RuOpGcGm+ddGzNL5iAxs2SF7drkd/N6H7AqIra4PiG/zPtKsqtS15Od4PS7Wtbw2fkPceN9y+nND0y1lpr44P478fNHV7Ny7Qamtrdy2B4dmz2fsUMrC5e9SE+Fo1lNgt6AZmmL+aUm2NRL/zJ/9MAfWLuhG4C2UhM9AX/Z1Nvfvq3UxPF5Lc+tfe2CVwFtLc2s39jT/zvI1rlrRxtLV60b9sKUHSe1cOTef8VN9z1b8e8YjKB/XT0R/b/7plf7+j7btZW44P17c0vXcn795AtbtG8rNfHl4/fhqp8/wROrXrtTwcwpE7n7k4dy6jd/U/F1AM35TQJ6fNCxpp6++JgRva6ww7+S3g28AtwwSJDMJbtMey5wENn9MQ4abrmdnZ1RTR/JZ+c/xLcX+hSBRqs2hAbaulm86pRoiMHCRNLiiOisNK+wXZuIuIf8Eu9BHEsWMpHfD6O9lvcevem+kV6ZbrU00ihwiIwtjewjmcbm96FYweb3qOgn6az81oBdq1dXcwsMXtcmvZmlaWSQVLoVXsVPf0RcExGdEdHZ0THsTbmAbD/fzOqjkUGyAti57PlOZJeZ18TJB+08fCMr3EjjfOtmfxGMJY0MktvIbq4rSbOBl2o5rMKXPvAOTps9naay/4+tpSZOmz2dae2tCJjW3rrF80N2237QrZm+ZVWaX2pis2W2t5b657WVmthqwuZvdVtZLeUETGxp3ux33zpnTplY1Qdzx0ktnDZ7+uveKitfV/nvapcysN12bSW++qF9OWS37Su2bys1ccWH9mXmlM1vwDZzykQevXDuoK/LanvtyI3Vzmg8atN3U53JwJ/Ihk8oAUTE1fnh33lkN9lZD5wZEcMejqn2qI2Z1dZQR20KO48kvyfoUPOD7C5WZjbG+cxWM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZIUGiaQ5kh6TtFTS+RXmbyvpdkkPSFoi6cwi6zGzYhQWJJKagauAo4G9gJMl7TWg2TnAIxExi2zoissktRRVk5kVo8gtkgOBpRGxLCI2AjeTDRxeLoBJ+Rg325ANOr6pwJrMrABFBkk1g4TPA/YkG6rzIeATEdFbYE1mVoAig6SaQcKPAu4HpgL7AvMkvWmLBUlnSeqS1LV69epa12lmiYoMkmoGCT8TuDUyS4GngD0GLigiromIzojo7OjoKKxgMxuZIoNkETBT0i55B+pJZAOHl1sOHA4gaUfgbcCyAmsyswIUOfbvJknnAncBzcB1EbFE0tn5/KuBLwLXS3qIbFfovIhYU1RNZlaMwoIEICIWAAsGTLu67PFK4G+LrMHMiuczW80smYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsWaFBImmOpMckLZV0/iBtDpV0v6Qlkn5ZZD1mVozChqOQ1AxcBRxJNureIkm3RcQjZW3aga8DcyJiuaQpRdVjZsUpcovkQGBpRCyLiI3AzcCxA9qcQjZk53KAiFhVYD1mVpAig2Qa8GzZ8xX5tHK7A9tJ+oWkxZJOr7QgDyJuNroVGSSqMC0GPJ8A7A8cAxwF/Iek3bd4kQcRNxvVihyycwWwc9nznYCVFdqsiYh1wDpJ9wCzgMcLrMvMaqzILZJFwExJu0hqAU4CbhvQ5ofAuyRNkNQGHAT8vsCazKwAhW2RRMQmSecCdwHNwHURsUTS2fn8qyPi95LuBB4EeoFrI+Lhomoys2IoYmC3xejW2dkZXV1djS7D7A1H0uKI6Kw0z2e2mlkyB4mZJXOQmFkyB4mZJXOQmFkyB4mZJXOQmFmyQYNE0gJJM+pYi5mNUUNtkVwP/I+kz0gq1akeMxuDBj1FPiK+K+nHwOeALknfIjuNvW/+5XWoz8zGgOGutekG1gFbAZMoCxIzsz6DBomkOcDlZFfs7hcR6+tWlZmNKUNtkXwGODEiltSrGDMbm4bqI3lXPQsxs7HL55GYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWbJCg0TSHEmPSVoq6fwh2h0gqUfSCUXWY2bFKCxIJDUDVwFHA3sBJ0vaa5B2XyEb/8bMxqAit0gOBJZGxLKI2AjcDBxbod3HgO8DqwqsxcwKVGSQTAOeLXu+Ip/WT9I04Djg6qEWJOksSV2SulavXl3zQs0sTZFBogrTBg7rdwVwXkT0DLWgiLgmIjojorOjo6NW9ZlZjRQ29i/ZFsjOZc93AlYOaNMJ3CwJYDIwV9KmiJhfYF1mVmNFBskiYKakXYDngJOAU8obRMQufY8lXQ/8yCFiNvYUFiQRsUnSuWRHY5qB6yJiiaSz8/lD9ouY2dhR5BYJEbEAWDBgWsUAiYgziqzFzIrjM1vNLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLJmDxMySNXQQcUmnSnow/7lX0qwi6zGzYjR6EPGngPdExD7AF4FriqrHzIrT0EHEI+LeiHgxf7qQbDQ+MxtjGjqI+AAfBe4osB4zK0iRA2RVM4h41lA6jCxI3jnI/LOAswCmT59eq/rMrEaK3CKpZhBxJO0DXAscGxHPV1pQRFwTEZ0R0dnR0VFIsWY2ckUGSf8g4pJayAYRv628gaTpwK3AhyPi8QJrMbMCNXoQ8c8BOwBflwSwKSI6i6rJzIqhiIrdFqNWZ2dndHV1NboMszccSYsH+6L3ma1mlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlqzIkfaQNAe4kmw4imsj4uIB85XPnwusB86IiN+lrHPG+T9Oefmo0lpqYr/p7dz75AuVhygcB6a1t/Kpo94GwH/evoQX13cD2TCNKX/zIbttz9PPb2Dl2g20t5V4tbuHDd29m7XpW0d5DZfe9Rgr125gansrh+3Rwc8fXT3oMtpKTbRMaOalDd1s21pCgrXruyk1i409m1c/saWZ9Rt7aGtpZt3Gns3mtZaa2LrUzNr13Uwd5P3os9WEJjZu6iWAZoldO9pYtno9PWWjQTRL9ET0/+77+7qeeYGb7nuWnggkaJ3QxIbu3v51fuCvhxpRd2iFDUchqRl4HDiSbNS9RcDJEfFIWZu5wMfIguQg4MqIOGio5Q41HMV4CpE3klKT6AV6ehsXl6VmQUB3A2sor6WnN6hlKU1A7xDzW0vNXHT8O4YMk0YNR3EgsDQilkXERuBm4NgBbY4FbojMQqBd0psLrMlGoe7eaGiIAHT3xKgIEchqqXUpQ4UIwIbuHi6967ERL7/IIJkGPFv2fEU+7fW2QdJZkrokda1evbrmhZoZrFy7YcSvLTJIVGHawJytpo0HETerg6ntrSN+bZFBsgLYuez5TsDKEbSxca7UJJqbKn2n1LGGZlFqcA19Ss2i1qUM90FvLTX3d/IWsfwUi4CZknaR1AKcBNw2oM1twOnKzAZeiog/jHSFT198zMirHYVaS00cstv2FTfbxotp7a1ceuIsLjtxFtu1lfqnp/7Nh+y2PdPaWxGwXVuJ1tKW/9X71jGtvZVLT5jFpSfO6n/NtPZWTps9fchltJWaaG8tIaC9tcR2bdnjluYtq5/Y0ozy3wO1lpr6X9tXy+V/v+9m70efrSY09dfdLDFzykSatfn6+p73/Z7W3srlH9qX02ZP758mZfX3rXO4jtbhFDqIeH5U5gqyw7/XRcSFks4GiIir88O/84A5ZId/z4yIIUcI9yDiZo0x1FGbQs8jiYgFwIIB064uexzAOUXWYGbF85mtZpbMQWJmyQrtIymCpNXAM1U0nQysKbicVKO9xtFeH7jGWqi2vrdERMXzL8ZckFRLUtdgHUOjxWivcbTXB66xFmpRn3dtzCyZg8TMko3nILmm0QVUYbTXONrrA9dYC8n1jds+EjOrn/G8RWJmdTLmg0TSHEmPSVoq6fwK8yXpa/n8ByXtN8rqOzWv60FJ90qaVc/6qqmxrN0BknoknVDP+vJ1D1ujpEMl3S9piaRfjqb6JG0r6XZJD+T1nVnn+q6TtErSw4PMT/ucRMSY/SG7hudJYFegBXgA2GtAm7nAHWTXaM0G7htl9R0MbJc/Prqe9VVbY1m7n5Fd8nDCaKsRaAceAabnz6eMsvo+DXwlf9wBvAC01LHGdwP7AQ8PMj/pczLWt0hG+13Yhq0vIu6NiBfzpwvJbqVQT9W8h5DdEvP7wKp6FperpsZTgFsjYjlARNSzzmrqC2BSfqHqNmRBsqleBUbEPfk6B5P0ORnrQVKzu7AV5PWu+6Nk3wr1NGyNkqYBxwFX0xjVvI+7A9tJ+oWkxZJOr1t11dU3D9iT7H47DwGfiIjh7oBYT0mfk0Kv/q2Dmt2FrSBVr1vSYWRB8s5CK6qw6grTBtZ4BXBeRPRIlZoXrpoaJwD7A4cDrcBvJC2MiMeLLo7q6jsKuB94L7AbcLekX0XEywXXVq2kz8lYD5LRfhe2qtYtaR/gWuDoiHi+TrX1qabGTuDmPEQmA3MlbYqI+XWpsPp/5zURsQ5YJ+keYBbZSAajob4zgYsj65BYKukpYA/gt3Worxppn5N6dfYU1IE0AVgG7MJrnVx7D2hzDJt3Iv12lNU3HVgKHDxa38MB7a+n/p2t1byPewI/zdu2AQ8Dbx9F9X0D+Hz+eEfgOWBynd/HGQze2Zr0ORnTWyQRsUnSucBdvHYXtiXld2EjO8owl+zDup7sm2E01fc5YAfg6/k3/qao4wVeVdbYUNXUGBG/l3Qn8CDZ6AvXRkTFQ52NqA/4InC9pIfIPqznRUTdrgiWdBNwKDBZ0grgAqBUVl/S58RntppZsrF+1MbMRgEHiZklc5CYWTIHiZklc5CYWTIHiRVO0k6SfijpCUnLJM2TtFWj67LacZBYofKL1G4F5kfETGAm2SnslzS0MKspn0dihZJ0OHBBRLy7bNqbyIYU2TkiXmlYcVYz3iKxou0NLC6fENmFak8Db21EQVZ7DhIrmqh8FWlDLiO2YjhIrGhLyK4e7pfv2uwIPNaQiqzmHCRWtJ8CbX03GpLUDFwGzIuIDQ2tzGrGQWKFiqw3/zjgBElPAM8DvRFxYWMrs1ryURurK0kHAzcBx0fE4uHa29jgIDGzZN61MbNkDhIzS+YgMbNkDhIzS+YgMbNkDhIzS+YgMbNk/w9SXTZj50cOHwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 288x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot Y vs Q in the initial distribution\n",
    "plt.figure(1, figsize=(4, 3))\n",
    "plt.clf()\n",
    "plt.scatter(df_init_LS[\"Q\"], df_init_LS[\"Y\"])\n",
    "plt.ylabel(\"Y\")\n",
    "plt.xlabel(\"Q\")\n",
    "plt.title(\"Y vs Q for DS_LS\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2dfda6a9",
   "metadata": {},
   "source": [
    "# compute upper and lower bound"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7628dcce",
   "metadata": {},
   "source": [
    "#### Compute best classifier on the source distribution hS_LS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "id": "d5bbf3e2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hs_LS: 0.49007415641382496\n",
      "Err_S_hs_LS: 0.043\n"
     ]
    }
   ],
   "source": [
    "# compute the optimal threshold hs based on the original dataset\n",
    "hS_LS =  compute_optimal_threshold(df_init_LS, M = M)\n",
    "print(\"hs_LS:\", hS_LS)\n",
    "\n",
    "# error on the original dataset S using classifier \"hs\"\n",
    "Err_S_hS_LS = compute_error(df_init_LS, hS_LS)\n",
    "print(\"Err_S_hs_LS:\", Err_S_hS_LS)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e482341",
   "metadata": {},
   "source": [
    "#### Compute new distribution induced by hS_LS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "id": "ad44cf80",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute new distribution induced by hS\n",
    "df_hS_LS = compute_new_distribution_LS(df_init_LS, hS_LS, alpha = alpha, mu0 = mu0, mu1 = mu1, \n",
    "                                       sigma = sigma, bottom = bottom, top = top)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "d969055c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# add Q for df_hS_LS\n",
    "# compute Q for df_hT_LS\n",
    "df_hS_LS['Q'] = clf.predict_proba(df_hS_LS[[\"X1\", \"X2\", \"X3\"]])[:, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "id": "190e7f1a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 1.0, 'Y vs Q for df_hS_LS')"
      ]
     },
     "execution_count": 97,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAADgCAYAAADPGumFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAUXklEQVR4nO3dfZBcVZ3G8e/DZIITQSYhgyVDQhBBRCGAA2QFAWWREMQAhpUgsrCskSpx3SrXglprwRcUNCULVtAUS2F8A1aERVwCqFCKKwYy4T1oIAQJCSoTYkBClGTy2z/uHeh0umd65vTtnk6eT9VUdd8+955fd9JP3/ejiMDMLMUOzS7AzFqfg8TMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8SGTdIRkp6U9LKkkwtY/tmS/m+4/Un6haR/rnc9NjQHySgh6QeSri2bdrSkFyS9pcB+OyV9S9IfJb0i6VFJ/zjEbF8E5kXEThFxS1G11bs/SedK+p2kv0j6k6TbJO08xDwLJF1S5bWZkh6S9JKkNZLukjRlpPW1sjHNLsBe8y/AUknHRcTPJL0B+C/gMxHxhyI6lDQW+DnwPPB3wCrgWOA7knaJiG9UmXVPYOkI+xwTEZuGOduI+yvp92jgK8D0iHhQ0gTgpITlvQ34LnAqcDewE/ABYHNKna3KaySjRES8AHwKuFrSG4GLgaciYkF5W0nT8jWItpJpp0h6JH98mKTe/JfyT5Iur9Ltx4DJwGkR8XREbIyIO8hC7ZJKv9aSngLeCvwk39TYUdLukm6VtFbSckkfL2n/eUk/kvR9SS8BZ1dY5q75/C9Juh/Ye7D+hvgo95T063yt46eSJubTDwV+ExEPAkTE2oj4TkT8ZYjlVXMQ8HRE3BWZv0TETRGxcoTLa2kOklEkIm4ElgDXA3OAT1RptwhYD7y/ZPIZwHX54yuBKyPiTWRfyh9W6fI44PaIWF82/SZgHDCtQt97AyuBk/JNjb/l9a4CdgdmAV+RdGzJbDOBHwGdwA8q1HEV8FfgLcA/5X+D9TeYM4BzgN2AscC/5dPvA46X9IV8n8tQgTSUB4D9JP2npPdJ2ilxeS3NQTL6fJIsIL44xK/b9cBsgHzNYUY+DWAj8DZJEyPi5Tx4KpkIbLXZlG96rAG6hipW0iTgSOCCiPhrRDwEXEO2tjPgNxFxS0RsjogNZfO3AR8GLoqI9RHxGPCdofodxLcj4om8nx+SrTkQEb8i2ww5BLgNeEHS5aVrdcMRESuAY4DuvJ81+f6U7TJQHCSjTET8iexLPNQ+geuAU/Nf1lOBByLimfy1c4F9gd9JWizpg1WWsYZsLWALksaQhUxfDSXvDqwt20R4huwLNuDZQebvIttXV9rmmSpta/HHksevkO27ACAibo+Ik4AJZGtJZwMjPsoTEYsi4h8iogt4L3AU8LmRLq+VOUhaVEQ8TvaFO4EtN2uIiCcjYjbZ6v1XgR/l+13K/Rw4ocJrHyZbq7m/hlKeAyaU7U+ZDKwuLXeQ+fuATcCksvkLk68Z3UW2k/RddVrmYuDmei2v1ThIWtt1ZDtGjwJuHJgo6UxJXRGxGViXT+6vMP/3yPZt3ChpiqR2SccD3wC+FhEvDlVARDwL3AtcKukNkg4kWyOqtC+k0vz9ZF/Az0saJ2l/YKjDz8OWH6o9XdJ4ZQ4DjgaqbfaVasvf28DfWElHSvq4pN3y5e8HfKjG5W1zHCSt7Xqy7fS7I2JNyfTpZIeSXybb8Xp6RPy1fOZ8x+Xfk21W3AdsAO4ArgC+MIw6ZgNTyNZO/ge4OCJ+Noz5zyfbBPkjsAD49jDmrdWfgY8DTwIvAd8H5kZELYF3IdlnM/B3N1lAfwh4NP+c7yB771+re+UtQL5Dmg2Q1A7cTrZZcnb4P4fVyGsk9pqI2Ei2f+Qp4O1NLsdaiNdIrKXkmxGVnJAf4h3JMpeSnT1b7hM1bvps9xwkZpbMmzZmlqzlLtqbOHFiTJkypdllmG13lixZsiY/+W4rLRckU6ZMobe3t9llmG13JFU949ibNmaWzEFiZskK27TJ7/b1QeD5iNjq+gNJIjvrcgbZxVVnR8QDqf3e8uBq/vW/H0pdzDbniL0n0PvMOv62qfJ9d3Zg8DvytEnMPnwSPXtOYO6dy1i9bgNi8Itosvmg3wcGW8bvLztxRPMVuUaygOxU7WpOAPbJ/+YA30rt0CFS3a+fWls1RGDo23r1R/D9RSv5zI0Ps3pddieAWvLBIdJaplx424jmKyxIIuIeYO0gTWYC383vLrUI6Ey9N+ncO5elzG416N/sZLCtNXMfSTdb3oNiFVvew+I1kubktw7s7eurfouM59ZtqPqamRWnmUGiCtMq/txFxNUR0RMRPV1d1W/atXtnR71qM7NhaGaQrGLLm9nsQXYZ+oh99nhfZ1a0th0q5b9t75oZJLcCZ+U3mZkGvJg67MLJB3dzxUcOqktx25oj9p7AjmOq/3MP9R+hTeLMaZP5+mlT6c7X/GqJlDbnTksZ6VGbIg//Dtx0Z6KkVWTDK7QDRMR8YCHZod/lZId/z6lHvycf3M3JB1fc1WJ14s/XyhUWJPk9Qwd7PcjumG5mLc5ntppZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSUrNEgkTZe0TNJySRdWeH0XST+R9LCkpZLqcid5M2uswoJEUhtwFdlg4fsDsyXtX9bsk8DjETGVbOiKr0saW1RNZlaMItdIDgOWR8SKiHgVuIFs4PBSAewsScBOZIOObyqwJjMrQJFBUssg4fOAd5AN1fko8OmI2FxgTWZWgCKDpJZBwo8HHgJ2Bw4C5kl601YLkuZI6pXU29fXV+86zSxRkUFSyyDh5wA3R2Y58DSwX/mCIuLqiOiJiJ6urq7CCjazkSkySBYD+0jaK9+BejrZwOGlVgLHAkh6M/B2YEWBNZlZAYoc+3eTpPOBO4E24NqIWCrpvPz1+cCXgAWSHiXbFLogItYUVZOZFaOwIAGIiIXAwrJp80sePwd8oMgazKx4PrPVzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJIVGiSSpktaJmm5pAurtDlG0kOSlkr6ZZH1mFkxChuOQlIbcBVwHNmoe4sl3RoRj5e06QS+CUyPiJWSdiuqHjMrTpFrJIcByyNiRUS8CtwAzCxrcwbZkJ0rASLi+QLrMbOCFBkk3cCzJc9X5dNK7QuMl/QLSUsknVVpQR5E3Gx0KzJIVGFalD0fA7wbOBE4HvgPSftuNZMHETcb1YocsnMVMKnk+R7AcxXarImI9cB6SfcAU4EnCqzLzOqsyDWSxcA+kvaSNBY4Hbi1rM2PgfdKGiNpHHA48NsCazKzAhS2RhIRmySdD9wJtAHXRsRSSeflr8+PiN9KugN4BNgMXBMRjxVVk5kVQxHluy1Gt56enujt7W12GWbbHUlLIqKn0ms+s9XMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMklUNEkkLJU1pYC1m1qIGWyNZAPxU0ucktTeoHjNrQVVPkY+IH0q6DbgI6JX0PbLT2Adev7wB9ZlZCxjqWpuNwHpgR2BnSoLEzGxA1SCRNB24nOyK3UMi4pWGVWVmLWWwNZLPAadFxNJGFWNmrWmwfSTvbWQhZta6fB6JmSVzkJhZMgeJmSVzkJhZMgeJmSVzkJhZMgeJmSUrNEgkTZe0TNJySRcO0u5QSf2SZhVZj5kVo7AgkdQGXAWcAOwPzJa0f5V2XyUb/8bMWlCRaySHAcsjYkVEvArcAMys0O5TwE3A8wXWYmYFKjJIuoFnS56vyqe9RlI3cAowf7AFSZojqVdSb19fX90LNbM0RQaJKkwrH9bvCuCCiOgfbEERcXVE9ERET1dXV73qM7M6KWzsX7I1kEklz/cAnitr0wPcIAlgIjBD0qaIuKXAusyszooMksXAPpL2AlYDpwNnlDaIiL0GHktaAPyvQ8Ss9RQWJBGxSdL5ZEdj2oBrI2KppPPy1wfdL2JmraPINRIiYiGwsGxaxQCJiLOLrMXMiuMzW80smYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJI1dRBxSR+V9Ej+d6+kqUXWY2bFaPYg4k8DR0fEgcCXgKuLqsfMitPUQcQj4t6I+HP+dBHZaHxm1mKaOoh4mXOB2wusx8wKUuQAWbUMIp41lN5HFiRHVnl9DjAHYPLkyfWqz8zqpMg1kloGEUfSgcA1wMyIeKHSgiLi6ojoiYierq6uQoo1s5ErMkheG0Rc0liyQcRvLW0gaTJwM/CxiHiiwFrMrEDNHkT8ImBX4JuSADZFRE9RNZlZMRRRcbfFqNXT0xO9vb3NLsNsuyNpSbUfep/ZambJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJihxpD0nTgSvJhqO4JiIuK3td+eszgFeAsyPigZQ+p1x4W8rs272xbWLj5qDS4AJjdhCbNjdm1AHx+rCMEkRAd2cHnz3+7fQ+s5br7ltJeSnjx7UTAS9u2MguHe1IsO6Vjeyezwcw985lPLduA+PGtrH+1f6t5r/4pHdy8sHd3PLgaubeuYzV6zbQJtEfsUVN1Wqt5T2V1jlQW2mf1eo7c9pknu57mV8/tXbIvrorvOfdOzuYsmsH9z61dqt6BXx02mQuOfmAGt5JhfdW1HAUktqAJ4DjyEbdWwzMjojHS9rMAD5FFiSHA1dGxOGDLXew4SgcItu+HYDNI5ivfQeBYGP/4P/f29vERw6dxE1LVrNhY/+gbeulo72ND7+7u+59trcJAjYOI/zPHCRMmjUcxWHA8ohYERGvAjcAM8vazAS+G5lFQKektxRYk7W4kYQIZF+moUIEsqC5/r5nGxYiABs29hfS58b+GFaIAFx/37Mj6qvIIOkGSqtalU8bbhskzZHUK6m3r6+v7oWalepvwqBxzeizkpHWUWSQqMK0SptmQ7XxIOLWUG2q9N9y2+uzkpHWUWSQrAImlTzfA3huBG3MXjPS/7DtOyjbZzBUuzYx+/BJdLS3jbCn4etobyukz/Y2ZfuGhmH24ZOGblRBkUGyGNhH0l6SxgKnA7eWtbkVOEuZacCLEfGHkXb4+8tOHHm1BmRHbar9KI0Z5n/KFKU9DdTT3dnB5R85iDOnTaZSKePHtdPZ0Y6Azo52xo/LHnd3djD3tKnMnTWV7s4OBLxx7NZf2vHj2pk7ayqXnHwAl556AN2dHcDrv9LV3n2tn8pAu9I6uzs7uPTUA7bos1p9Z06bzBF7T6ipr+7ODubOmsrc015/z92dHRyx94SK9YrBd7QO+d6KHEQ8PypzBdnh32sj4suSzgOIiPn54d95wHSyw7/nRMSgI4R7EHGz5hjsqE2h55FExEJgYdm0+SWPA/hkkTWYWfF8ZquZJXOQmFmyQveRFEFSH/BMDU0nAmsKLifVaK9xtNcHrrEeaq1vz4ioeP5FywVJrST1VtsxNFqM9hpHe33gGuuhHvV508bMkjlIzCzZthwkVze7gBqM9hpHe33gGushub5tdh+JmTXOtrxGYmYN0vJBImm6pGWSlku6sMLrkvSN/PVHJB0yyur7aF7XI5LulTS1kfXVUmNJu0Ml9Uua1cj68r6HrFHSMZIekrRU0i9HU32SdpH0E0kP5/Wd0+D6rpX0vKTHqrye9j2JiJb9I7uG5yngrcBY4GFg/7I2M4Dbya5LmgbcN8rqew8wPn98QiPrq7XGknZ3k13yMGu01Qh0Ao8Dk/Pnu42y+v4d+Gr+uAtYC4xtYI1HAYcAj1V5Pel70uprJKP9LmxD1hcR90bEn/Oni8hupdBItXyGkN0S8ybg+UYWl6ulxjOAmyNiJUBENLLOWuoLYOf8QtWdyIJkU6MKjIh78j6rSfqetHqQ1O0ubAUZbt/nkv0qNNKQNUrqBk4B5tMctXyO+wLjJf1C0hJJZzWsutrqmwe8g+x+O48Cn46Ikd45sghJ35NCr/5tgLrdha0gNfct6X1kQXJkoRVV6LrCtPIarwAuiIh+NedOXrXUOAZ4N3As0AH8RtKiiHii6OKorb7jgYeA9wN7Az+T9KuIeKng2mqV9D1p9SAZ7Xdhq6lvSQcC1wAnRMQLDaptQC019gA35CEyEZghaVNE3NKQCmv/d14TEeuB9ZLuAaaSjWQwGuo7B7gssh0SyyU9DewH3N+A+mqR9j1p1M6egnYgjQFWAHvx+k6ud5a1OZEtdyLdP8rqmwwsB94zWj/DsvYLaPzO1lo+x3cAd+VtxwGPAe8aRfV9C/h8/vjNwGpgYoM/xylU39ma9D1p6TWSiNgk6XzgTl6/C9vS0ruwkR1lmEH2ZX2F7JdhNNV3EbAr8M38F39TNPACrxprbKpaaoyI30q6A3iEbNSKayKi4qHOZtQHfAlYIOlRsi/rBRHRsCuCJV0PHANMlLQKuBhoL6kv6XviM1vNLFmrH7Uxs1HAQWJmyRwkZpbMQWJmyRwkZpbMQWKFk7SHpB9LelLSCknzJO3Y7LqsfhwkVqj8IrWbgVsiYh9gH7JT2L/W1MKsrnweiRVK0rHAxRFxVMm0N5ENKTIpIl5uWnFWN14jsaK9E1hSOiGyC9V+D7ytGQVZ/TlIrGii8lWkTbmM2IrhILGiLSW7evg1+abNm4FlTanI6s5BYkW7Cxg3cKMhSW3A14F5EbGhqZVZ3ThIrFCR7c0/BZgl6UngBWBzRHy5uZVZPfmojTWUpPcA1wOnRsSSodpba3CQmFkyb9qYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWbL/Bxs9GKLaeDlMAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 288x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot Y vs Q for df_hS_LS\n",
    "plt.figure(1, figsize=(4, 3))\n",
    "plt.scatter(df_hS_LS[\"Q\"], df_hS_LS[\"Y\"])\n",
    "plt.ylabel(\"Y\")\n",
    "plt.xlabel(\"Q\")\n",
    "plt.title(\"Y vs Q for df_hS_LS\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23a34b5b",
   "metadata": {},
   "source": [
    "#### compute the optimal classifier hT_LS considering induced distribution shift"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "e091c7c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hT_LS: 0.6499364848136413\n",
      "Err_hT_hT_LS: 0.137\n"
     ]
    }
   ],
   "source": [
    "# compute the optimal hT_LS\n",
    "hT_LS, Err_hT_hT_LS = compute_optimal_threshold_MinInduceRisk_LS(df_init_LS, M = M, alpha = alpha, mu0 = mu0, \n",
    "                                                           mu1 = mu1, sigma = sigma, bottom = bottom, top = top)\n",
    "\n",
    "print(\"hT_LS:\", hT_LS)\n",
    "print(\"Err_hT_hT_LS:\",Err_hT_hT_LS)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "8a8ec6f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute the new data distribution induced by hT\n",
    "df_hT_LS = compute_new_distribution_LS(df_init_LS, hT_LS, alpha = alpha, mu0 = mu0, mu1 = mu1, \n",
    "                                       sigma = sigma, bottom = bottom, top = top)\n",
    "\n",
    "df_hT_LS['Q'] = clf.predict_proba(df_hT_LS[[\"X1\", \"X2\", \"X3\"]])[:, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "62bd68a2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 1.0, 'Y vs Q for df_hT_LS')"
      ]
     },
     "execution_count": 100,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAADgCAYAAADPGumFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAT3ElEQVR4nO3de5BcZZ3G8e8zk56YCegkZqAkEBNjUNEliiNkRQVlWULQjSDUcpOFZc1SJaxV7m5BrbWAuipISYEFmkIK4xXKC4sgAUQtxZXrRLkFBUIQCKBMCNckSCbz2z/OmdDpdM905u3TF/J8qibVffo95/1lkn767XN7FRGYmaXoanUBZtb5HCRmlsxBYmbJHCRmlsxBYmbJHCRmlsxBYmbJHCS23STtL+lBSS9K+mgB2z9R0v9tb3+SfiXpXxpdj43PQdImJH1P0mUVyw6Q9LSkNxTYb5+kr0v6s6QNku6R9E/jrPY54KKI2Ckiriqqtkb1J+n9eQi9KGm9pCh7/qKkWWOsWzOcJJ0s6Y+SXpD0F0nXStp5e+t7NZjU6gJsi38DVko6OCJulPQa4BvAv0fEk0V0KKkH+DnwFPC3wBrgIOBbkl4XEV+tseobgZUT7HNSRAxv52oT7g8gIn4D7JT3Pxt4GOibQB1bSDoA+CKwMCJ+L2k68JGJbq/TeUTSJiLiaeA04BJJU4GzgIciYlllW0kL8hFEd9mywyXdnT/eV9KgpOfzT8rza3T7cWAWcFREPBwRmyLierJQ+59qn66SHgLeBFyTf5pPlrSbpKslrZO0StInytqfLelHkr4r6XngxCrbfH2+/vOSbgfmjtXfOL/KN0r6bT5K+JmkGeO0n6j3ALdExO8BImJdRHwrIl4oqL+25iBpIxHxQ2AFcDmwBPjXGu1uBdYDHypbfCzw/fzxhcCFEfFasjflD2p0eTBwXUSsr1j+Y6AXWFCl77nAo8BH8q8af83rXQPsBhwJfFHSQWWrLQZ+BPQB36tSx8XAS8AbgH/Of8bqbyzHAicBuwA9wH+M036ibgMOkfTZfB/OeAH3quYgaT+fJAuIz0XEo2O0uxw4BiAfOSzKlwFsAt4saUZEvJgHTzUzgG2+NuVD/rVA/3jFStoDeB9wekS8FBF3ApeSjXZG3RIRV0XESERsrFi/G/gYcGZErI+Ie4FvjdfvGL4ZEQ/k/fwAeGfCtmrKvy4dAewDXAs8Len88lHijsRB0mYi4i9kb+Lx9gl8Hzgi/yQ8AvhdRDySv3YysCfwR0l3SPpwjW2sJRsFbEXSJLKQGaqj5N2AdRVD+keAmWXPHxtj/X6yfXXlbR6p0bYefy57vIF830gRIuK6iPgIMJ1s1HUisEMeNXKQdKiIuI/sDXcoW3+tISIejIhjyIb35wI/yve7VPo5cGiV1z5GNqq5vY5SngCmV+xPmQU8Xl7uGOsPAcPAHhXrd4x8pPUL4JfAO1pdTys4SDrb98l2jH4A+OHoQknHS+qPiBHg2Xzx5irrf4ds38YPJc2WVJJ0CPBV4MsR8dx4BUTEY8DNwJckvUbS3mQjomr7Qqqtvxm4EjhbUq+kvYDxDj8326T87zb6U5K0WNLRkqYpsy9wAFDra+SrmoOks10OHAj8MiLWli1fSHYo+UWyHa9HR8RLlSvnOy7/juxrxW3ARuB64ALgs9tRxzHAbLLRyf8CZ0XEjdux/qlkX0H+DCwDvrkd6zbD18l+N6M/3wSeAT4BPAg8D3wXOC8i6grQVxv5Dmk2SlIJuI7sa8mJ4f8cViePSGyLiNhEtn/kIeAtLS7HOohHJNZR8q9r1RyaH5Jti23uaBwkZpbMX23MLFnHXbQ3Y8aMmD17dqvLMNvhrFixYm1EVD3bueOCZPbs2QwODra6DLMdjqSaZxz7q42ZJXOQmFmywr7a5Hf7+jDwVERsc/2BJJGddbmI7OKqEyPid43oe/YZ1zZiM22vt9TFxk0j9PZ0s/7lV86Anzypi5GRETaNFNPvlFIX+8zq49bVz7C5ylE/MfbFNbVMntTFX4cLKtrq8qdzDpvQekWOSJaRnapdy6HAvPxnCdlpyMl2lBAB2LBphICtQgTgr8PFhQjAxk0j/PahdVVDBCYWIoBDpA1M9P1TWJBExE3AujGaLAa+HZlbgb4i701qZsVp5T6SmWx9D4o1bH0Piy0kLclvHTg4NFTPLTLMrJlaGSSqsqzqqDgiLomIgYgY6O8f96ZdZtZkrQySNWx9M5vdyS5DN7MO08oguRo4Ib8pzALguUZMuzDRvc6dqLfUhYCpPVvfJnTypC5KBf7LTil1sf/c6XSr2qCy+lCzHpMn+WyEVpvo+6fIw7+jN92ZIWkN2fQKJYCIWAosJzv0u4rs8O9Jjep7RwoTs3ZQWJDk9wwd6/Ugu2O6mXU4jyXNLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLJmDxMySFRokkhZKul/SKklnVHn9dZKukXSXpJWSGnYneTNrnsKCRFI3cDHZZOF7AcdI2qui2SeB+yJiPtnUFV+R1FNUTWZWjCJHJPsCqyJidUS8DFxBNnF4uQB2liRgJ7JJx4cLrMnMClBkkNQzSfhFwNvIpuq8B/hURIwUWJOZFaDIIKlnkvBDgDuB3YB3AhdJeu02G5KWSBqUNDg0NNToOs0sUZFBUs8k4ScBV0ZmFfAw8NbKDUXEJRExEBED/f39hRVsZhNTZJDcAcyTNCffgXo02cTh5R4FDgKQtCvwFmB1gTWZWQGKnPt3WNKpwA1AN3BZRKyUdEr++lLg88AySfeQfRU6PSLWFlWTmRWjsCABiIjlwPKKZUvLHj8B/H2RNZhZ8Xxmq5klc5CYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklc5CYWTIHiZklKzRIJC2UdL+kVZLOqNHmQEl3Slop6ddF1mNmxShsOgpJ3cDFwMFks+7dIenqiLivrE0f8DVgYUQ8KmmXouoxs+IUOSLZF1gVEasj4mXgCmBxRZtjyabsfBQgIp4qsB4zK0iRQTITeKzs+Zp8Wbk9gWmSfiVphaQTqm3Ik4ibtbcig0RVlkXF80nAu4HDgEOA/5a05zYreRJxs7ZW5JSda4A9yp7vDjxRpc3aiFgPrJd0EzAfeKDAusyswYockdwBzJM0R1IPcDRwdUWbnwDvlzRJUi+wH/CHAmsyswIUNiKJiGFJpwI3AN3AZRGxUtIp+etLI+IPkq4H7gZGgEsj4t6iajKzYiiicrdFexsYGIjBwcFWl2G2w5G0IiIGqr3mM1vNLJmDxMySOUjMLJmDxMySOUjMLJmDxMySOUjMLFnNIJG0XNLsJtZiZh1qrBHJMuBnkj4jqdSkesysA9U8RT4ifiDpWuBMYFDSd8hOYx99/fwm1GdmHWC8a202AeuBycDOlAWJmdmomkEiaSFwPtkVu/tExIamVWVmHWWsEclngKMiYmWzijGzzjTWPpL3N7MQM+tcPo/EzJI5SMwsmYPEzJI5SMwsmYPEzJI5SMwsmYPEzJIVGiSSFkq6X9IqSWeM0e49kjZLOrLIesysGIUFiaRu4GLgUGAv4BhJe9Vody7Z/Ddm1oGKHJHsC6yKiNUR8TJwBbC4SrvTgB8DTxVYi5kVqMggmQk8VvZ8Tb5sC0kzgcOBpWNtSNISSYOSBoeGhhpeqJmlKTJIVGVZ5bR+FwCnR8TmsTYUEZdExEBEDPT39zeqPjNrkMLm/iUbgexR9nx34ImKNgPAFZIAZgCLJA1HxFUF1mVmDVZkkNwBzJM0B3gcOBo4trxBRMwZfSxpGfBTh4hZ5yksSCJiWNKpZEdjuoHLImKlpFPy18fcL2JmnaPIEQkRsRxYXrGsaoBExIlF1mJmxfGZrWaWzEFiZskcJGaWzEFiZskcJGaWzEFiZskcJGaWzEFiZskcJGaWzEFiZskcJGaWzEFiZskcJGaWzEFiZskcJGaWzEFiZskcJGaWzEFiZskcJGaWzEFiZslaOom4pOMk3Z3/3CxpfpH1mFkxWj2J+MPAARGxN/B54JKi6jGz4rR0EvGIuDkinsmf3ko2G5+ZdZiWTiJe4WTgugLrMbOCFDlBVj2TiGcNpQ+SBcn7ary+BFgCMGvWrEbVZ2YNUuSIpJ5JxJG0N3ApsDginq62oYi4JCIGImKgv7+/kGLNbOKKDJItk4hL6iGbRPzq8gaSZgFXAh+PiAcKrMXMCtTqScTPBF4PfE0SwHBEDBRVk5kVQxFVd1u0rYGBgRgcHGx1GWY7HEkran3Q+8xWM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZA4SM0vmIDGzZA4SM0tW5Ex7SFoIXEg2HcWlEXFOxevKX18EbABOjIjfpfQ5+4xrU1bfYZW6YNPI2G26BBHZdIndEgveNI37nnyBZzZsGnO9qT3dRAQbanRQ6oLNASPbMaHBzL4pfPCt/fz0rid5dmP1/vefO51bHlpHea/zdpnKjZ8+kKt+/zjn3XA/jz+7EVFjCsjcpC4xXFac8j+qTcBQvq3eUheTS908u2ETfb0lIuC5jZvYLa/9yhVrtvxOJDhuv1k8PPQiv31o3VZ/hzn9O3H5bY+xuazDmX1T+M9D3gLAZ69ZueXfoG9KibP/4e189F0zOe4bt2y1rVFTe7pZ//LmrZaVuuC8o97JR9811qy6tRU2HYWkbuAB4GCyWffuAI6JiPvK2iwCTiMLkv2ACyNiv7G2O9Z0FA4Rq8euO/fw/Eub2bhp8/iN21ipS4wAmysSuNQlZs/o5cGn1m/3Ni/4x9ph0qrpKPYFVkXE6oh4GbgCWFzRZjHw7cjcCvRJekOBNZnxlxde7vgQAdg0EtuEyOjyiYQIwHk33D+h9YoMkpnAY2XP1+TLtrcNkpZIGpQ0ODQ01PBCzSzzxLMbJ7RekUGiKssq47OeNp5E3KxJduubMqH1igySNcAeZc93B56YQBuzhtp15x6mlLpbXUayUpfo7tr2s7jUJebtMnVC2xzdgbu9igySO4B5kuZI6gGOBq6uaHM1cIIyC4DnIuLJiXb4p3MOm3i1O7hSHf8TuvTKELJbYv+505nWWxp3vak93fSO0UGpK9v29pjZN4XjF8yib0rt/vefO32b/+DzdpnKbZ85mC8d8TfMzD99x+t6UkVxIjvKUk354t5SF9N6SwiY1luib0r2eLT28t+JBMcvmMX+c6dv83c4fsEsuis6nNk3hfOOms9Xjpq/1b9B35QS5x01nxs/feA22xo1tWfbEC11jb2jdTyFTiKeH5W5gOzw72UR8QVJpwBExNL88O9FwEKyw78nRcSYM4R7EnGz1hjrqE2h55FExHJgecWypWWPA/hkkTWYWfF8ZquZJXOQmFmyQveRFEHSEPBIHU1nAGsLLidVu9fY7vWBa2yEeut7Y0RUPf+i44KkXpIGa+0YahftXmO71weusREaUZ+/2phZMgeJmSV7NQfJJa0uoA7tXmO71weusRGS63vV7iMxs+Z5NY9IzKxJOj5IJC2UdL+kVZLOqPK6JH01f/1uSfu0YY3H5bXdLelmSfPbqb6ydu+RtFnSkc2sL+973BolHSjpTkkrJf26neqT9DpJ10i6K6/vpCbXd5mkpyTdW+P1tPdJRHTsD9k1PA8BbwJ6gLuAvSraLAKuI7ueagFwWxvW+F5gWv740GbWWE99Ze1+SXbJw5Ft+DvsA+4DZuXPd2mz+v4LODd/3A+sA3qaWOMHgH2Ae2u8nvQ+6fQRSSfchW3cGiPi5oh4Jn96K9ntFNqmvtxpwI+Bp5pY26h6ajwWuDIiHgWIiGbWWU99AeycX6i6E1mQDDerwIi4Ke+zlqT3SacHScPuwlag7e3/ZLJPhmYZtz5JM4HDgaW0Rj2/wz2BaZJ+JWmFpBOaVl199V0EvI3sfjv3AJ+KiHFut91USe+TQq/+bYKG3YWtQHX3L+mDZEHyvkIrqui2yrLK+i4ATo+Izap1I45i1VPjJODdwEHAFOAWSbdGxANFF0d99R0C3Al8CJgL3CjpNxHxfMG11SvpfdLpQdIJd2Grq39JewOXAodGxNNNqg3qq28AuCIPkRnAIknDEXFVUyqs/995bUSsB9ZLugmYTzaTQTvUdxJwTmQ7JFZJehh4K3B7E+qrR9r7pFk7ewragTQJWA3M4ZWdXG+vaHMYW+9Eur0Na5wFrALe246/w4r2y2j+ztZ6fodvA36Rt+0F7gXe0Ub1fR04O3+8K/A4MKPJv8fZ1N7ZmvQ+6egRSUQMSzoVuIFX7sK2svwubGRHGRaRvVE3kH0ytFuNZwKvB76Wf+oPR5Mu8qqzvpaqp8aI+IOk64G7gRGyCdmqHupsRX3A54Flku4he7OeHhFNuyJY0uXAgcAMSWuAs4BSWX1J7xOf2WpmyTr9qI2ZtQEHiZklc5CYWTIHiZklc5CYWTIHiRVO0u6SfiLpQUmrJV0kaXKr67LGcZBYofKL1K4EroqIecA8slPYv9zSwqyhfB6JFUrSQcBZEfGBsmWvJZtSZI+IeLFlxVnDeERiRXs7sKJ8QWQXqv0JeHMrCrLGc5BY0UT1q0hbchmxFcNBYkVbSXb18Bb5V5tdgftbUpE1nIPEivYLoHf0RkOSuoGvABdFxMaWVmYN4yCxQkW2N/9w4EhJDwJPAyMR8YXWVmaN5KM21lSS3gtcDhwRESvGa2+dwUFiZsn81cbMkjlIzCyZg8TMkjlIzCyZg8TMkjlIzCyZg8TMkv0/CMrPAb3KDHkAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 288x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot Y vs Q for df_hT_LS\n",
    "plt.figure(1, figsize=(4, 3))\n",
    "plt.scatter(df_hT_LS[\"Q\"], df_hT_LS[\"Y\"])\n",
    "plt.ylabel(\"Y\")\n",
    "plt.xlabel(\"Q\")\n",
    "plt.title(\"Y vs Q for df_hT_LS\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89bde9aa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "aecf3efe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Err_hS_hS_LS - Err_hT_hT_LS 0.03366666666666665\n"
     ]
    }
   ],
   "source": [
    "# compute Err_hs_hs and Err_hT_hT\n",
    "hS_LS =  compute_optimal_threshold(df_init_LS, M = M)\n",
    "df_hS_LS = compute_new_distribution_LS(df_init_LS, hS_LS, alpha = alpha, mu0 = mu0, mu1 = mu1, \n",
    "                                       sigma = sigma, bottom = bottom, top = top)\n",
    "\n",
    "# compute the coefficients on the initial data distribution\n",
    "clf = LogisticRegression(solver='liblinear', fit_intercept=True)\n",
    "clf.fit(df_init_LS[[\"X1\", \"X2\", \"X3\"]], df_init_LS['Y'])\n",
    "# compute Q for df_hs_LS\n",
    "df_hS_LS['Q'] = clf.predict_proba(df_hS_LS[[\"X1\", \"X2\", \"X3\"]])[:, 1]\n",
    "\n",
    "\n",
    "Err_hS_hS_LS = compute_error(df_hS_LS, hS_LS)\n",
    "hT_LS, Err_hT_hT_LS = compute_optimal_threshold_MinInduceRisk_LS(df_init_LS, M = M, alpha = alpha, mu0 = mu0, \n",
    "                                                                 mu1 = mu1, sigma = sigma, bottom = bottom, top = top)\n",
    "\n",
    "# compute df_hT_LS\n",
    "df_hT_LS = compute_new_distribution_LS(df_init_LS, hT_LS, alpha = alpha, mu0 = mu0, mu1 = mu1, \n",
    "                                       sigma = sigma, bottom = bottom, top = top)\n",
    "\n",
    "# compute Q for df_hT_LS\n",
    "df_hT_LS['Q'] = clf.predict_proba(df_hT_LS[[\"X1\", \"X2\", \"X3\"]])[:, 1]\n",
    "\n",
    "print(\"Err_hS_hS_LS - Err_hT_hT_LS\", Err_hS_hS_LS -Err_hT_hT_LS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "dfbbdae7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute w(hS), w(hT) and p\n",
    "w_hS_LS = (np.sum(df_hS_LS['Q'] > hS_LS))/len(df_hS_LS)\n",
    "w_hT_LS = (np.sum(df_hT_LS['Q'] > hT_LS))/len(df_hT_LS)\n",
    "p_LS = (np.count_nonzero(df_init_LS['Y'] == 1))/len(df_init_LS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "e5960f4e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dtv_ST_plus 0.010666666666666713\n",
      "dtv_ST_minus 0.009333333333333332\n"
     ]
    }
   ],
   "source": [
    "# compute d_TV(D+(hS), D+(hT)) and d_TV(D-(hS), D-(hT))\n",
    "# Pr_{DS|Y = +1}(hS(X) = +1)\n",
    "P_DSY1_hS1 = np.sum((df_init_LS[\"Q\"] > hS_LS) & (df_init_LS[\"Y\"] == 1))/np.sum(df_init_LS[\"Y\"] == 1)\n",
    "# Pr_{DS|Y = +1}(hT(X) = +1)\n",
    "P_DSY1_hT1 = np.sum((df_init_LS[\"Q\"] > hT_LS) & (df_init_LS[\"Y\"] == 1))/np.sum(df_init_LS[\"Y\"] == 1)\n",
    "\n",
    "dtv_ST_plus = np.abs(P_DSY1_hS1 - P_DSY1_hT1)\n",
    "\n",
    "print(\"dtv_ST_plus\", dtv_ST_plus)\n",
    "\n",
    "# Pr_{DS|Y = 0}(hS(X) = +1)\n",
    "P_DSY0_hS1 = np.sum((df_init_LS[\"Q\"] > hS_LS) & (df_init_LS[\"Y\"] == 0))/np.sum(df_init_LS[\"Y\"] == 0)\n",
    "# Pr_{DS|Y = 0}(hT(X) = +1)\n",
    "P_DSY0_hT1 = np.sum((df_init_LS[\"Q\"] > hT_LS) & (df_init_LS[\"Y\"] == 0))/np.sum(df_init_LS[\"Y\"] == 0)\n",
    "\n",
    "dtv_ST_minus = np.abs(P_DSY0_hS1 - P_DSY0_hT1)\n",
    "\n",
    "print(\"dtv_ST_minus\", dtv_ST_minus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "4002cf84",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "UB_LS 0.048333333333333436\n"
     ]
    }
   ],
   "source": [
    "# compute the upper bound\n",
    "UB_LS = np.abs(w_hS_LS -  w_hT_LS) + (1+p_LS)*(dtv_ST_plus + dtv_ST_minus)\n",
    "print(\"UB_LS\", UB_LS)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23ddf7df",
   "metadata": {},
   "source": [
    "### Lower Bound"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61dfe977",
   "metadata": {},
   "source": [
    "Label Shift Lower Bound is:\n",
    "\n",
    "                   max{Err(DS)(h), Err(h)(h)} >= |p - w(h)|(1 - |TPR_S(h) - FPR_S(h)|)/2\n",
    "\n",
    "Quantity needs to be computed (lower bound):\n",
    "    \n",
    "    - Err_Ds(h)\n",
    "    - Err_h(h), Err_h(h), where h\\in {hs, hT}\n",
    "    - TPR_S(h), FPR_S(h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "b5ab3e83",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max of {Err_S_hs_LS, Err_hs_hs_LS} 0.16366666666666665\n",
      "LB_hS 0.0021979933110367897\n"
     ]
    }
   ],
   "source": [
    "# lower bound for hS:\n",
    "Err_S_hS_LS = compute_error(df_init_LS, hS_LS)\n",
    "print(\"max of {Err_S_hs_LS, Err_hs_hs_LS}\", max(Err_S_hS_LS, Err_hS_hS_LS))\n",
    "\n",
    "TPR_S_hS = np.sum((df_init_LS[\"Q\"] > hS_LS) & (df_init_LS[\"Y\"] == 1))/np.sum(df_init_LS[\"Q\"] > hS_LS)\n",
    "FPR_S_hS = np.sum((df_init_LS[\"Q\"] > hS_LS) & (df_init_LS[\"Y\"] == 0))/np.sum(df_init_LS[\"Q\"] > hS_LS)\n",
    "\n",
    "\n",
    "LB_hS = np.abs(p_LS - w_hS_LS)*(1 - np.abs(TPR_S_hS - FPR_S_hS))/2\n",
    "print(\"LB_hS\", LB_hS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "4bafb921",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max of {Err_S_hT_LS, Err_hT_hT_LS} 0.13\n",
      "LB_hT 0.0023372013651877157\n"
     ]
    }
   ],
   "source": [
    "# lower bound for hT:\n",
    "Err_S_hT_LS = compute_error(df_init_LS, hT_LS)\n",
    "print(\"max of {Err_S_hT_LS, Err_hT_hT_LS}\", max(Err_S_hT_LS, Err_hT_hT_LS))\n",
    "\n",
    "TPR_S_hT = np.sum((df_init_LS[\"Q\"] > hT_LS) & (df_init_LS[\"Y\"] == 1))/np.sum(df_init_LS[\"Q\"] > hT_LS)\n",
    "FPR_S_hT = np.sum((df_init_LS[\"Q\"] > hT_LS) & (df_init_LS[\"Y\"] == 0))/np.sum(df_init_LS[\"Q\"] > hT_LS)\n",
    "\n",
    "\n",
    "LB_hT = np.abs(p_LS - w_hT_LS)*(1 - np.abs(TPR_S_hT - FPR_S_hT))/2\n",
    "print(\"LB_hT\", LB_hT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e05f8726",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "id": "472ab7e3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7faa1053c190>"
      ]
     },
     "execution_count": 107,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAADQCAYAAACnZrwtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAaM0lEQVR4nO3df3SU1ZnA8e8DRVGw6oG4pUQKGqGrCFHS7haPDYhaf4BNIrtirdXFij8Wnc4utrrdCu1xj93VksalHsVq0R4XRTtmxZ9lXX7oahXUKD+6iFU8Zq0Fo3K0gibk2T/emTAZZpKZyfvOfeed53POHDI3k/feyROeed9773uvqCrGGOPCINcNMMZULktAxhhnLAEZY5yxBGSMccYSkDHGGUtAxhhnPue6AUEYOXKkjh071nUzTJoXX3zxPVWt8vOYFufwKTTOkUxAY8eOZcOGDa6bYdKIyFt+H9PiHD6FxtkuwYwxzkQqAYnILBFZumvXLtdNMQGyOLuhqjz00ENk3j2RqzwfkUpAqrpSVecdeuihrptiAmRxdqO1tZWmpibi8XhPslFV4vE4TU1NtLa2FnzMSPYBhca0ad6/a9a4bIUxvmhoaCAWi9HS0gJAc3Mz8XiclpYWYrEYDQ0NBR/TEpAxJi8iQnNzMwAtLS09iSgWi9Hc3IyIFHzMSF2CGWOClZ6EUopNPmBnQP5LXXYBrF27f5ldjpkylurzSRePx+0MyITItGm9k66JhFTySfX5dHd39/QJpXdMF8LOgPyWfoZjndAmQlpbW3uST+qMJ71PqL6+vuBjRioBicgsYFZNTY3rppgAWZzdaGhoIJFI0NDQ0HO5lUpC9fX1RY2CSRSXZK2rq9NQTNGvpDOgbH1faZ+Isnbti6pa52eVoYmz6SEiBcU5UmdAoVMJiceYAbAEZPzRX99XkcO0JtpsFMyYMhPEPVmuWAIypswEcU+WK3YJZvxnfV+BCuKeLFcsARlTZoK4J8sVuwSLKpuNHGl+35PliiUgY8pQrnuyyqkDGsogAYnIUSJyp4g86LotxoRBEPdkuRJoH5CI3AXMBHao6sS08jOAFmAw8EtV/WmuY6jqG8AlloDyYHfiV4R87slqbGx03Mr8BN0JvQxYAtyTKhCRwcAvgNOAdmC9iDyMl4xuzPj5uaq6I+A2GlNWgrgny5VAE5CqrhORsRnFXwVeT57ZICL3Ad9U1RvxzpaKIiLzgHkAY8aMKfYw5a0C7sS3OHvJJtsZTq7yMHPRBzQaeDvteXuyLCsRGSEitwEniMh1uV6nqktVtU5V66qqfN3/zoSIxTlaXMwDyjZOmLPXTFU7gMvzOrAt01ARLM7R4eIMqB04Mu15NfCOHwe27VrSrFlT0suv/u5P8rkui3NEuEhA64FjRGSciBwAzAEedtAO46P+7k8CDnPZPhNOgSYgEVkOPAdMEJF2EblEVbuA+cCTwO+BFaq62af6bMdMR9LvT0olofS5KsCHftVlcY4QVY3cY8qUKWpKr7u7W2OxmOL16SmgsVhMu7u7FdigFufIKzTOoZ8JXQj7ZHSrVPcnWZyjI1IJSK1z0ikt0f1JFufoiFQCMu6kkk+u+5OMySZS6wHlnB8S0VnBYdLf/Un4OApm84CiI1JnQHZq7k7q/qT0Pp9UEkokEuDjKJjFOToidQZk3InS/UmmdKKbgGxpCmNCL1KXYDY8WxksztFRGVszWye0c4Vu2ZsP25o5fAqNc6TOgIwx5cUSkDHGmUh1QuecH2KXXpFi84CiI1JnQDY/pDJYnKMjUgnIGFNeLAEZY5yxBGSMKZxPW39bAjLGOBOpBGQzZCuDxTk6IpWAKn10RPvZmSIqs94rPc7OpC67pk3z7q9cu7Z3WREilYAqXX87U7S2trptoDEZIjURsdKl70wB3nrM6asUltOe4SaEAtj62xJQhGSuQphKROmrFBoTJnYJFjGl2pnCGD9YAoqYUu1MYSqcT1t/WwKKkP52prAkZEKnkF0Mw/4AZgFLa2pq8tnEMXISiUSv3UhVe+9WmkgknLUNH3dGDUOcu7u7NZFI9Pye+yuvFIXG2XnSCOJRqVv2hvk/hZ8JSEMQ5zAne5csAVVwAgqzqCWg9GSTSkKZzytRoXG2YXhjimBTHvxhndDGFMmmPAycJSBT1lze+6ZqUx4GyhKQKVsffvihs3vfUvXYlIcBKqTDqFwe1gkdPgTUCe2qI9hGwbIrNM7Ok0UQD0tA4RNUAkr/T596lGIUKsxTHlwqNM6VsTOqcS7InVFVlUGD9vUmdHd3W0ewI5HcGVVEGkTkDhH5TxE53XV7THioWkdwOQs8AYnIXSKyQ0Q2ZZSfISJbReR1Ebm2r2OoaquqXgpcDJwXYHNNmbGO4DJXyPVaMQ/g68CJwKa0ssHAH4CjgAOAV4BjgeOBRzIeR6T93M+AE/ur0/qAwocA+oCOPvpo6wgOmULjHPhMaFVdJyJjM4q/Cryuqm8AiMh9wDdV9UZgZuYxxLug/ynwuKq+lK0eEZkHzAMYM2aMf2/AhEpmnBOJBA0NDT19PqnJgfX19bYCZBnI+xJMRIb5WO9o4O205+3JslyuAk4FZovI5dleoKpLVbVOVeuqqqr8a6kJlcw4NzY27tfhLCJZy0349JuARGSqiGwBfp98PllEbh1gvdn+MnJesKvqLao6RVUvV9Xb+mirbddSASzO0ZHPGVAz8A2gA0BVX8Hr1xmIduDItOfVwDsDPCZq27VUBItzdOR1Caaqb2cU7R1gveuBY0RknIgcAMwBHh7gMY0xZSafBPS2iEwFVEQOEJEFJC/H8iEiy4HngAki0i4il6hqFzAfeDJ5rBWqurmI9mfWZafmFcDiHB39zoQWkZFAC14nsAC/BWKq2hF884pjM6HDJ8iZ0M75tEdWFBQa536H4VX1PeCCAbWqRERkFjCrpqbGdVNMgCzO0dFvAhKRX5FlhEpV5wbSogFQ1ZXAyrq6uktdt8UEx+IcHflMRHwk7euhQCM+jFgZU9ZSl10Aa9fuX2aXY3nJ5xLsN+nPk53K/xVYiwbATs0rg8U5OgpejkNEJgCPqmpoox+azknTwzqhK4PvndAi8hFeH5Ak/30X+EHRLTTGmKR8LsEOKUVDjDGVJ2cCEpET+/rBXHelu2R9A5UhdHG2S6+i5ewDEpHVffycquopwTRp4ELTN2B6RLoPyPTwrQ9IVaf70yRjjMkurwXJRGQi3oqFQ1NlqnpPUI0yxlSGfEbBFgLT8BLQY8CZwDNA6BJQ6PoGTCCyxtmGwstSPnfDzwZmAO+q6t8Bk4EDA21VkWydmMpgcY6OfBLQHlXtBrpE5PPADrzF5I0xZkD6GoZfAiwHXhCRw4A7gBeBj4EXStI6Y/qydeu+Sy+7H6ss9dUHtA24GfgiXtJZDpwGfF5VXy1B24wxEdfXMHwL0CIiX8JbMvVXeKNgy0Vkt6puK1EbjcluwoR9ZznWCV2W+u0DUtW3VPVfVfUE4Ft4y3H8b+AtK4It1VkZLM7Rkc+2PEOSAb8XeBx4DTg38JYVwUZHKoPFOTr66oQ+DTgfOBuv0/k+YJ6q/rlEbTMmf3bpVZb66oT+J+A/gAWq+n6J2mOMqSA5L8FUdbqq3mHJpzCqykMPPUTmTb65yo2pZHnvDW/y09raSlNTE/F4vCfZqCrxeJympiZaW1vdNtCYEMnrZlSTv4aGBmKxGC0tLQA0NzcTj8dpaWkhFovR0NDgtoHGhIglIJ+JCM3NzQC0tLT0JKJYLEZzczMi4rJ5xoRKpC7BwjI/JD0JpVjy8U9Y4mwGLlIJKCzzQ1J9PunS+4TMwIQlzmbgIpWAwiCVfFJ9Pt3d3T19QpaEjOnN+oB81tra2pN8Updd6X1C9fX1NDY2Om6lMeEQ2TMgV/NxGhoaSCQSvfp8UkkokUjYKJgxaSKbgFzNxxERGhsb9+twzlVuTCWL7CWYzccxJvwim4BsPo4x4RfZSzCw+TjGhF2kE5DNxzEm3EKfgETkL0XkNhF5UESuyPfnbD6OMWVAVQN7AHfhbeOzKaP8DGAr8DpwbZ7HGgTcmc9rp0yZoolEQgGNxWLa3d2tqqrd3d0ai8UU0EQioaZ0gA3q89/XlClTSvsmTL8KjXPQndDLgCWk7aIqIoOBX+DtsNEOrBeRh4HBwI0ZPz9XVXeIyDnAtclj5SU1H6ehoWG/+Tj19fU2CmZMCIgGfCkiImOBR1R1YvL514BFqvqN5PPrAFQ1M/lkO9ajqnp2ju/NA+YBjBkzZspbb73lzxswvhCRF1W1zofjWJxDrNA4u+gDGg28nfa8PVmWlYhME5FbROR2vL3ps1LVpapap6p1VVVV/rXWhIrFOVpczAPKNgae8zRMVdcAa4JqjDHGHRdnQO3AkWnPq4F3/DiwrRNTGSzO0eEiAa0HjhGRcSJyAN6uqw/7cWC1dWIqgsU5OgJNQCKyHHgOmCAi7SJyiap2AfOBJ4HfAytUdbNP9dknYwWwOEdH4KNgLtTV1emGDRtcN8Ok8WsULJ3FOXzKYRQsMPbJWBksztERqQRkfQOVweIcHZFKQMaY8mIJyBjjTKQWJBORWcCsmpqa/b7X2dlJe3s7e/bsKX3DAjJ06FCqq6sZMmSI66aUVKXFOSWK8a6YUbA333yTQw45hBEjRkRiQTJVpaOjg48++ohx48a5bk6/SjUKFrU4p5RLvCt6FKwve/bsidQfpYgwYsSISH7SD0TU4pwS1XhHKgH1NzwbxT/KSlRpcU6J4vuKVAIK+/Ds4MGDqa2t5bjjjmPy5MksXryY7u5uADZs2MDVV18NwKeffsqpp55KbW0t999/P08//TTHHXcctbW17N692+VbCIWwx1lEuPDCC3ued3V1UVVVxcyZMx22KpwilYD8oAFuaHjQQQfR1tbG5s2bWbVqFY899hg//vGPAairq+OWW24B4OWXX6azs5O2tjbOO+887r33XhYsWEBbWxsHHXRQ8W/O9AgyzsOGDWPTpk09HxarVq1i9OicK85UNEtAGUq1oeERRxzB0qVLWbJkCarKmjVrmDlzJjt27ODb3/42bW1t1NbWcvvtt7NixQp+8pOfcMEFF/hStwk+zmeeeSaPPvooAMuXL+f888/v+d4LL7zA1KlTOeGEE5g6dSpbt24FYPHixcydOxeAjRs3MnHiRD755JMBtSP0Clm/NewPYBawtKamZr+1ards2bJfWTbp60an1pPOfF6sYcOG7Vd22GGH6bvvvqurV6/Ws88+W1W119eqqhdddJE+8MADWY+Z7/tyDR/XhC6HOL/yyit67rnn6u7du3Xy5Mm9Yrpr1y7t7OxUVdVVq1ZpU1OTqqru3btXTz75ZE0kEjplyhR95plnin5/rhQa50jNA1LVlcDKurq6S4s9Rqk3NNQIToMIWjnEedKkSWzfvp3ly5dz1lln9frerl27uOiii9i2bRsiQmdnJwCDBg1i2bJlTJo0icsuu4yTTjppQG0oB3YJlkWpNjR84403GDx4MEcccYSvxzX5CTrO55xzDgsWLOh1+QXwox/9iOnTp7Np0yZWrlzZa2h927ZtDB8+nHfe8WWNvtCzBJSFavAbGu7cuZPLL7+c+fPnR3J4tRwEHee5c+dy/fXXc/zxx/cq37VrV0+n9LJly3qVx2Ix1q1bR0dHBw8++KAv7QgzS0AZUn+UQWxouHv37p5h+FNPPZXTTz+dhQsX+th6k68g45xSXV1NLBbbr/z73/8+1113HSeddBJ79+7tKY/H41x55ZWMHz+eO++8k2uvvZYdO3YMuB2hVkiHUbk8sm1Yl2/nXbltaBj2TskUSrQxYVTjnBL2eBcaZ+fJws8HPo2OJBKJ/UZBcpW7FvY/yBQ/E1Alxjkl7PEuNM6RugRTH2bIigiNjY379cvkKjelZ3GOjkglIGNMebEEZIxxxhKQMcYZS0DGGGcsAZXQ9u3bmThxYq+yRYsWcfPNN3PxxRczbtw4amtr+fKXv9xzl7wpP8OHD9+vbNGiRYwePbonvldccUXPUiyVzBJQX6ZN8x4lctNNN9HW1kZbWxt33303b775ZsnqrmglinM8HqetrY0tW7awceNG1q5dG3idYRepBBSVDetS9wYNGzbMcUvCqdzj/Nlnn7Fnzx4OP/xw101xLlIJyI/5IS5dc8011NbWUl1dzZw5c+wm1RzKNc7Nzc3U1tYyatQoxo8fT21tresmORep5Th8kX4qnjpFTi9bs6boQ+ea3JYqv+mmm5g9ezYff/wxM2bM4Nlnn2Xq1KlF12f6EGCcc4nH4yxYsIDOzk5mz57Nfffdx5w5c3yvp5xE6gwo7EaMGMEHH3zQq+z9999n5MiRvcqGDx/OtGnTeOaZZ0rZPFMiQ4YM4YwzzmDdunWum+KcnQFlSv/kS30i+vRpOHz4cEaNGsVTTz3FjBkzeP/993niiSeIxWKsXr2653VdXV08//zzXHXVVb7Ua7IIMM79UVWeffZZuwTDzoBK7p577uGGG26gtraWU045hYULF3L00UcD+/qAJk2axPHHH09TU5Pj1ppifPLJJ1RXV/c8Fi9eDOzrA5o4cSJdXV1ceeWVA6+sxCO1frMzoBI79thje53tpKQvTGXKW675PYsWLSptQ8qAJaC+lOiU3DhmcXbGEpAx5cbBCF5QrA/IGONMWZwBicgwYB2wUFUfKfY4qhqphaa8BehMpqjFOaUn3g5H8PwW6BmQiNwlIjtEZFNG+RkislVEXheRa/M41A+AFQNpy9ChQ+no6IjMf1pVpaOjg6FDh7puSqhELc4pUY130GdAy4AlwD2pAhEZDPwCOA1oB9aLyMPAYODGjJ+fC0wCtgAD+s1XV1fT3t7Ozp07B3KYUBk6dCjV1dWumxEqUYxzShTjHWgCUtV1IjI2o/irwOuq+gaAiNwHfFNVbwRmZh5DRKYDw4Bjgd0i8piqFryOwZAhQxg3blyhP2bKTMXFuUwvvVJc9AGNBt5Oe94O/FWuF6vqDwFE5GLgvVzJR0TmAfMAxowZ41dbTchYnKPFxShYtt7Bfi/YVXVZXx3QqrpUVetUta6qqmpADTThZXGOFhcJqB04Mu15NeDLRtjlvk6MyY/FOTok6NGCZB/QI6o6Mfn8c8BrwAzg/4D1wLdUdbOPde4E3sooHgm851cdBXBVb9jq/pKq+nrKkiXOYXq/lVB3tnoLinOgfUAishyYBowUkXa8eTx3ish84Em8ka+7/Ew+ANl+ASKyQVXr/KwnH67qrYS6M+Mc9fcbtrr9qDfoUbDzc5Q/BjwWZN3GmPCzWzGMMc5UUgJaWmH1VmLdlfZ+Xdc94HoD74Q2xphcKukMyBgTMmVxN3x/ROQLwM+BrwCfAtuB7wHzgVPwJjruAf5WVd8UkX8BvgMcrqr7b2MZUN3An4AHgKOBvcBKVc3nZtxs9Y4Anko+/ULyeKkboB5K1rcX6AYuU9Xnk6OP30vWX6WqBQ/dFlnvvUAd0Am8kCzvLKLuioqzqxgPoO7C46yqZf3Am1n9HHB5Wlkt8CPgQWBQsqwa7w8R4K+BUcDHpawbOBiYniw7AHgaONOH38EiYEHy668l23Rg8vlI4IvJr08AxuL95xlZwnrPSv6uBFgOXGFxLo8YBx3nKJwBTQc6VfW2VIGqtonIKcAfNXnvmKq2p33/d5B7n64g6wZWJ8s+E5GX8P5o/TQK7565T5P19HwCqurL4Mv7LrTenikXIvICxb1ni/M+rmLcX90FxzkKfUATgRezlK8AZolIm4j8TEROCFPdInIYMIt9p7l++S1wpIi8JiK3iki9z8cvul4RGQJcCDxRxPEtzvu4inFedRcS5ygkoKySn0YTgOvwrlOfEpEZYag7eTvKcuAWTS5L4mPdHwNT8O4Y3wncn1xJIFB51nsrsE5Vn/ax3oqLs6sYF1B33nGOwiXYZmB2tm8kTxMfBx4XkT8BDfj7SVRs3UuBbar6cx/bkl73XmANsEZENgIX4S0OF6i+6hWRhUAVcFmRh7c4967XSYz7q7vQOEfhDOi/gQNF5NJUgYh8RUTqReSLyeeD8FZWzLxBteR1i8gNwKF4IxW+E5EJInJMWlEt/r/vguoVke8C3wDO1yIWk0uyOO+r20mM+6u7mDiXfQJSr/u9EThNRP4gIpvxeu0nASvFW4/6VaALb3lYROTfkjfHHiwi7SKyqBR1i0g18EO81R1fSvYdfLfIt57LcOBuEdkiIq8m61oEICJXJ993NfCqiPyyFPUCtwF/ATyXfM/XF3pwi3MvrmLcZ90UEWebCW2Mcabsz4CMMeXLEpAxxhlLQMYYZywBGWOcsQRkjHHGEpBPRGRvcuhxs4i8IiL/kJwb0tfPjBWRbwXQlu+JyMF+H9dYnP1mCcg/u1W1VlWPw9t2+ixgYT8/Mxbw/Q8Tb/JbWf9hhpjF2U9+3K5vD4WMJR+Ao4AOvKUJxuItyfBS8jE1+ZrfAbuANiDex+tGAeuSr9sEnJwsPx1vaYSX8NafGQ5cDXwGbMS7I3sw3jT5TcmyuOvfVTk/LM4+/z5dNyAqj8w/zGTZB3gzQw8GhibLjgE2JL+ehrdnWur1uV73j8APk18PBg7BW4dlHTAsWf4D4Prk19tJrgWDd+PgqrQ6DnP9uyrnh8XZ30cUbkYNs9SiLEPwpujX4q0iNz7H63O9bj1wV3KZg1b11qKpx5sG/z/irf1yAN6nZKY3gKNE5N+BR/GWUzD+sjgXyRJQQETkKLw/rh14fQR/Aibj9bvtyfFj8WyvU9V1IvJ14Gzg1yJyE96n7irNsfdaiqp+ICKT8W4S/Hu8pTTnDuzdmRSL88BYJ3QARKQK78a8JeqdDx/KvpXzLsQ7vQb4CO80OyXr60TkS8AOVb0DuBM4Ea9f4SQRqUm+5mARGZ95XBEZibdk6G/wlhA9MZh3XXkszgNnZ0D+OUhE2vBOr7uAXwOLk9+7FfiNiPwNXofhn5PlrwJdIvIKXgdirtdNA64RkU7gY+A7qrpTvIWglovIgcnX/TPwGt46NI+LyB/xRkp+lTZUfJ2/b7viWJx9ZHfDG2OcsUswY4wzloCMMc5YAjLGOGMJyBjjjCUgY4wzloCMMc5YAjLGOGMJyBjjzP8D3A6lRl6JrzQAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 288x216 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "data_UB = {'CS1': [0.00133, 0.0498], 'CS2': [0.002, 0.216], \n",
    "        'TS1': [0.023, 0.105], 'TS2': [0.072, 0.18]}\n",
    "\n",
    "data_LB = {'CS1': [0.0126, 0.00333], 'CS2': [0.016, 0.006], \n",
    "        'TS1': [0.130, 0.000124], 'TS2': [0.24, 0.0266]}\n",
    "\n",
    "names = list(data_UB.keys())\n",
    "\n",
    "values_diff = [item[0] for item in list(data_UB.values())]\n",
    "values_UB = [item[1] for item in list(data_UB.values())]\n",
    "\n",
    "values_max = [item[0] for item in list(data_LB.values())]\n",
    "values_LB = [item[1] for item in list(data_LB.values())]\n",
    "\n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(1, 2, figsize=(4, 3), sharey=True)\n",
    "\n",
    "axs[0].scatter(names, values_diff, marker= 'x', color = 'black', s = 50, label='Diff')\n",
    "axs[0].scatter(names, values_UB, marker= '+', color = 'red', s = 50,  label='UB')\n",
    "axs[0].set_xlabel(\"Datasets\")\n",
    "axs[0].set_ylabel(\"Value\")\n",
    "axs[0].set_yscale('log')\n",
    "axs[0].legend(loc=\"lower left\")\n",
    "\n",
    "\n",
    "axs[1].scatter(names, values_max, marker= 'x', color = 'black', s = 50, label='Max')\n",
    "axs[1].scatter(names, values_LB, marker= '+', color = 'red', s = 50, label='LB')\n",
    "axs[1].set_xlabel(\"Datasets\")\n",
    "#axs[1].set_ylabel(\"Value\")\n",
    "axs[1].set_yscale('log')\n",
    "axs[1].legend(loc=\"lower left\")\n",
    "\n",
    "#fig.suptitle('Actual VS Theoretical Bound for Each Dataset')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6ab281b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9f52fb6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c729baa8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52424535",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "572a0ef3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
