{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from main import load_data\n",
    "from utils import Acc, NotAbstainAcc, AdjustAcc, Flip_L, CheckLFs_Acc, Snorkel, GetStats\n",
    "from label_prop import PropagationSoft, PropagationHard\n",
    "from snorkel.labeling.model import LabelModel\n",
    "from extension import alpha_from_LPA, LPA_with_dongle_with_labeled_inds_custom_alpha, LPA_with_dongle_with_labeled_inds\n",
    "from extension import Adaboost_weight_norm, Adaboost_weight\n",
    "from extension import Generate_data_var_reg, alpha_from_reg\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_experiment(data_name, num_labels, lamb, seed, euc_th = 10, wl_th = 10, thresh = 10):\n",
    "    results = []\n",
    "    X, L, labels, W_x, S_x, W_x_large, S_x_large = load_data(data_name, euc_th, wl_th, seed)\n",
    "    \n",
    "    # Accuracy approx\n",
    "    L_acc, snorkel_pred = Snorkel(L)\n",
    "    L_acc_oracle, Coverage = CheckLFs_Acc(L, labels, show = False)\n",
    "    L_acc_oracle = np.nan_to_num(L_acc_oracle)\n",
    "\n",
    "\n",
    "    # labeled points\n",
    "    labeled_inds = np.random.choice(range(W_x.shape[0]), size= num_labels, replace=False)\n",
    "    snorkel_pred[labeled_inds,:] = np.stack((1-labels[labeled_inds], labels[labeled_inds]), axis = 1)\n",
    "\n",
    "    # LPA + WL\n",
    "    LPA_WL = PropagationHard(snorkel_pred, W_x, labels, labeled_inds, alpha = 1)\n",
    "\n",
    "    # Baseline\n",
    "    results.append(list(GetStats(snorkel_pred, labels)) + ['Snorkel', data_name, num_labels, lamb, seed, euc_th])\n",
    "    results.append(list(GetStats(LPA_WL, labels)) + ['LPA + WL', data_name, num_labels, lamb, seed, euc_th])\n",
    "    \n",
    "    # LPA with dongle nodes\n",
    "    alpha_s = L_acc\n",
    "    alpha_oracle = L_acc_oracle\n",
    "    alpha_one = np.ones_like(L_acc)\n",
    "    alpha_zero = np.zeros_like(L_acc)\n",
    "\n",
    "    list_alpha = [alpha_s, alpha_oracle, alpha_one, alpha_zero]\n",
    "    list_method1 = ['Dongle + alpha_s','Dongle + alpha_*','Dongle + alpha_1','LPA']\n",
    "\n",
    "    for alpha_j, Methods in zip(list_alpha, list_method1):\n",
    "        pseudolabels = LPA_with_dongle_with_labeled_inds(W_x, L, alpha_j, labels, labeled_inds, lamb = lamb)\n",
    "        results.append(list(GetStats(pseudolabels, labels)) + [Methods, data_name, num_labels, lamb, seed, euc_th])\n",
    "\n",
    "    # Optimal weight\n",
    "    opt_s = Adaboost_weight_norm(L_acc/100, clip = 5)\n",
    "    opt_oracle = Adaboost_weight_norm(L_acc_oracle/100, clip = 5)\n",
    "    list_alpha2 = [opt_s, opt_oracle]\n",
    "    list_method2 = ['Dongle + opt_s', 'Dongle + opt_*']\n",
    "\n",
    "    for alpha_j, Methods in zip(list_alpha2, list_method2):\n",
    "        pseudolabels = LPA_with_dongle_with_labeled_inds(W_x, L, alpha_j, labels, labeled_inds, lamb = lamb)\n",
    "        results.append(list(GetStats(pseudolabels, labels)) + [Methods, data_name, num_labels, lamb, seed, euc_th])\n",
    "\n",
    "\n",
    "    # Alpha_j depends on x\n",
    "    alpha_mat_lpa = alpha_from_LPA(X,L,labels, L_acc, labeled_inds, thresh = thresh, alpha_LPA = 1)\n",
    "    alpha_mat_reg_lin = alpha_from_reg(X, L, labels, L_acc,labeled_inds, kernel = 'linear')\n",
    "    alpha_mat_reg_poly = alpha_from_reg(X, L, labels, L_acc,labeled_inds, kernel = 'polynomial')\n",
    "    alpha_mat_reg_rbf = alpha_from_reg(X, L, labels, L_acc,labeled_inds, kernel = 'rbf')\n",
    "\n",
    "    list_alpha_mat = [alpha_mat_lpa, alpha_mat_reg_lin, alpha_mat_reg_poly, alpha_mat_reg_rbf]\n",
    "    list_method3 = ['Dongle + LPA alpha', 'Dongle + Reg alpha lin','Dongle + Reg alpha poly', 'Dongle + Reg alpha rbf']\n",
    "\n",
    "    for alpha_mat, Methods in zip(list_alpha_mat, list_method3):\n",
    "        pseudolabels = LPA_with_dongle_with_labeled_inds_custom_alpha(W_x, L, alpha_mat, labels, labeled_inds, lamb = 1)\n",
    "        results.append(list(GetStats(pseudolabels, labels)) + [Methods, data_name, num_labels, lamb, seed, euc_th])\n",
    "\n",
    "\n",
    "\n",
    "    return results  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [00:00<00:00, 2108.79epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 2099.70epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 2091.88epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 2086.05epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 1936.92epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 601.47epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 684.11epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 643.62epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 700.19epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 653.93epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 2166.52epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 2088.23epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 1877.53epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 1899.16epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 2185.69epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 1431.41epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 1419.50epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 1447.97epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 1324.12epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 1453.79epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 1866.13epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 1933.56epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 1869.98epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 1871.65epoch/s]\n",
      "100%|██████████| 200/200 [00:00<00:00, 1873.24epoch/s]\n"
     ]
    }
   ],
   "source": [
    "results_labels = [] #\n",
    "for data_name in ['youtube','sms','basketball','cdr','tennis']:\n",
    "    for seed in [1,2,3,4,5]:\n",
    "        results_labels += run_experiment(data_name, num_labels = 100, lamb = 1, seed = seed, euc_th = 10, wl_th = 10, thresh = 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_labels = pd.DataFrame(results_labels, columns = ['Not abstain Acc',' Coverage','Acc','Method','Data name', 'num_labels','lambda','seed','euc_th'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>Data name</th>\n",
       "      <th>basketball</th>\n",
       "      <th>cdr</th>\n",
       "      <th>sms</th>\n",
       "      <th>tennis</th>\n",
       "      <th>youtube</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Method</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Dongle + LPA alpha</th>\n",
       "      <td>70.42</td>\n",
       "      <td>72.65</td>\n",
       "      <td>83.26</td>\n",
       "      <td>86.91</td>\n",
       "      <td>85.41</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + Reg alpha lin</th>\n",
       "      <td>80.99</td>\n",
       "      <td>65.94</td>\n",
       "      <td>83.29</td>\n",
       "      <td>87.13</td>\n",
       "      <td>87.97</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + Reg alpha poly</th>\n",
       "      <td>83.12</td>\n",
       "      <td>65.94</td>\n",
       "      <td>83.28</td>\n",
       "      <td>87.15</td>\n",
       "      <td>87.99</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + Reg alpha rbf</th>\n",
       "      <td>76.32</td>\n",
       "      <td>61.41</td>\n",
       "      <td>82.63</td>\n",
       "      <td>87.08</td>\n",
       "      <td>85.81</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + alpha_*</th>\n",
       "      <td>70.38</td>\n",
       "      <td>73.73</td>\n",
       "      <td>83.57</td>\n",
       "      <td>86.92</td>\n",
       "      <td>88.41</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + alpha_1</th>\n",
       "      <td>75.52</td>\n",
       "      <td>63.47</td>\n",
       "      <td>82.29</td>\n",
       "      <td>86.90</td>\n",
       "      <td>82.08</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + alpha_s</th>\n",
       "      <td>70.42</td>\n",
       "      <td>72.65</td>\n",
       "      <td>83.26</td>\n",
       "      <td>86.91</td>\n",
       "      <td>85.41</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + opt_*</th>\n",
       "      <td>73.12</td>\n",
       "      <td>71.09</td>\n",
       "      <td>83.20</td>\n",
       "      <td>86.90</td>\n",
       "      <td>87.77</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + opt_s</th>\n",
       "      <td>73.20</td>\n",
       "      <td>68.79</td>\n",
       "      <td>82.54</td>\n",
       "      <td>86.90</td>\n",
       "      <td>81.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>LPA</th>\n",
       "      <td>61.07</td>\n",
       "      <td>56.35</td>\n",
       "      <td>75.58</td>\n",
       "      <td>65.80</td>\n",
       "      <td>62.45</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>LPA + WL</th>\n",
       "      <td>77.20</td>\n",
       "      <td>69.05</td>\n",
       "      <td>82.45</td>\n",
       "      <td>86.50</td>\n",
       "      <td>82.38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Snorkel</th>\n",
       "      <td>56.45</td>\n",
       "      <td>70.52</td>\n",
       "      <td>70.25</td>\n",
       "      <td>85.99</td>\n",
       "      <td>83.15</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "Data name                basketball    cdr    sms  tennis  youtube\n",
       "Method                                                            \n",
       "Dongle + LPA alpha            70.42  72.65  83.26   86.91    85.41\n",
       "Dongle + Reg alpha lin        80.99  65.94  83.29   87.13    87.97\n",
       "Dongle + Reg alpha poly       83.12  65.94  83.28   87.15    87.99\n",
       "Dongle + Reg alpha rbf        76.32  61.41  82.63   87.08    85.81\n",
       "Dongle + alpha_*              70.38  73.73  83.57   86.92    88.41\n",
       "Dongle + alpha_1              75.52  63.47  82.29   86.90    82.08\n",
       "Dongle + alpha_s              70.42  72.65  83.26   86.91    85.41\n",
       "Dongle + opt_*                73.12  71.09  83.20   86.90    87.77\n",
       "Dongle + opt_s                73.20  68.79  82.54   86.90    81.00\n",
       "LPA                           61.07  56.35  75.58   65.80    62.45\n",
       "LPA + WL                      77.20  69.05  82.45   86.50    82.38\n",
       "Snorkel                       56.45  70.52  70.25   85.99    83.15"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_labels.groupby(['Data name','Method']).mean().round(2).loc[:,'Acc'].unstack().transpose()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>Data name</th>\n",
       "      <th>basketball</th>\n",
       "      <th>cdr</th>\n",
       "      <th>sms</th>\n",
       "      <th>tennis</th>\n",
       "      <th>youtube</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Method</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Dongle + LPA alpha</th>\n",
       "      <td>0.80</td>\n",
       "      <td>0.20</td>\n",
       "      <td>0.99</td>\n",
       "      <td>0.13</td>\n",
       "      <td>0.48</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + Reg alpha lin</th>\n",
       "      <td>1.47</td>\n",
       "      <td>1.16</td>\n",
       "      <td>1.22</td>\n",
       "      <td>0.37</td>\n",
       "      <td>0.42</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + Reg alpha poly</th>\n",
       "      <td>0.62</td>\n",
       "      <td>1.15</td>\n",
       "      <td>1.16</td>\n",
       "      <td>0.33</td>\n",
       "      <td>0.66</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + Reg alpha rbf</th>\n",
       "      <td>1.50</td>\n",
       "      <td>1.91</td>\n",
       "      <td>1.15</td>\n",
       "      <td>0.31</td>\n",
       "      <td>1.34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + alpha_*</th>\n",
       "      <td>0.73</td>\n",
       "      <td>0.54</td>\n",
       "      <td>0.97</td>\n",
       "      <td>0.14</td>\n",
       "      <td>0.33</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + alpha_1</th>\n",
       "      <td>0.49</td>\n",
       "      <td>1.13</td>\n",
       "      <td>1.27</td>\n",
       "      <td>0.44</td>\n",
       "      <td>0.84</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + alpha_s</th>\n",
       "      <td>0.80</td>\n",
       "      <td>0.20</td>\n",
       "      <td>0.99</td>\n",
       "      <td>0.13</td>\n",
       "      <td>0.48</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + opt_*</th>\n",
       "      <td>0.81</td>\n",
       "      <td>0.98</td>\n",
       "      <td>1.02</td>\n",
       "      <td>0.13</td>\n",
       "      <td>0.79</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Dongle + opt_s</th>\n",
       "      <td>0.69</td>\n",
       "      <td>0.63</td>\n",
       "      <td>1.13</td>\n",
       "      <td>0.13</td>\n",
       "      <td>2.28</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>LPA</th>\n",
       "      <td>0.30</td>\n",
       "      <td>0.45</td>\n",
       "      <td>1.81</td>\n",
       "      <td>0.77</td>\n",
       "      <td>1.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>LPA + WL</th>\n",
       "      <td>0.64</td>\n",
       "      <td>0.56</td>\n",
       "      <td>1.17</td>\n",
       "      <td>0.19</td>\n",
       "      <td>1.04</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Snorkel</th>\n",
       "      <td>0.25</td>\n",
       "      <td>0.57</td>\n",
       "      <td>0.85</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.25</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "Data name                basketball   cdr   sms  tennis  youtube\n",
       "Method                                                          \n",
       "Dongle + LPA alpha             0.80  0.20  0.99    0.13     0.48\n",
       "Dongle + Reg alpha lin         1.47  1.16  1.22    0.37     0.42\n",
       "Dongle + Reg alpha poly        0.62  1.15  1.16    0.33     0.66\n",
       "Dongle + Reg alpha rbf         1.50  1.91  1.15    0.31     1.34\n",
       "Dongle + alpha_*               0.73  0.54  0.97    0.14     0.33\n",
       "Dongle + alpha_1               0.49  1.13  1.27    0.44     0.84\n",
       "Dongle + alpha_s               0.80  0.20  0.99    0.13     0.48\n",
       "Dongle + opt_*                 0.81  0.98  1.02    0.13     0.79\n",
       "Dongle + opt_s                 0.69  0.63  1.13    0.13     2.28\n",
       "LPA                            0.30  0.45  1.81    0.77     1.00\n",
       "LPA + WL                       0.64  0.56  1.17    0.19     1.04\n",
       "Snorkel                        0.25  0.57  0.85    0.12     0.25"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_labels.groupby(['Data name','Method']).std().round(2).loc[:,'Acc'].unstack().transpose()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "39ca8058bc909080e5cc8b58cc5496375c0141f717bb9f1bcad8ab5053b7d3f8"
  },
  "kernelspec": {
   "display_name": "Python 3.9.7 ('nn-pruning')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
