{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/d/repos/AttentionMD/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer\n",
    "import math\n",
    "import matplotlib\n",
    "from transformers import AutoTokenizer\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"tokenizer\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_tensor(sentence):\n",
    "    tokens = tokenizer(sentence)\n",
    "    return torch.tensor([tokens[\"input_ids\"]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_attention_map(model, x):\n",
    "    attn_maps = []\n",
    "\n",
    "    # Embed\n",
    "    B, T = x.size()\n",
    "    mask = (x != 0)\n",
    "    x = model.emb_static(x) + model.emb_pos(torch.arange(0, T))\n",
    "    \n",
    "    for block in model.blocks:\n",
    "        # Get the attention map\n",
    "        B, T, C = x.size()\n",
    "        qkv = block.attn.attn_matrix(x)\n",
    "        q, k, v = qkv.split(block.attn.n_embd, dim=2)\n",
    "        k = k.view(B, T, block.attn.n_head, C // block.attn.n_head).transpose(1, 2)\n",
    "        q = q.view(B, T, block.attn.n_head, C // block.attn.n_head).transpose(1, 2)\n",
    "        v = v.view(B, T, block.attn.n_head, C // block.attn.n_head).transpose(1, 2)\n",
    "        scale_factor = 1 / math.sqrt(q.size(-1))\n",
    "        attn_weight = q @ k.transpose(-2, -1) * scale_factor\n",
    "        attn_weight = torch.softmax(attn_weight, dim=-1)\n",
    "        attn_maps.append(attn_weight)\n",
    "        \n",
    "        # Pass through block\n",
    "        x = block(x, mask, \"cpu\")\n",
    "        \n",
    "    return torch.cat(attn_maps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def colorize(words, color_array):\n",
    "    # words is a list of words\n",
    "    # color_array is an array of numbers between 0 and 1 of length equal to words\n",
    "    cmap = matplotlib.cm.get_cmap('Greens')\n",
    "    template = '<span class=\"barcode\" style=\"color: black; background-color: {};\">{}</span>'\n",
    "    colored_string = ''\n",
    "    for word, color in zip(words, color_array):\n",
    "        color = matplotlib.colors.rgb2hex(cmap(color)[:3])\n",
    "        colored_string += template.format(color, '&nbsp;' + word + '&nbsp;')\n",
    "    return colored_string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_models(attn_maps1, attn_maps2, sentence, selected):\n",
    "    # Average attention scores over the sequence\n",
    "    avg_attention_scores1 = attn_maps1[-1, :, 0, :].mean(dim=[0]).detach().numpy()\n",
    "    avg_attention_scores2 = attn_maps2[-1, :, 0, :].mean(dim=[0]).detach().numpy()\n",
    "\n",
    "    # Normalize the attention scores\n",
    "    normalized_attention_scores1 = (\n",
    "        (avg_attention_scores1 - avg_attention_scores1.min())\n",
    "        / (avg_attention_scores1.max() - avg_attention_scores1.min())\n",
    "    )\n",
    "    normalized_attention_scores2 = (\n",
    "        (avg_attention_scores2 - avg_attention_scores2.min())\n",
    "        / (avg_attention_scores2.max() - avg_attention_scores2.min())\n",
    "    )\n",
    "\n",
    "    # Tokenize the text for display\n",
    "    tokens = tokenizer.convert_ids_to_tokens(tokenizer(sentence)[\"input_ids\"])\n",
    "    for word in selected:\n",
    "        if word not in tokens:\n",
    "            return False, None, None, None, None\n",
    "    indices = [idx for idx in range(len(tokens)) if tokens[idx] in selected]\n",
    "\n",
    "    # Remove special tokens like [CLS] and [SEP]\n",
    "    tokens = tokens[1:-1]\n",
    "    normalized_attention_scores1 = normalized_attention_scores1[1:-1]\n",
    "    normalized_attention_scores2 = normalized_attention_scores2[1:-1]\n",
    "\n",
    "    colored_string1 = colorize(tokens, normalized_attention_scores1)\n",
    "    colored_string2 = colorize(tokens, normalized_attention_scores2)\n",
    "    \n",
    "    return (\n",
    "        True,\n",
    "        colored_string1,\n",
    "        colored_string2,\n",
    "        [normalized_attention_scores1[idx-1] for idx in indices],\n",
    "        [normalized_attention_scores2[idx-1] for idx in indices]\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences = [\n",
    "    \"The movie was fantastic\",\n",
    "    \"I hated the movie\",\n",
    "    \"The plot was boring\",\n",
    "    \"I love this movie\",\n",
    "    \"The plot was terrible\",\n",
    "    \"This movie is great\",\n",
    "    \"The scenes were dirty\",\n",
    "    \"I'm satisfied with movie\",\n",
    "    \"The DVD arrived late\",\n",
    "    \"The subtitles work perfectly\",\n",
    "    \"The movie was disappointing\",\n",
    "    \"I enjoyed the movie\",\n",
    "    \"The pacing is unreliable\",\n",
    "    \"The cast were friendly\",\n",
    "    \"The script is slow\",\n",
    "    \"The movie was great\",\n",
    "    \"The DVD was poor\",\n",
    "    \"The plot was fascinating\",\n",
    "    \"The set was sturdy\",\n",
    "    \"The cinematography was ruined\",\n",
    "    \"The documentary was engaging\",\n",
    "    \"The DVD crashes often\",\n",
    "    \"The scenes were delicious\",\n",
    "    \"The DVD broke down\",\n",
    "    \"The scenery was breathtaking\",\n",
    "    \"The service was prompt\",\n",
    "    \"The plot was predictable\",\n",
    "    \"The tickets overpriced\",\n",
    "    \"The service was excellent\",\n",
    "    \"The projector overheats\",\n",
    "    \"The theater is scenic\",\n",
    "    \"The projector stopped\",\n",
    "    \"The festival was vibrant\",\n",
    "    \"The popcorn runs out\",\n",
    "    \"The movie was fun\",\n",
    "    \"The screening was delayed\",\n",
    "    \"The impact was pleasant\",\n",
    "    \"The streaming is unstable\",\n",
    "    \"The snacks are fresh\",\n",
    "    \"The DVD cracked\",\n",
    "    \"The theater has selection\",\n",
    "    \"The interface is difficult\",\n",
    "    \"The cinema is spacious\",\n",
    "    \"The equipment broke\",\n",
    "    \"The staff are friendly\",\n",
    "    \"The seats are uncomfortable\",\n",
    "    \"The movie was heavenly\",\n",
    "    \"The equipment is outdated\",\n",
    "    \"The theater is well-kept\",\n",
    "    \"The plot was confusing\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "good_tokens = [\n",
    "    set({\"fantastic\"}),\n",
    "    set({\"hated\"}),\n",
    "    set({\"boring\"}),\n",
    "    set({\"love\"}),\n",
    "    set({\"terrible\"}),\n",
    "    set({\"great\"}),\n",
    "    set({\"dirty\"}),\n",
    "    set({\"satisfied\"}),\n",
    "    set({\"late\"}),\n",
    "    set({\"perfectly\"}),\n",
    "    set({\"disappointing\"}),\n",
    "    set({\"excellent\"}),\n",
    "    set({\"unreliable\"}),\n",
    "    set({\"friendly\"}),\n",
    "    set({\"slow\"}),\n",
    "    set({\"great\"}),\n",
    "    set({\"poor\"}),\n",
    "    set({\"fascinating\"}),\n",
    "    set({\"sturdy\"}),\n",
    "    set({\"ruined\"}),\n",
    "    set({\"engaging\"}),\n",
    "    set({\"crashes\"}),\n",
    "    set({\"delicious\"}),\n",
    "    set({\"broke\"}),\n",
    "    set({\"breathtaking\"}),\n",
    "    set({\"prompt\"}),\n",
    "    set({\"predictable\"}),\n",
    "    set({\"overpriced\"}),\n",
    "    set({\"excellent\"}),\n",
    "    set({\"overheats\"}),\n",
    "    set({\"scenic\"}),\n",
    "    set({\"stopped\"}),\n",
    "    set({\"vibrant\"}),\n",
    "    set({\"quickly\"}),\n",
    "    set({\"fun\"}),\n",
    "    set({\"delayed\"}),\n",
    "    set({\"pleasant\"}),\n",
    "    set({\"unstable\"}),\n",
    "    set({\"fresh\"}),\n",
    "    set({\"cracked\"}),\n",
    "    set({\"selection\"}),\n",
    "    set({\"difficult\"}),\n",
    "    set({\"spacious\"}),\n",
    "    set({\"broke\"}),\n",
    "    set({\"friendly\"}),\n",
    "    set({\"uncomfortable\"}),\n",
    "    set({\"heavenly\"}),\n",
    "    set({\"outdated\"}),\n",
    "    set({\"happy\"}),\n",
    "    set({\"confusing\"})\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "positive_review = [\n",
    "    True,\n",
    "    False,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    True,\n",
    "    False,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False,\n",
    "    True,\n",
    "    False\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_18466/2057364446.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  result = torch.load(\"results/1.1/masked_first/1.pt\", map_location=torch.device('cpu'))\n",
      "/tmp/ipykernel_18466/2057364446.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  result = torch.load(\"results/2/masked_first/1.pt\", map_location=torch.device('cpu'))\n",
      "/tmp/ipykernel_18466/260022292.py:4: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.\n",
      "  cmap = matplotlib.cm.get_cmap('Greens')\n"
     ]
    }
   ],
   "source": [
    "result = torch.load(\"results/1.1/6/1.pt\", map_location=torch.device('cpu'))\n",
    "model_1 = result[\"model\"]\n",
    "result = torch.load(\"results/2/6/1.pt\", map_location=torch.device('cpu'))\n",
    "model_2 = result[\"model\"]\n",
    "colored1 = []\n",
    "colored2 = []\n",
    "score_list1 = []\n",
    "score_list2 = []\n",
    "good_sentences = []\n",
    "important = []\n",
    "review_rating = []\n",
    "\n",
    "for idx in range(len(sentences)):\n",
    "    sentence = sentences[idx]\n",
    "    x = to_tensor(sentence)\n",
    "    attn_maps_1 = get_attention_map(model_1, x)\n",
    "    attn_maps_2 = get_attention_map(model_2, x)\n",
    "    (\n",
    "        good_parse,\n",
    "        highlighted1,\n",
    "        highlighted2,\n",
    "        scores1,\n",
    "        scores2,\n",
    "    ) = compare_models(attn_maps_1, attn_maps_2, sentence, good_tokens[idx])\n",
    "    if not good_parse:\n",
    "        continue\n",
    "    colored1.append(highlighted1)\n",
    "    colored2.append(highlighted2)\n",
    "    score_list1.append(scores1[0])\n",
    "    score_list2.append(scores2[0])\n",
    "    good_sentences.append(sentence)\n",
    "    important.append(list(good_tokens[idx])[0])\n",
    "    review_rating.append(positive_review[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup for HTML table from W3School and stack overflow\n",
    "# https://www.w3schools.com/html/html_tables.asp\n",
    "# https://tex.stackexchange.com/questions/23804/how-to-incorporate-tex-mathematics-into-a-website\n",
    "\n",
    "string = \"\"\"<!DOCTYPE html>\n",
    "<html>\n",
    "<head>\n",
    "<style>\n",
    "table {\n",
    "  font-family: arial, sans-serif;\n",
    "  border-collapse: collapse;\n",
    "  width: 50%;\n",
    "}\n",
    "\n",
    "td, th {\n",
    "  border: 1px solid #dddddd;\n",
    "  text-align: left;\n",
    "  padding: 8px;\n",
    "}\n",
    "\n",
    "tr:nth-child(even) {\n",
    "  background-color: #dddddd;\n",
    "}\n",
    "</style>\n",
    "<script type=\"text/x-mathjax-config\">\n",
    "  MathJax.Hub.Config({tex2jax: {inlineMath: [['$','$'], ['\\\\(','\\\\)']]}});\n",
    "</script>\n",
    "<script type=\"text/javascript\"\n",
    "  src=\"http://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML\">\n",
    "</script>\n",
    "</head>\n",
    "<body>\n",
    "\n",
    "<table>\n",
    "  <tr>\n",
    "    <th>Label</th>\n",
    "    <th>Optimal Token</th>\n",
    "    <th>$\\ell_{1.1}$-MD Token Selection</th>\n",
    "    <th>GD Token Selection</th>\n",
    "    <th>Better Selector</th>\n",
    "  </tr>\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx in range(40):\n",
    "    string += f\"\"\"\n",
    "  <tr>\n",
    "    <td>{\"+\" if review_rating[idx] else \"-\"}</td>\n",
    "    <td>{important[idx]}</td>\n",
    "    <td>{colored1[idx]}</td>\n",
    "    <td>{colored2[idx]}</td>\n",
    "    <td>{\"=\" if score_list1[idx] == score_list2[idx] else \"1.1\" if score_list1[idx] > score_list2[idx] else \"2\"}</td>\n",
    "  </tr>\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "string += \"\"\"</table>\n",
    "\n",
    "</body>\n",
    "</html>\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = open(\"results/img/attention_map.html\", \"w\")\n",
    "f.write(string)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Score for 1.1-MD = 0.4387154281139374\n"
     ]
    }
   ],
   "source": [
    "print(f\"Score for 1.1-MD = {torch.tensor(score_list1).mean().item()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Score for GD = 0.3122487962245941\n"
     ]
    }
   ],
   "source": [
    "print(f\"Score for GD = {torch.tensor(score_list2).mean().item()}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
