{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1d54d8cb-d4a9-494a-96ea-c2101440d807",
   "metadata": {},
   "source": [
    "## Compute BoW Importances offline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7b14b4fe-b47f-4ab2-9453-e3104ceab1b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7fc3350e-1e41-4650-8ac8-c4af0634edd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from slalom_explanations.attribution_methods import get_groundtruth_importance, BoW, NaiveBayesEstim\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6c03249d-c1a5-4862-8cd3-d52f674de8c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "def compute_ref_importances(gt_list, use_dataset, tokenizer, max_seq_len) -> dict:\n",
    "    \"\"\" Compute reference importance scores \"\"\"\n",
    "    # define BoW model and train it\n",
    "    if len(gt_list) == 0:\n",
    "        return {}\n",
    "\n",
    "    print(f\"Tokenizing with max_seq_len = {max_seq_len}\")\n",
    "    bow = BoW(ds=use_dataset, tokenizer=tokenizer, max_seq_len=max_seq_len)\n",
    "    importances = {}\n",
    "    for mygt in gt_list:\n",
    "        print(\"getting ground_truth for model\", mygt)\n",
    "        bow_svm = get_groundtruth_importance(mygt, bow)\n",
    "        importances[mygt] = deepcopy(bow_svm.get_importance())\n",
    "\n",
    "    print(\"Got reference importances.\")\n",
    "    return importances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e7754e3d-399f-4b63-badf-3006cc45877b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "self.tokenizer: BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={\n",
      "\t0: AddedToken(\"[PAD]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t100: AddedToken(\"[UNK]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t101: AddedToken(\"[CLS]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t102: AddedToken(\"[SEP]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t103: AddedToken(\"[MASK]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "} type(self.tokenizer: <class 'transformers.models.bert.tokenization_bert_fast.BertTokenizerFast'>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (936 > 512). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5000, 21377)\n",
      "self.tokenizer: BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={\n",
      "\t0: AddedToken(\"[PAD]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t100: AddedToken(\"[UNK]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t101: AddedToken(\"[CLS]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t102: AddedToken(\"[SEP]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t103: AddedToken(\"[MASK]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "} type(self.tokenizer: <class 'transformers.models.bert.tokenization_bert_fast.BertTokenizerFast'>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (936 > 512). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5000, 21377)\n",
      "self.tokenizer: GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={\n",
      "\t50256: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
      "} type(self.tokenizer: <class 'transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast'>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (1088 > 1024). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5000, 23669)\n",
      "self.tokenizer: RobertaTokenizerFast(name_or_path='FacebookAI/roberta-base', vocab_size=50265, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={\n",
      "\t0: AddedToken(\"<s>\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
      "\t1: AddedToken(\"<pad>\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
      "\t2: AddedToken(\"</s>\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
      "\t3: AddedToken(\"<unk>\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
      "\t50264: AddedToken(\"<mask>\", rstrip=False, lstrip=True, single_word=False, normalized=False, special=True),\n",
      "} type(self.tokenizer: <class 'transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast'>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (914 > 512). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5000, 23669)\n",
      "self.tokenizer: BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={\n",
      "\t0: AddedToken(\"[PAD]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t100: AddedToken(\"[UNK]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t101: AddedToken(\"[CLS]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t102: AddedToken(\"[SEP]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t103: AddedToken(\"[MASK]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "} type(self.tokenizer: <class 'transformers.models.bert.tokenization_bert_fast.BertTokenizerFast'>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (977 > 512). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5000, 14979)\n",
      "self.tokenizer: BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={\n",
      "\t0: AddedToken(\"[PAD]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t100: AddedToken(\"[UNK]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t101: AddedToken(\"[CLS]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t102: AddedToken(\"[SEP]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "\t103: AddedToken(\"[MASK]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
      "} type(self.tokenizer: <class 'transformers.models.bert.tokenization_bert_fast.BertTokenizerFast'>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (977 > 512). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5000, 14979)\n",
      "self.tokenizer: GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={\n",
      "\t50256: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
      "} type(self.tokenizer: <class 'transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast'>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (1178 > 1024). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5000, 16709)\n",
      "self.tokenizer: RobertaTokenizerFast(name_or_path='FacebookAI/roberta-base', vocab_size=50265, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={\n",
      "\t0: AddedToken(\"<s>\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
      "\t1: AddedToken(\"<pad>\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
      "\t2: AddedToken(\"</s>\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
      "\t3: AddedToken(\"<unk>\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
      "\t50264: AddedToken(\"<mask>\", rstrip=False, lstrip=True, single_word=False, normalized=False, special=True),\n",
      "} type(self.tokenizer: <class 'transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast'>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (963 > 512). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5000, 16709)\n"
     ]
    }
   ],
   "source": [
    "from copy import deepcopy\n",
    "for dataset_name in [\"imdb\", \"yelp\"]:\n",
    "    if dataset_name == \"imdb\":\n",
    "        imdb = load_dataset('imdb').with_format('torch', device=\"cpu\") # format to pytorch tensors, but leave data on cpu\n",
    "        imdb[\"train\"] = imdb[\"train\"].shuffle(seed=42).select(range(5000))\n",
    "        imdb[\"test\"] = imdb[\"test\"].shuffle(seed=42).select(range(20000))\n",
    "        dataset = imdb\n",
    "    elif dataset_name == \"yelp\":\n",
    "        yelp = load_dataset('yelp_polarity').with_format('torch', device='cpu')\n",
    "        yelp[\"train\"] = yelp[\"train\"].shuffle(seed=42).select(range(5000))\n",
    "        yelp[\"test\"] = yelp[\"test\"].shuffle(seed=42).select(range(20000))\n",
    "        dataset = yelp\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown dataset {config.dataset}.\")\n",
    "    for model_type in [\"distilbert\", \"bert\", \"gpt2\", \"roberta\"]:\n",
    "        if model_type == \"gpt2\":\n",
    "            tokenizer = AutoTokenizer.from_pretrained('gpt2', use_fast=True, padding=512)\n",
    "            if tokenizer.pad_token is None:\n",
    "                tokenizer.pad_token = tokenizer.eos_token\n",
    "        elif model_type == \"distilbert\":\n",
    "            tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', use_fast=True, padding=512)\n",
    "            use_cls = True\n",
    "        elif model_type == \"roberta\":\n",
    "            tokenizer = AutoTokenizer.from_pretrained('FacebookAI/roberta-base', use_fast=True, padding=512)\n",
    "            use_cls = True\n",
    "        elif model_type == \"bert\":\n",
    "            tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', use_fast=True, padding=512)\n",
    "            \n",
    "        bow = BoW(ds=dataset, tokenizer=tokenizer)\n",
    "        bow_nb_mult =  NaiveBayesEstim(bow, multiplicities=True)\n",
    "        importances_nb_mult = deepcopy(bow_nb_mult.get_signed_importance())\n",
    "        torch.save({\"nb\": importances_nb_mult}, f\"ground_truth/gt_{model_type}_{dataset_name}.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "209d1518-7ac0-4e88-998c-1031d5e95517",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "transformer",
   "language": "python",
   "name": "transformer"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
