{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b2dc0504-791b-46e8-bc8d-90ae34d487b8",
   "metadata": {},
   "source": [
    "# DO NOT MODIFY THIS FILE!\n",
    "\n",
    "This file gives a template of how to **compute FreeShap values** of each data for finetuning a BERT model on the Movie Review dataset and **evalute** a selected MR subset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3bba6107-d879-40d4-9d2b-ac45b2876d03",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Connected to cuda:0\n"
     ]
    }
   ],
   "source": [
    "from main.utils.gpu import use_gpus_, connect_to_\n",
    "\n",
    "use_gpus_([0])\n",
    "device = connect_to_(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "382fe400-ed23-4804-8550-a9d59430188e",
   "metadata": {},
   "outputs": [],
   "source": [
    "configs = {\n",
    "    \"yaml_path\": \"main/configs/mr-bert_ntk.yaml\",\n",
    "    \"dataset_name\": \"mr\",\n",
    "    \"file_path\": \"saved_data/mr-bert/\",\n",
    "    \"num_dp\": 8530,\n",
    "    \"tmc_iter\": 1000,\n",
    "    \"approximate\": 'inv',\n",
    "    \"early_stopping\": True,\n",
    "    \"parallel\": 40,\n",
    "    \"seed\": 2023\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a6b89b88-9c2a-4385-a9e0-15d8865a3d22",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(8530, 1066)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from main.shapley.freeshap import free_shapley\n",
    "\n",
    "shapleys, ps_shapleys = free_shapley(**configs)\n",
    "ps_shapleys.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "74d654f9-380f-4ff2-a272-ba03a19b4ebf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3000000000000002"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5d0a6f8c-04d5-42b2-8b54-0cbfe0ffc649",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Constructing <class 'dataset.EasyReader'>\n",
      "Constructing <class 'dataset.ListDataset'>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/tianxiao/miniconda3/envs/S25001/lib/python3.10/site-packages/huggingface_hub/file_download.py:945: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Label 0 to word terrible (6659)\n",
      "Label 1 to word great (2307)\n",
      "label_to_word:  {0: 6659, 1: 2307}\n",
      "label_list:  [6659, 2307]\n",
      "Constructing <class 'probe.PromptFinetuneProbe'>\n",
      "Constructing PromptFinetuneProbe\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
      "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Probe has 109514298 parameters\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[loading]: 304it [00:05, 71.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mr: 8530\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[loading]: 8530it [00:06, 1385.45it/s]\n",
      "[loading]: 1066it [00:03, 274.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mr: 1066\n",
      "Constructing <class 'dataset.EasyReader'>\n",
      "Constructing <class 'dataset.ListDataset'>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Label 0 to word terrible (6659)\n",
      "Label 1 to word great (2307)\n",
      "label_to_word:  {0: 6659, 1: 2307}\n",
      "label_list:  [6659, 2307]\n",
      "Constructing <class 'probe.PromptFinetuneProbe'>\n",
      "Constructing PromptFinetuneProbe\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
      "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Probe has 109514298 parameters\n",
      "Probe has 109515836 parameters\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/tianxiao/miniconda3/envs/S25001/lib/python3.10/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='630' max='630' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [630/630 00:18, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>0.264400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.124900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>150</td>\n",
       "      <td>0.094900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.047000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>250</td>\n",
       "      <td>0.031100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.014700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>350</td>\n",
       "      <td>0.010100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.000900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>450</td>\n",
       "      <td>0.007000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.005100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>550</td>\n",
       "      <td>0.003500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.000100</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='17' max='17' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [17/17 00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 1.0286288261413574, 'eval_accuracy': 83.67729831144464, 'eval_runtime': 0.7846, 'eval_samples_per_second': 1358.626, 'eval_steps_per_second': 21.667, 'epoch': 10.0}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from main.evaluate.mr import evaluate_mr_subset\n",
    "\n",
    "n_selected = 1000\n",
    "indices = np.argsort(shapleys)[::-1][:n_selected]\n",
    "\n",
    "results = evaluate_mr_subset(indices, device=device, seed=configs[\"seed\"])\n",
    "print(results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "s25001",
   "language": "python",
   "name": "s25001"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
