{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import argparse\n",
    "import glob\n",
    "import math\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy.special import softmax\n",
    "import scipy.stats as stats\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import display, HTML\n",
    "from captum.attr import visualization\n",
    "\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "import datasets\n",
    "from datasets import load_dataset, load_metric \n",
    "from datasets import list_datasets, list_metrics\n",
    "\n",
    "from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification\n",
    "\n",
    "import os\n",
    "import torch.backends.cudnn as cudnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model   \n",
    "model = BertForSequenceClassification.from_pretrained(\"textattack/bert-base-uncased-SST-2\").to(\"cuda\")\n",
    "model.eval()\n",
    "\n",
    "# load tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"textattack/bert-base-uncased-SST-2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = \"sst2\"\n",
    "dataset = load_dataset(dataset_name, split='train', streaming=True) # load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "special_tokens = {\"[CLS]\", \"[SEP]\"}\n",
    "special_idxs = {101,102}    \n",
    "mask = \"[PAD]\"\n",
    "mask_id = 0   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_sample(text):\n",
    "    tokenized_input  = tokenizer(text, add_special_tokens=True, truncation=True)\n",
    "    input_ids = tokenized_input['input_ids']\n",
    "    text_ids = (torch.tensor([input_ids])).to(\"cuda\")\n",
    "    text_words = tokenizer.convert_ids_to_tokens(text_ids[0])\n",
    "    \n",
    "    # mask special tokens\n",
    "    att_mask = tokenized_input['attention_mask']\n",
    "    spe_idxs = [x for x, y in list(enumerate(input_ids)) if y in special_idxs]\n",
    "    att_mask = [0 if index in spe_idxs else 1 for index, item in enumerate(att_mask)]\n",
    "    att_mask = [0 if index in spe_idxs else 1 for index, item in enumerate(att_mask)]\n",
    "    att_mask = (torch.tensor([att_mask])).to(\"cuda\")\n",
    "    \n",
    "    return text_ids, att_mask, text_words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.7/site-packages/transformers/modeling_utils.py:867: FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.\n",
      "  \"The `device` argument is deprecated and will be removed in v5 of Transformers.\", FutureWarning\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "\n",
    "\n",
    "library = [[],[]]\n",
    "for i, test_instance in enumerate(dataset):\n",
    "\n",
    "    text = test_instance['sentence']\n",
    "    target = test_instance['label'] \n",
    "\n",
    "    text_ids, att_mask, text_words = preprocess_sample(text)\n",
    "\n",
    "    # get truc words number\n",
    "    total_len = len(text_words)\n",
    "    if total_len< 10: \n",
    "        continue\n",
    "\n",
    "    result = model(input_ids = text_ids, attention_mask = None, output_hidden_states=True)\n",
    "\n",
    "    prob = result[0]\n",
    "    hs = result[1]\n",
    "    \n",
    "    pred_class = torch.argmax(prob, axis=1).cpu().detach().numpy().squeeze()\n",
    "    pred_class_prob = softmax(prob.cpu().detach().numpy(), axis=1).squeeze()\n",
    "    \n",
    "    if target != pred_class.item() :\n",
    "        continue\n",
    "        \n",
    "    if target == 0 :\n",
    "        lib_target = 1\n",
    "    else:\n",
    "        lib_target = 0\n",
    "        \n",
    "    cpu_hs = tuple(h.detach().cpu() for h in hs)\n",
    "    temp_lib_dict = {'target_cls_confi':pred_class_prob[lib_target], 'activation':cpu_hs}\n",
    "    \n",
    "    library[lib_target].append(temp_lib_dict)\n",
    "    library[lib_target].sort(key=lambda x: x['target_cls_confi'],reverse=False)\n",
    "    \n",
    "    if len(library[lib_target]) > 500:\n",
    "        library[lib_target].pop()\n",
    "        \n",
    "np.save('./act_lib/sst.npy',np.array(library))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
