{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "import numpy as np\n",
    "\n",
    "import os \n",
    "from tqdm.auto import tqdm\n",
    "import pandas as pd\n",
    "import sklearn\n",
    "import sklearn.metrics\n",
    "from math import nan\n",
    "\n",
    "unk_token = '<UNK>'\n",
    "\n",
    "\n",
    "\n",
    "class LIMEEvaluator:\n",
    "\n",
    "    def __init__(self, predict_lr, words, num_samples = 1000, permute = False):\n",
    "        self.predict_lr = predict_lr\n",
    "        instances = np.random.choice(3,size=(num_samples,len(words))).astype(str)\n",
    "        # 0 -> empty\n",
    "        # 1 -> unk\n",
    "        # 2 -> word\n",
    "        instances = np.where(instances=='0','',instances)\n",
    "        instances = np.where(instances=='1',unk_token,instances)\n",
    "        instances = np.where(instances=='2',words,instances)\n",
    "\n",
    "        if permute:\n",
    "            for i in range(len(instances)):\n",
    "                np.random.shuffle(instances[i])\n",
    "                \n",
    "        instances = [' '.join(x) for x in instances]\n",
    "        instances = np.array([' '.join(x.split()) for x in instances])\n",
    "        \n",
    "        preds = self.predict_lr(instances)\n",
    "\n",
    "        self.instances = instances\n",
    "        self.preds = preds\n",
    "        self.words = words\n",
    "\n",
    "\n",
    "    @staticmethod\n",
    "    def select(lime_dist,num):\n",
    "        selected = []\n",
    "        for x in lime_dist:\n",
    "            if (x[0][0],x[0][1]) not in [(y[0][0],y[0][1]) for y in selected]:\n",
    "                selected.append(x)\n",
    "            if len(selected) == num:\n",
    "                return selected\n",
    "        return selected\n",
    "    \n",
    "    @staticmethod\n",
    "    def fit(now, rules):\n",
    "        totwt = 0\n",
    "        now = now.split()\n",
    "        for (wordx, wordy, dist),wt in rules:\n",
    "            posx = -1\n",
    "            posy = -1\n",
    "            for i, x in enumerate(now):\n",
    "                if x == wordx:\n",
    "                    posx = i\n",
    "                if x == wordy:\n",
    "                    posy = i\n",
    "            if dist == -1:\n",
    "                dist -= 1\n",
    "            if wordx!=wordy:\n",
    "                if posx != -1 and posy != -1 and posy - posx > dist:\n",
    "                    totwt += wt\n",
    "            else:\n",
    "                if posx>=dist:\n",
    "                    totwt +=wt\n",
    "        return totwt\n",
    "\n",
    "    @staticmethod\n",
    "    def fit_LIME(now, limes, words):\n",
    "        totwt = 0\n",
    "        now = now.split()\n",
    "        for word,wt in limes:\n",
    "            if word in now:\n",
    "                for i,x in enumerate(now):\n",
    "                    if x == word and words[i] == x:\n",
    "                        totwt += wt\n",
    "        return totwt\n",
    "\n",
    "    def calc_acc(self, rules, limes, selectrule = False, lime_domain = False):\n",
    "        words = self.words\n",
    "        instances = self.instances\n",
    "        preds = self.preds\n",
    "        if selectrule:\n",
    "            rules = self.select(rules,len(limes))\n",
    "\n",
    "        if lime_domain:\n",
    "            cs = np.array([self.fit_LIME(x, limes, words) for x in instances]) !=  0\n",
    "        else:\n",
    "            cs = np.ones(len(instances)).astype(bool)\n",
    "        instances = instances[cs]\n",
    "        pred_rule = np.array([1 if self.fit(x, rules)>0 else 0 for x in instances])\n",
    "        pred_LIME = np.array([1 if self.fit_LIME(x, limes, words)>0 else 0 for x in instances])\n",
    "        preds = preds[cs]\n",
    "\n",
    "        try:\n",
    "            acc = [(pred_LIME == preds).sum()/len(instances),\n",
    "                    (pred_rule == preds).sum()/len(instances)]\n",
    "        except AttributeError as e:\n",
    "            print(pred_LIME,pred_rule,preds)\n",
    "            raise e\n",
    "        return acc\n",
    "    \n",
    "    def calc_AUROC(self, rules, limes, select = False, lime_domain = False):\n",
    "        words = self.words\n",
    "        instances = self.instances\n",
    "        preds = self.preds\n",
    "        if select:\n",
    "            rules = self.select(rules,len(limes))\n",
    "        if lime_domain:\n",
    "            cs = np.array([self.fit_LIME(x, limes, words) for x in instances]) !=  0\n",
    "        else:\n",
    "            cs = np.ones(len(instances)).astype(bool)\n",
    "\n",
    "        instances = instances[cs]\n",
    "        pred_rule = np.array([1 if self.fit(x, rules) else 0 for x in instances])\n",
    "        pred_LIME = np.array([1 if self.fit_LIME(x, limes, words)else 0 for x in instances])\n",
    "        preds = preds[cs]\n",
    "        fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true = preds, y_score = pred_LIME, pos_label = 1) #positive class is 1; negative class is 0\n",
    "        auroc1 = sklearn.metrics.auc(fpr, tpr)\n",
    "        fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true = preds, y_score = pred_rule, pos_label = 1) #positive class is 1; negative class is 0\n",
    "        auroc2 = sklearn.metrics.auc(fpr, tpr)\n",
    "        if auroc1 is nan:\n",
    "            auroc1 = 0.5\n",
    "        if auroc2 is nan:\n",
    "            auroc2=0.5\n",
    "        auroc = [auroc1,auroc2]  \n",
    "        return auroc\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "from transformers import BertTokenizer, AutoModelForSequenceClassification\n",
    "import numpy as np\n",
    "import torch\n",
    "device = 'cpu'\n",
    "model = AutoModelForSequenceClassification.from_pretrained('../bert-base-cased-finetuned-sst2/')\n",
    "tokenizer = BertTokenizer.from_pretrained('../bert-base-cased-finetuned-sst2/')\n",
    "\n",
    "\n",
    "# %%\n",
    "def predict(text,batch_size = 32):\n",
    "    if type(text) == str:\n",
    "        text = [text]\n",
    "    elif type(text) == np.ndarray:\n",
    "        text = text.tolist()\n",
    "    if type(text) != list:\n",
    "        raise TypeError('Input must be a string or a list of strings')\n",
    "    probs = []\n",
    "    for i in range(0,len(text),batch_size):\n",
    "        inputs = tokenizer(text[i:i+batch_size], return_tensors=\"pt\", padding=True).to(device)\n",
    "        outputs = model(**inputs).logits.softmax(1)\n",
    "        probs.append(outputs)\n",
    "    probs = torch.concat(probs)\n",
    "    return probs.cpu().detach().numpy()\n",
    "\n",
    "\n",
    "# %%\n",
    "def predict_lr(text,batch_size = 32):\n",
    "    if type(text) == str:\n",
    "        text = [text]\n",
    "    elif type(text) == np.ndarray:\n",
    "        text = text.tolist()\n",
    "    if type(text) != list:\n",
    "        raise TypeError('Input must be a string or a list of strings')\n",
    "    res_all = []\n",
    "    for i in range(0,len(text),batch_size):\n",
    "        inputs = tokenizer(text[i:i+batch_size], return_tensors=\"pt\", padding=True).to(device)\n",
    "        outputs = model(**inputs)[0]\n",
    "        res = (outputs[:,1] > outputs[:,0]).to(int).cpu().detach().numpy()\n",
    "        res_all.append(res)\n",
    "    res = np.concatenate(res_all)\n",
    "    # inputs = tokenizer(text, return_tensors=\"pt\", padding=True).to(device)\n",
    "    # outputs = model(**inputs)[0]\n",
    "    # res = (outputs[:,1] > outputs[:,0]).to(int).detach().numpy()\n",
    "    return res\n",
    "\n",
    "# %%\n",
    "\n",
    "\n",
    "# %%\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4fec90dc1aa343c1a0f1526df25083b4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1652 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.7006, 0.5948]\n"
     ]
    }
   ],
   "source": [
    "from textwrap import wrap\n",
    "\n",
    "\n",
    "pbar = tqdm(os.listdir('../Llama-2-7b/lime_res'),position=0)\n",
    "for file in pbar:\n",
    "    if not file.endswith('.txt'):\n",
    "        continue\n",
    "    f = open('../Llama-2-7b/lime_res/'+file)\n",
    "    x = f.readlines()\n",
    "    lime = eval(x[1])\n",
    "    rex = [(eval(t[0]),t[1])for t in eval(x[3])]\n",
    "    words = x[0]\n",
    "    words = words.replace(', ', ',')\n",
    "    words = words.replace(' ', ',')\n",
    "    # print(words)\n",
    "    words = eval(words)\n",
    "    evaluator = LIMEEvaluator(predict_lr=predict_lr,words=words,permute=True,num_samples=5000)\n",
    "    acc = evaluator.calc_acc(rules=rex,limes=lime)\n",
    "    pbar.write(str(acc))\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.40111809923130676, 0.5765199161425576]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluator.calc_acc(rules=rex,limes=lime,lime_domain=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.49677547226550756, 0.5]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "import sklearn\n",
    "evaluator.calc_AUROC(rules=rex,limes=lime)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 186,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.        ,  0.18234991,  0.        ,  0.        ,  0.        ,\n",
       "       -0.01587144, -0.01939538, -0.01693947,  0.        ,  0.02227669,\n",
       "       -0.01587144,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "       -0.01587144, -0.01587144,  0.        ,  0.        , -0.00245591,\n",
       "        0.        ,  0.        , -0.00352393,  0.        , -0.01587144,\n",
       "        0.        , -0.01587144,  0.        ,  0.13226209, -0.00245591,\n",
       "        0.        ,  0.        ,  0.        ,  0.        , -0.01832735,\n",
       "        0.        , -0.01587144,  0.        , -0.01587144, -0.01587144,\n",
       "        0.        , -0.01832735,  0.10218118, -0.01693947, -0.01587144,\n",
       "       -0.01587144,  0.        ,  0.        , -0.00106802,  0.        ,\n",
       "       -0.00352393,  0.1194405 ,  0.179894  , -0.01587144,  0.        ,\n",
       "       -0.00106802,  0.        , -0.01832735,  0.        , -0.01587144,\n",
       "       -0.00245591, -0.00352393, -0.01587144, -0.01587144,  0.        ,\n",
       "        0.        ,  0.        , -0.01939538, -0.01939538,  0.        ,\n",
       "        0.        , -0.01587144,  0.12433134,  0.        ,  0.10463709,\n",
       "        0.10111315, -0.01832735, -0.01832735,  0.        ,  0.11698459,\n",
       "        0.        ,  0.        ,  0.        , -0.01587144,  0.        ,\n",
       "       -0.00245591,  0.        ,  0.        ,  0.10218118,  0.11532262,\n",
       "       -0.01832735,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.10463709,  0.        ,  0.        ,  0.        ])"
      ]
     },
     "execution_count": 186,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 205,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ac0bcde0217642e7b0285ded69ccbdc4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1821 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# change dir\n",
    "os.chdir('Llama-2-7b')\n",
    "\n",
    "test = pd.read_csv('/home/outerform/github-repo/ReX/iclr2016/data/sentiment-test', header=None,\n",
    "                    sep='\\t')\n",
    "bak = sys.stdout\n",
    "# sys.stdout = open('LIME.log','w')\n",
    "pbar = tqdm(enumerate(test[0]),total=len(test[0]),position=0)\n",
    "os.makedirs('./lime_res',exist_ok=True)\n",
    "os.makedirs('./LIMEcover',exist_ok=True)\n",
    "for idx,x in pbar:\n",
    "    # print(\"text is \" + x)\n",
    "    # if len(x.split()) <= 10:\n",
    "    x = x.replace('/', ' ')\n",
    "    if not os.path.exists('./lime_res/'+x+'.txt'):\n",
    "        # process(x,'./lime_res/'+x+'.txt')\n",
    "        pass\n",
    "    else:\n",
    "        os.rename('./lime_res/'+x+'.txt',f'./lime_res/{idx}.txt')\n",
    "    # cover,precision = coverLIME.process('./lime_res/'+x+'.txt')\n",
    "    # if precision == -1:\n",
    "    #     continue\n",
    "    # pbar.set_description(\"cover: %s, precision: %s\" % (cover,precision))\n",
    "    # result = result.append({'text': x, 'result': predict_lr([x])[0], 'cover_LIME': cover[0], 'cover_ReX': cover[1], 'precision_LIME': precision[0], 'precision_ReX': precision[1]}, ignore_index=True)\n",
    "    # result.to_json('LIME.json')\n",
    "    # result.to_csv('LIME.csv')\n",
    "    # exit()\n",
    "# sys.stdout.close()\n",
    "# sys.stdout = bak\n",
    "os.chdir('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 208,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('./bert-base-cased-finetuned-sst2/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 211,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['I', 'am', 'not', 'ok']"
      ]
     },
     "execution_count": 211,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
