{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import argparse\n",
    "import glob\n",
    "import math\n",
    "import pandas as pd\n",
    "from scipy.special import softmax\n",
    "import scipy.stats as stats\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import display, HTML\n",
    "from captum.attr import visualization\n",
    "\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "import datasets\n",
    "from datasets import load_dataset, load_metric \n",
    "from datasets import list_datasets, list_metrics\n",
    "\n",
    "from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification\n",
    "import pickle\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import random\n",
    "import torch.backends.cudnn as cudnn\n",
    "\n",
    "def seed_everything(seed):\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = True\n",
    "    \n",
    "ran_seed = 41\n",
    "seed_everything(ran_seed) # Seed 고정"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model   \n",
    "model = BertForSequenceClassification.from_pretrained(\"textattack/bert-base-uncased-SST-2\").to(\"cuda\")\n",
    "model.eval()\n",
    "\n",
    "# load tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"textattack/bert-base-uncased-SST-2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset parquet (/root/.cache/huggingface/datasets/parquet/sst2-70a988cc7ad74849/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "efd9eddc65df4258aec9e4652c694870",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset_name = \"sst2\"\n",
    "dataset = load_dataset(dataset_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "special_tokens = {\"[CLS]\", \"[SEP]\"}\n",
    "special_idxs = {101,102}    \n",
    "mask = \"[PAD]\"\n",
    "mask_id = 0   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_sample(text):\n",
    "    tokenized_input  = tokenizer(text, add_special_tokens=True, truncation=True)\n",
    "    input_ids = tokenized_input['input_ids']\n",
    "    text_ids = (torch.tensor([input_ids])).to(\"cuda\")\n",
    "    text_words = tokenizer.convert_ids_to_tokens(text_ids[0])\n",
    "    \n",
    "    # mask special tokens\n",
    "    att_mask = tokenized_input['attention_mask']\n",
    "    spe_idxs = [x for x, y in list(enumerate(input_ids)) if y in special_idxs]\n",
    "    att_mask = [0 if index in spe_idxs else 1 for index, item in enumerate(att_mask)]\n",
    "    att_mask = [0 if index in spe_idxs else 1 for index, item in enumerate(att_mask)]\n",
    "    att_mask = (torch.tensor([att_mask])).to(\"cuda\")\n",
    "    \n",
    "    return text_ids, att_mask, text_words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(model, text_ids, target, att_mask=None, seg_ids=None):\n",
    "    out = model(text_ids, attention_mask=att_mask, token_type_ids=seg_ids)\n",
    "    prob = out[0]\n",
    "    pred_class = torch.argmax(prob, axis=1).cpu().detach().numpy()\n",
    "    pred_class_prob = softmax(prob.cpu().detach().numpy(), axis=1)\n",
    "    return pred_class[0], pred_class_prob[:, target][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def truncate_words(sorted_idx, text_words, text_ids, replaced_num, seg_ids=None):\n",
    "    to_be_replaced_idx = []\n",
    "    i= 0\n",
    "    while len(to_be_replaced_idx) < replaced_num and i!=len(text_words)-1:\n",
    "        current_idx = sorted_idx[i]\n",
    "        if text_words[current_idx] not in special_tokens:\n",
    "            to_be_replaced_idx.append(current_idx)\n",
    "        i += 1\n",
    "    remaining_idx = sorted(list(set(sorted_idx) - set(to_be_replaced_idx)))\n",
    "    truncated_text_ids = text_ids[0, np.array(remaining_idx)]\n",
    "    if seg_ids is not None:\n",
    "        seg_ids = seg_ids[0, np.array(remaining_idx)]\n",
    "    truncated_text_words = np.array(text_words)[remaining_idx]\n",
    "    return truncated_text_ids.unsqueeze(0), truncated_text_words, seg_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "class Generator:\n",
    "    def __init__(self, model,lib):\n",
    "        \n",
    "        self.lib = lib\n",
    "        self.model = model\n",
    "        self.model.eval()\n",
    "\n",
    "    def forward(self, input_ids, attention_mask):\n",
    "        return self.model(input_ids, attention_mask)\n",
    "\n",
    "    def generate_ContrastCAT(self, input_ids, attention_mask,index=None, start_layer=0, text_words=None,text_ids=None, original_prob = None):\n",
    "\n",
    "        result = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)\n",
    "        \n",
    "        output = result[0]\n",
    "        hs = result[1]\n",
    "        seq_length = hs[0].shape[1]\n",
    "\n",
    "        kwargs = {\"alpha\": 1}\n",
    "\n",
    "        blocks = self.model.bert.encoder.layer\n",
    "\n",
    "        for blk_id in range(len(blocks)):\n",
    "            hs[blk_id].retain_grad()\n",
    "\n",
    "        if index == None:\n",
    "            index = np.argmax(output.cpu().data.numpy(), axis=-1)\n",
    "        \n",
    "        one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)\n",
    "        one_hot[0, index] = 1\n",
    "        one_hot_vector = one_hot\n",
    "        one_hot = torch.from_numpy(one_hot).requires_grad_(True)\n",
    "        one_hot = torch.sum(one_hot.cuda() * output)\n",
    "\n",
    "        self.model.zero_grad()\n",
    "        one_hot.backward(retain_graph=True)\n",
    "\n",
    "        output_cam = []\n",
    "        \n",
    "        if len(self.lib[index]) < 30:\n",
    "            t_len = len(self.lib[index])\n",
    "        else:\n",
    "            t_len = 30\n",
    "        \n",
    "        for reference_hs in self.lib[index][:t_len] :\n",
    "\n",
    "            cams = {}        \n",
    "            for blk_id in range(len(blocks)):\n",
    "                \n",
    "                hs_grads = hs[blk_id].grad.detach().cpu()\n",
    "\n",
    "                #################################\n",
    "                att = blocks[blk_id].attention.self.get_attn().squeeze(0).detach().cpu()\n",
    "                att = att.mean(dim=0)\n",
    "                att = att.mean(dim=0)\n",
    "                #################################\n",
    "\n",
    "                #################################\n",
    "                ### Load reference ###\n",
    "                ref_hs = reference_hs['activation'][blk_id]\n",
    "                ref_hs_seq_length = ref_hs.shape[1]\n",
    "\n",
    "                #################################\n",
    "                if ref_hs_seq_length >= seq_length:\n",
    "                    reference = ref_hs[:,:seq_length,:]\n",
    "                else:\n",
    "                    pad_length = seq_length - ref_hs_seq_length\n",
    "                    reference = F.pad(input=ref_hs, pad=(0, 0, 0, pad_length), mode='constant', value=0)\n",
    "\n",
    "                activation = (hs[blk_id].detach().cpu() - reference)\n",
    "\n",
    "                cat = (hs_grads * activation).sum(dim=-1).squeeze(0)\n",
    "                cat = cat * att\n",
    "                #################################\n",
    "                \n",
    "                cams[blk_id] = cat\n",
    "\n",
    "            expln = sum(cams.values())\n",
    "            \n",
    "            #################################\n",
    "            ######## Min-Max scaling ########\n",
    "            min_v = expln.min()\n",
    "            max_v = expln.max()\n",
    "            \n",
    "            numerator = expln - min_v\n",
    "            denominator = max_v - min_v\n",
    "\n",
    "            if denominator == 0:\n",
    "                expln = numerator/ (denominator + 1e-6)\n",
    "            else:\n",
    "                expln = numerator/ denominator\n",
    "            #################################            \n",
    "            \n",
    "            output_cam.append(expln)\n",
    "            \n",
    "        cam = torch.stack(output_cam)\n",
    "        \n",
    "        total_len = len(text_words)\n",
    "        granularity = np.linspace(0, 1, 10)\n",
    "        trunc_words_num = [int(g) for g in np.round(granularity*total_len)]\n",
    "        trunc_words_num = list(dict.fromkeys(trunc_words_num))\n",
    "\n",
    "        ######################################################\n",
    "        ############## Reference filtering ################\n",
    "        descending_sorted_idx = torch.argsort(-cam,dim=1).detach().cpu().numpy()\n",
    "        \n",
    "        per_cam_trunc_proba_info = []\n",
    "        for sorted_idx in descending_sorted_idx:\n",
    "            trunc_proba_list = []\n",
    "\n",
    "            for num in trunc_words_num[1:]: #exclude 0\n",
    "\n",
    "                truncated_text_ids_libra, t, _ = truncate_words(sorted_idx, text_words, text_ids, num, seg_ids=None)\n",
    "                trunc_class_libra, trunc_prob_libra = predict(model, truncated_text_ids_libra, index, seg_ids=None)\n",
    "\n",
    "                diff = original_prob - trunc_prob_libra\n",
    "                trunc_proba_list.append( diff )\n",
    "                \n",
    "            per_cam_trunc_proba_info.append(trunc_proba_list)\n",
    "        ############################################################\n",
    "        per_cam_trunc_proba_info = np.array(per_cam_trunc_proba_info)\n",
    "        mean_per_cam_trunc_proba_info = per_cam_trunc_proba_info.mean(axis=1)\n",
    "\n",
    "        m_trunc = mean_per_cam_trunc_proba_info.mean()\n",
    "        std_trunc = mean_per_cam_trunc_proba_info.std()\n",
    "\n",
    "        threshold = m_trunc + std_trunc\n",
    "        filterd_idx = (mean_per_cam_trunc_proba_info >= threshold)\n",
    "        \n",
    "        if filterd_idx.sum() != 0:\n",
    "            trunc_choice_cam = cam[filterd_idx]\n",
    "            trunc_choice_cam = trunc_choice_cam.mean(dim=0)\n",
    "        else:\n",
    "            trunc_choice_cam = cam.mean(dim=0)\n",
    "        ############################################################\n",
    "        \n",
    "        ############################################################\n",
    "        ######## Min-Max scaling ########\n",
    "        min_v = trunc_choice_cam.min()\n",
    "        max_v = trunc_choice_cam.max()\n",
    "\n",
    "        numerator = trunc_choice_cam - min_v\n",
    "        denominator = max_v - min_v\n",
    "\n",
    "        if denominator == 0:\n",
    "            trunc_choice_cam = numerator/ (denominator + 1e-6)\n",
    "        else:\n",
    "            trunc_choice_cam = numerator/ denominator\n",
    "        ############################################################\n",
    "        \n",
    "        return trunc_choice_cam\n",
    "    ################################################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_explns(explanations, input_ids, attention_mask, start_layer=0, true_class = 1, text_words=None, text_ids=None, proba = None):\n",
    "    \n",
    "    Libra = explanations.generate_ContrastCAT(input_ids=input_ids, attention_mask=attention_mask, \n",
    "                                                index=true_class, start_layer=start_layer,text_words=text_words,text_ids=text_ids, original_prob = proba)\n",
    "    return Libra"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.7/site-packages/transformers/modeling_utils.py:867: FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.\n",
      "  \"The `device` argument is deprecated and will be removed in v5 of Transformers.\", FutureWarning\n"
     ]
    }
   ],
   "source": [
    "seed_everything(ran_seed)\n",
    "to_test = np.array(dataset['validation'])\n",
    "\n",
    "libra = np.load('./act_lib/sst.npy',allow_pickle=True)\n",
    "explanations = Generator(model,libra)\n",
    "\n",
    "for i, test_instance in enumerate(to_test):\n",
    "\n",
    "    text = test_instance['sentence']\n",
    "    target = test_instance['label'] \n",
    "\n",
    "    text_ids, att_mask, text_words = preprocess_sample(text)\n",
    "\n",
    "    result = model(text_ids, attention_mask=None, token_type_ids=None)\n",
    "    prob = result[0]\n",
    "\n",
    "    pred_class_prob = softmax(prob.cpu().detach().numpy(), axis=1)\n",
    "    pred_class = torch.argmax(prob, axis=1).cpu().detach().numpy()[0]\n",
    "    original_prob = pred_class_prob[:, pred_class][0]\n",
    "\n",
    "    # get attributions\n",
    "    Libra_cat = \\\n",
    "    generate_explns(explanations, text_ids, att_mask, start_layer=0, true_class = pred_class, text_words=text_words, text_ids=text_ids, proba = original_prob) \n",
    "    \n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[  101,  2009,  1005,  1055,  1037, 11951,  1998,  2411, 12473,  4990,\n",
      "          1012,   102]])\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlsAAABRCAYAAADlwWM+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAIpklEQVR4nO3da4xcZR3H8e+PXUuhRChWEYEABoI0GgM0iJAQwiVCMNREJZBoioHUFyJITLRqogmvqjFeXjQmDaBECWDwwqoERC7xhYZQAcOl1lYU2lpouYOXQsvfFzuYddmlsjNnzs7M95Ns5lyenuefJ9uZ357nzDmpKiRJktSMvdouQJIkaZgZtiRJkhpk2JIkSWqQYUuSJKlBhi1JkqQGGbYkSZIa1FXYSnJgktuTbOy8Lp6l3e4kD3R+JrrpU5IkaZCkm/tsJfkG8ExVrU6yClhcVV+cod1LVbVfF3VKkiQNpG7D1gbgtKraluRg4O6qOmaGdoYtSZI0kroNW89V1QGd5QDPvrY+rd0u4AFgF7C6qn4+y/FWAisBxjJ+wqLxGWcl57Xdixa0XcKcjb24s+0S5qR27267hDnLwr3bLmFOaq/BvdwzO19uu4Q5qYWD+96SXQP6f/TVV9uuYM6OPvaFtkuYk43r39p2CXP2wis7nqqqt8+0b3xP/zjJb4B3zrDrK1NXqqqSzJbcDq+qrUneDdyZ5MGq+sv0RlW1FlgLsP+Cd9TJS87fU3nzzosnHd52CXO2390b2i5hTnY/93zbJczZ2FGvOxE8EF7dd3A/+PfatKXtEuZk9zGHtV3CnI1vH8wPfv75r7YrmLNf3XZb2yXMybnHf6jtEubs1m1rHptt3x7DVlWdOdu+JE8mOXjKNOL2WY6xtfP6aJK7geOA14UtSZKkYdPtXMAEsKKzvAK4eXqDJIuT7N1ZXgKcAjzSZb+SJEkDoduwtRo4K8lG4MzOOkmWJbmq0+ZYYF2SPwJ3MXnNlmFLkiSNhD1OI76RqnoaOGOG7euASzrLvwPe100/kiRJg2pwv1IkSZI0AAxbkiRJDTJsSZIkNciwJUmS1CDDliRJUoMMW5IkSQ0ybEmSJDXIsCVJktQgw5YkSVKDDFuSJEkNMmxJkiQ1yLAlSZLUIMOWJElSgwxbkiRJDepJ2EpydpINSTYlWTXD/r2T3NjZf0+SI3rRryRJ0nzXddhKMgasAc4BlgIXJlk6rdnFwLNVdRTwbeDr3fYrSZI0CHpxZutEYFNVPVpVLwM3AMuntVkOXNtZvgk4I0l60LckSdK81ouwdQiwecr6ls62GdtU1S7geeBtPehbkiRpXhtvu4CpkqwEVgIsHNuv5WokSZK614szW1uBw6asH9rZNmObJOPA/sDT0w9UVWurallVLVuw1z49KE2SJKldvQhb9wJHJzkyyQLgAmBiWpsJYEVn+WPAnVVVPehbkiRpXut6GrGqdiW5FLgNGAOuqaqHk1wJrKuqCeBq4IdJNgHPMBnIJEmShl5PrtmqqluAW6Zt++qU5X8DH+9FX5IkSYPEO8hLkiQ1yLAlSZLUIMOWJElSgwxbkiRJDTJsSZIkNciwJUmS1CDDliRJUoMMW5IkSQ0ybEmSJDXIsCVJktQgw5YkSVKDDFuSJEkNMmxJkiQ1yLAlSZLUIMOWJElSg3oStpKcnWRDkk1JVs2w/6IkO5I80Pm5pBf9SpIkzXfj3R4gyRiwBjgL2ALcm2Siqh6Z1vTGqrq02/4kSZIGSS/ObJ0IbKqqR6vqZeAGYHkPjitJkjTwuj6zBRwCbJ6yvgX4wAztPprkVODPwBVVtXl6gyQrgZWd1Zdu3bZmQw/qm80S4KmeH/VnPT/iMGlmzAfZQ4334Jj3XzNj/vueH3GY+Hs+zdjBjXfR0Jhv6v0h++fw2Xb0Imz9P34BXF9VO5N8GrgWOH16o6paC6ztR0FJ1lXVsn70pUmOef855v3nmPefY95/jvmb04tpxK3AYVPWD+1s+6+qerqqdnZWrwJO6EG/kiRJ814vwta9wNFJjkyyALgAmJjaIMnUE5rnAet70K8kSdK81/U0YlXtSnIpcBswBlxTVQ8nuRJYV1UTwGVJzgN2Ac8AF3Xbbw/0ZbpS/8Mx7z/HvP8c8/5zzPvPMX8TUlVt1yBJkjS0vIO8JElSgwxbkiRJDRq5sLWnRwupt5IcluSuJI8keTjJ5W3XNCqSjCW5P8kv265lFCQ5IMlNSf6UZH2SD7Zd07BLckXnfeWhJNcnWdh2TcMoyTVJtid5aMq2A5PcnmRj53VxmzXOdyMVtqY8WugcYClwYZKl7VY19HYBn6+qpcBJwGcc8765HL/520/fBW6tqvcA78exb1SSQ4DLgGVV9V4mv6B1QbtVDa0fAGdP27YKuKOqjgbu6KxrFiMVtvDRQn1XVduq6r7O8otMfgAd0m5Vwy/JocC5TN7XTg1Lsj9wKnA1QFW9XFXPtVvVSBgH9kkyDuwL/L3leoZSVf2WyTsJTLWcyRuU03n9SF+LGjCjFrZmerSQH/x9kuQI4DjgnnYrGQnfAb4AvNp2ISPiSGAH8P3O1O1VSRa1XdQwq6qtwDeBx4FtwPNV9et2qxopB1XVts7yE8BBbRYz341a2FJLkuwH/AT4XFW90HY9wyzJh4HtVfWHtmsZIePA8cD3quo44B84rdKozjVCy5kMuu8CFiX5RLtVjaaavIeU95F6A6MWtvb4aCH1XpK3MBm0rquqn7Zdzwg4BTgvyd+YnCo/PcmP2i1p6G0BtlTVa2dtb2IyfKk5ZwJ/raodVfUK8FPg5JZrGiVPvvZ0mM7r9pbrmddGLWzt8dFC6q0kYfI6lvVV9a226xkFVfWlqjq0qo5g8nf8zqryL/4GVdUTwOYkx3Q2nQE80mJJo+Bx4KQk+3beZ87ALyX00wSworO8Ari5xVrmva4f1zNIZnu0UMtlDbtTgE8CDyZ5oLPty1V1S4s1SU34LHBd5w+5R4FPtVzPUKuqe5LcBNzH5Lee78dHyDQiyfXAacCSJFuArwGrgR8nuRh4DDi/vQrnPx/XI0mS1KBRm0aUJEnqK8OWJElSgwxbkiRJDTJsSZIkNciwJUmS1CDDliRJUoMMW5IkSQ36D2qjWCtXzCfwAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 720x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(text_ids.detach().cpu())\n",
    "\n",
    "plt.figure(figsize=(10,4))\n",
    "plt.imshow(Libra_cat.unsqueeze(dim=0).numpy())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
