{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:23:46.662753Z",
     "iopub.status.busy": "2025-06-14T15:23:46.662044Z",
     "iopub.status.idle": "2025-06-14T15:23:46.669143Z",
     "shell.execute_reply": "2025-06-14T15:23:46.668549Z",
     "shell.execute_reply.started": "2025-06-14T15:23:46.662711Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"7\"\n",
    "# os.environ[\"TRANSFORMERS_NO_TF\"] = \"1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:25:23.932688Z",
     "iopub.status.busy": "2025-06-14T15:25:23.932174Z",
     "iopub.status.idle": "2025-06-14T15:27:40.943503Z",
     "shell.execute_reply": "2025-06-14T15:27:40.942642Z",
     "shell.execute_reply.started": "2025-06-14T15:25:23.932663Z"
    }
   },
   "outputs": [],
   "source": [
    "# %pip install git+https://github.com/neelnanda-io/Easy-Transformer.git@clean-transformer-demo\n",
    "# # Install another version of node that makes PySvelte work way faster\n",
    "# !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n",
    "\n",
    "# %pip install fancy_einsum\n",
    "# %pip install einops\n",
    "# %pip install ekphrasis\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
    "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
    "execution": {
     "iopub.execute_input": "2025-06-14T15:27:40.945504Z",
     "iopub.status.busy": "2025-06-14T15:27:40.945232Z",
     "iopub.status.idle": "2025-06-14T15:27:52.653758Z",
     "shell.execute_reply": "2025-06-14T15:27:52.653068Z",
     "shell.execute_reply.started": "2025-06-14T15:27:40.945479Z"
    }
   },
   "outputs": [],
   "source": [
    "import einops\n",
    "from fancy_einsum import einsum\n",
    "from dataclasses import dataclass\n",
    "from easy_transformer import EasyTransformer\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import math\n",
    "from easy_transformer.utils import get_corner, gelu_new, tokenize_and_concatenate\n",
    "import tqdm.auto as tqdm\n",
    "\n",
    "\n",
    "import datasets\n",
    "import transformers\n",
    "import plotly.express as px\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "\n",
    "import os\n",
    "os.environ['CUDA_LAUNCH_BLOCKING'] = \"1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:27:52.654629Z",
     "iopub.status.busy": "2025-06-14T15:27:52.654422Z",
     "iopub.status.idle": "2025-06-14T15:27:52.877798Z",
     "shell.execute_reply": "2025-06-14T15:27:52.877175Z",
     "shell.execute_reply.started": "2025-06-14T15:27:52.654612Z"
    }
   },
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:27:52.879681Z",
     "iopub.status.busy": "2025-06-14T15:27:52.879211Z",
     "iopub.status.idle": "2025-06-14T15:27:56.794846Z",
     "shell.execute_reply": "2025-06-14T15:27:56.794250Z",
     "shell.execute_reply.started": "2025-06-14T15:27:52.879663Z"
    }
   },
   "outputs": [],
   "source": [
    "from ekphrasis.classes.preprocessor import TextPreProcessor\n",
    "from ekphrasis.classes.tokenizer import SocialTokenizer\n",
    "from ekphrasis.dicts.emoticons import emoticons\n",
    "#from transformers import BertTokenizer\n",
    "import string \n",
    "import re\n",
    "import spacy\n",
    "nlp2 = spacy.load('en_core_web_sm')\n",
    "from spacy.symbols import ORTH,NORM,LEMMA\n",
    "import string \n",
    "from spacy.lang.char_classes import LIST_PUNCT, LIST_ELLIPSES, LIST_QUOTES, LIST_CURRENCY\n",
    "from spacy.lang.char_classes import LIST_ICONS, HYPHENS, CURRENCY, UNITS\n",
    "from spacy.lang.char_classes import CONCAT_QUOTES, ALPHA_LOWER, ALPHA_UPPER, ALPHA, PUNCT\n",
    "from spacy.util import compile_infix_regex, compile_prefix_regex, compile_suffix_regex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:27:56.796196Z",
     "iopub.status.busy": "2025-06-14T15:27:56.795587Z",
     "iopub.status.idle": "2025-06-14T15:28:19.275500Z",
     "shell.execute_reply": "2025-06-14T15:28:19.274605Z",
     "shell.execute_reply.started": "2025-06-14T15:27:56.796167Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-12-10 00:32:45.613918: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Moving model to device:  cuda\n",
      "Finished loading pretrained model gpt2-small into EasyTransformer!\n"
     ]
    }
   ],
   "source": [
    "reference_gpt2 = EasyTransformer.from_pretrained(\"gpt2-small\", fold_ln=False, center_unembed=False, center_writing_weights=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:19.277563Z",
     "iopub.status.busy": "2025-06-14T15:28:19.276955Z",
     "iopub.status.idle": "2025-06-14T15:28:19.283671Z",
     "shell.execute_reply": "2025-06-14T15:28:19.282738Z",
     "shell.execute_reply.started": "2025-06-14T15:28:19.277544Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12, n_classes=3)\n"
     ]
    }
   ],
   "source": [
    "@dataclass\n",
    "class Config:\n",
    "    d_model: int = 768\n",
    "    debug: bool = True\n",
    "    layer_norm_eps: float = 1e-5\n",
    "    d_vocab: int = 50257\n",
    "    init_range: float = 0.02\n",
    "    n_ctx: int = 1024\n",
    "    d_head: int = 64\n",
    "    d_mlp: int = 3072\n",
    "    n_heads: int = 12\n",
    "    n_layers: int = 12\n",
    "    n_classes: int = 3\n",
    "\n",
    "cfg = Config()\n",
    "print(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:19.284751Z",
     "iopub.status.busy": "2025-06-14T15:28:19.284516Z",
     "iopub.status.idle": "2025-06-14T15:28:19.300483Z",
     "shell.execute_reply": "2025-06-14T15:28:19.299939Z",
     "shell.execute_reply.started": "2025-06-14T15:28:19.284726Z"
    }
   },
   "outputs": [],
   "source": [
    "class LayerNorm(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.cfg = cfg\n",
    "        self.w = nn.Parameter(torch.ones(cfg.d_model))\n",
    "        self.b = nn.Parameter(torch.zeros(cfg.d_model))\n",
    "\n",
    "    def forward(self, residual):\n",
    "        # residual: [batch, position, d_model]\n",
    "        \"YOUR CODE HERE\"\n",
    "\n",
    "        residual = residual - einops.reduce(residual, \"batch position d_model -> batch position 1\",reduction=\"mean\")\n",
    "        scale = (einops.reduce(residual.pow(2),\"batch position d_model -> batch position 1\",reduction=\"mean\" ) + + cfg.layer_norm_eps).sqrt()\n",
    "\n",
    "        outputs = residual/scale\n",
    "\n",
    "        outputs = outputs*self.w + self.b\n",
    "\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:19.301629Z",
     "iopub.status.busy": "2025-06-14T15:28:19.301354Z",
     "iopub.status.idle": "2025-06-14T15:28:19.317462Z",
     "shell.execute_reply": "2025-06-14T15:28:19.316806Z",
     "shell.execute_reply.started": "2025-06-14T15:28:19.301608Z"
    }
   },
   "outputs": [],
   "source": [
    "class Embed(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.cfg = cfg\n",
    "        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))\n",
    "        nn.init.normal_(self.W_E, std=self.cfg.init_range)\n",
    "\n",
    "    def forward(self, tokens):\n",
    "        # tokens: [batch, position]\n",
    "        \"YOUR CODE HERE\"\n",
    "        if self.cfg.debug: print(\"Tokens:\", tokens.shape)\n",
    "        embeddings = self.W_E[tokens,:]\n",
    "        if self.cfg.debug: print(\"Embeddings:\", embeddings.shape)\n",
    "\n",
    "        return embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Positional Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:19.318456Z",
     "iopub.status.busy": "2025-06-14T15:28:19.318226Z",
     "iopub.status.idle": "2025-06-14T15:28:19.331066Z",
     "shell.execute_reply": "2025-06-14T15:28:19.330339Z",
     "shell.execute_reply.started": "2025-06-14T15:28:19.318429Z"
    }
   },
   "outputs": [],
   "source": [
    "class PosEmbed(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.cfg = cfg\n",
    "        self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))\n",
    "        nn.init.normal_(self.W_pos, std=self.cfg.init_range)\n",
    "\n",
    "    def forward(self, tokens):\n",
    "        \"YOUR CODE HERE\"\n",
    "        if self.cfg.debug: print(\"Tokens:\", tokens.shape)\n",
    "        pos_embed = self.W_pos[:tokens.size(1), :] # [position, d_model]\n",
    "        pos_embed = einops.repeat(pos_embed, \"position d_model -> batch position d_model\", batch=tokens.size(0))\n",
    "        if self.cfg.debug: print(\"POS Embeddings:\", pos_embeddings.shape)\n",
    "        return pos_embed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:19.333704Z",
     "iopub.status.busy": "2025-06-14T15:28:19.333507Z",
     "iopub.status.idle": "2025-06-14T15:28:33.258515Z",
     "shell.execute_reply": "2025-06-14T15:28:33.257650Z",
     "shell.execute_reply.started": "2025-06-14T15:28:19.333690Z"
    }
   },
   "outputs": [],
   "source": [
    "# reference_text = \"I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!\"\n",
    "# tokens = reference_gpt2.to_tokens(reference_text)\n",
    "# tokens = tokens.cuda()\n",
    "# logits, cache = reference_gpt2.run_with_cache(tokens)\n",
    "# print(logits.shape)\n",
    "# pysvelte.AttentionMulti(tokens=reference_gpt2.to_str_tokens(reference_text), attention=cache['blocks.0.attn.hook_attn'][0].permute(1, 2, 0)).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:33.260736Z",
     "iopub.status.busy": "2025-06-14T15:28:33.259957Z",
     "iopub.status.idle": "2025-06-14T15:28:33.272247Z",
     "shell.execute_reply": "2025-06-14T15:28:33.271460Z",
     "shell.execute_reply.started": "2025-06-14T15:28:33.260711Z"
    }
   },
   "outputs": [],
   "source": [
    "class Attention(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.cfg = cfg\n",
    "        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))\n",
    "        nn.init.normal_(self.W_Q, std=self.cfg.init_range)\n",
    "        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))\n",
    "        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))\n",
    "        nn.init.normal_(self.W_K, std=self.cfg.init_range)\n",
    "        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))\n",
    "        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))\n",
    "        nn.init.normal_(self.W_V, std=self.cfg.init_range)\n",
    "        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))\n",
    "\n",
    "        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))\n",
    "        nn.init.normal_(self.W_O, std=self.cfg.init_range)\n",
    "        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))\n",
    "\n",
    "        self.register_buffer(\"IGNORE\", torch.tensor(-1e5, dtype=torch.float32, device=\"cuda\"))\n",
    "\n",
    "    def forward(self, normalized_resid_pre):\n",
    "        # normalized_resid_pre: [batch, position, d_model]\n",
    "        \"YOUR CODE HERE\"\n",
    "\n",
    "        q = einsum(\"batch position d_model, n_heads d_model d_head -> batch position n_heads d_head\",normalized_resid_pre,self.W_Q) + self.b_Q\n",
    "\n",
    "        k = einsum(\"batch position d_model, n_heads d_model d_head -> batch position n_heads d_head\",normalized_resid_pre,self.W_K) + self.b_K\n",
    "        v = einsum(\"batch position d_model, n_heads d_model d_head -> batch position n_heads d_head\",normalized_resid_pre,self.W_V) + self.b_V\n",
    "\n",
    "        score = einsum(\"batch qposition n_heads d_head, batch kposition n_heads d_head -> batch n_heads qposition kposition\",q,k)\n",
    "        score = score / math.sqrt(self.cfg.d_head)\n",
    "        causal_score = self.apply_causal_mask(score)\n",
    "        attn = torch.nn.Softmax(dim=-1)(causal_score)\n",
    "\n",
    "        context = einsum(\"batch n_heads qposition kposition, batch kposition n_heads d_head -> batch qposition n_heads d_head\",attn,v)\n",
    "\n",
    "        context = einsum(\"batch position n_heads d_head, n_heads d_head d_model -> batch position d_model\",context,self.W_O) + self.b_O\n",
    "\n",
    "        return context,attn\n",
    "    def apply_causal_mask(self, attn_scores):\n",
    "        # attn_scores: [batch, n_heads, query_pos, key_pos]\n",
    "        \"YOUR CODE HERE\"\n",
    "        mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device), diagonal=1).bool()\n",
    "\n",
    "        #print(mask,self.IGNORE)\n",
    "        attn_scores.masked_fill_(mask, self.IGNORE)\n",
    "        return attn_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:33.273950Z",
     "iopub.status.busy": "2025-06-14T15:28:33.273520Z",
     "iopub.status.idle": "2025-06-14T15:28:33.334054Z",
     "shell.execute_reply": "2025-06-14T15:28:33.333174Z",
     "shell.execute_reply.started": "2025-06-14T15:28:33.273919Z"
    }
   },
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.cfg = cfg\n",
    "        self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))\n",
    "        nn.init.normal_(self.W_in, std=self.cfg.init_range)\n",
    "        self.b_in = nn.Parameter(torch.zeros((cfg.d_mlp)))\n",
    "        self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))\n",
    "        nn.init.normal_(self.W_out, std=self.cfg.init_range)\n",
    "        self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))\n",
    "\n",
    "    def forward(self, normalized_resid_mid):\n",
    "        # normalized_resid_mid: [batch, position, d_model]\n",
    "        \"YOUR CODE HERE\"\n",
    "        outputs = gelu_new(einsum(\"batch position d_model, d_model d_mlp -> batch position d_mlp\",normalized_resid_mid,self.W_in) + self.b_in)\n",
    "        outputs = einsum(\"batch position d_mlp, d_mlp d_model -> batch position d_model\",outputs,self.W_out) + self.b_out\n",
    "        return outputs\n",
    "\n",
    "# rand_float_test(MLP, [2, 4, 768])\n",
    "# load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache[\"blocks.0.ln2.hook_normalized\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transformer Block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:33.335277Z",
     "iopub.status.busy": "2025-06-14T15:28:33.334985Z",
     "iopub.status.idle": "2025-06-14T15:28:33.351617Z",
     "shell.execute_reply": "2025-06-14T15:28:33.351013Z",
     "shell.execute_reply.started": "2025-06-14T15:28:33.335251Z"
    }
   },
   "outputs": [],
   "source": [
    "class TransformerBlock(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.cfg = cfg\n",
    "\n",
    "        self.ln1 = LayerNorm(cfg)\n",
    "        self.attn = Attention(cfg)\n",
    "        self.ln2 = LayerNorm(cfg)\n",
    "        self.mlp = MLP(cfg)\n",
    "\n",
    "    def forward(self, resid_pre):\n",
    "        # resid_pre [batch, position, d_model]\n",
    "        \"YOUR CODE HERE\"\n",
    "        outputs,attn = self.attn(resid_pre) \n",
    "        outputs = outputs + resid_pre  #self.attn(self.ln1(resid_pre)) + resid_pre\n",
    "        outputs = self.mlp(outputs)+outputs #self.mlp(self.ln2(outputs))+outputs\n",
    "        return outputs,attn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Unembedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:33.352670Z",
     "iopub.status.busy": "2025-06-14T15:28:33.352397Z",
     "iopub.status.idle": "2025-06-14T15:28:33.366616Z",
     "shell.execute_reply": "2025-06-14T15:28:33.366086Z",
     "shell.execute_reply.started": "2025-06-14T15:28:33.352646Z"
    }
   },
   "outputs": [],
   "source": [
    "class Unembed(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.cfg = cfg\n",
    "        self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))\n",
    "        nn.init.normal_(self.W_U, std=self.cfg.init_range)\n",
    "        self.b_U = nn.Parameter(torch.zeros((cfg.d_vocab), requires_grad=True))\n",
    "\n",
    "    def forward(self, normalized_resid_final):\n",
    "        # normalized_resid_final [batch, position, d_model]\n",
    "        \"YOUR CODE HERE\"\n",
    "        outputs = einsum(\"batch position d_model, d_model d_vocab -> batch position d_vocab\", normalized_resid_final,self.W_U) + self.b_U\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:33.367711Z",
     "iopub.status.busy": "2025-06-14T15:28:33.367386Z",
     "iopub.status.idle": "2025-06-14T15:28:33.381842Z",
     "shell.execute_reply": "2025-06-14T15:28:33.381165Z",
     "shell.execute_reply.started": "2025-06-14T15:28:33.367687Z"
    }
   },
   "outputs": [],
   "source": [
    "class Classification_Head(nn.Module):\n",
    "    def __init__(self,cfg):\n",
    "        super().__init__()\n",
    "        self.cfg = cfg\n",
    "        self.W_U = nn.Parameter(torch.empty((cfg.d_model,cfg.n_classes)))\n",
    "        nn.init.normal_(self.W_U,std = self.cfg.init_range)\n",
    "        self.b_U = nn.Parameter(torch.zeros((cfg.n_classes),requires_grad=True))\n",
    "    def forward(self, normalized_resid_final):\n",
    "        # normalized_resid_final [batch, position, d_model]\n",
    "        \"YOUR CODE HERE\"\n",
    "        outputs = einsum(\"batch d_model, d_model n_classes -> batch n_classes\", normalized_resid_final,self.W_U) + self.b_U\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Decoder based Transformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:33.382944Z",
     "iopub.status.busy": "2025-06-14T15:28:33.382665Z",
     "iopub.status.idle": "2025-06-14T15:28:33.397996Z",
     "shell.execute_reply": "2025-06-14T15:28:33.397202Z",
     "shell.execute_reply.started": "2025-06-14T15:28:33.382914Z"
    }
   },
   "outputs": [],
   "source": [
    "class DemoTransformer(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.cfg = cfg\n",
    "        self.embed = Embed(cfg)\n",
    "        self.pos_embed = PosEmbed(cfg)\n",
    "        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])\n",
    "        self.ln_final = LayerNorm(cfg)\n",
    "        #self.unembed = Unembed(cfg)\n",
    "        self.cls_head = Classification_Head(cfg)\n",
    "\n",
    "    def forward(self, tokens):\n",
    "        # tokens [batch, position]\n",
    "        \"YOUR CODE HERE\"\n",
    "        embed = self.embed(tokens)\n",
    "        pos_embed = self.pos_embed(tokens)\n",
    "        residual = embed + pos_embed\n",
    "        attention_per_block = []\n",
    "        for block in self.blocks:\n",
    "            residual,attn = block(residual)\n",
    "            attention_per_block.append(attn)\n",
    "        #print(residual.shape)\n",
    "        normalized_resid_final = residual #self.ln_final(residual)\n",
    "        pad_indices = (tokens==0 ).int().argmax(dim=1)\n",
    "\n",
    "        #pad_mask = (torch.arange(tokens.size(1), device=device).unsqueeze(0) <= pad_indices.unsqueeze(1)).float()\n",
    "\n",
    "        \n",
    "        outputs = normalized_resid_final[np.arange(normalized_resid_final.size(0)),pad_indices,:]\n",
    "        #print(outputs.shape)\n",
    "        #normalized_resid_final = einsum(\"batch position dmodel, batch position -> batch position dmodel\", normalized_resid_final, pad_mask)\n",
    "        #outputs = einops.reduce(normalized_resid_final,\"batch position dmodel -> batch dmodel\",reduction=\"sum\"  ) /einops.reduce(pad_mask,\"batch position -> batch 1\",reduction=\"sum\")\n",
    "        \n",
    "        \n",
    "        \n",
    "        outputs = self.cls_head(outputs)\n",
    "        \n",
    "        return outputs,attention_per_block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:33.398980Z",
     "iopub.status.busy": "2025-06-14T15:28:33.398725Z",
     "iopub.status.idle": "2025-06-14T15:28:33.415170Z",
     "shell.execute_reply": "2025-06-14T15:28:33.414380Z",
     "shell.execute_reply.started": "2025-06-14T15:28:33.398955Z"
    }
   },
   "outputs": [],
   "source": [
    "def lm_cross_entropy_loss(logits, tokens):\n",
    "    # Measure next token loss\n",
    "    # Logits have shape [batch, position, d_vocab]\n",
    "    # Tokens have shape [batch, position]\n",
    "    log_probs = logits.log_softmax(dim=-1)\n",
    "    pred_log_probs = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)\n",
    "    return -pred_log_probs.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:33.416551Z",
     "iopub.status.busy": "2025-06-14T15:28:33.416212Z",
     "iopub.status.idle": "2025-06-14T15:28:58.900966Z",
     "shell.execute_reply": "2025-06-14T15:28:58.900325Z",
     "shell.execute_reply.started": "2025-06-14T15:28:33.416532Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading twitter - 1grams ...\n",
      "Reading twitter - 2grams ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/cs22d010/anaconda3/envs/gpt_env/lib/python3.10/site-packages/ekphrasis/classes/tokenizer.py:225: FutureWarning: Possible nested set at position 2190\n",
      "  self.tok = re.compile(r\"({})\".format(\"|\".join(pipeline)))\n",
      "/home/cs22d010/anaconda3/envs/gpt_env/lib/python3.10/site-packages/ekphrasis/classes/exmanager.py:14: FutureWarning: Possible nested set at position 42\n",
      "  regexes = {k.lower(): re.compile(self.expressions[k]) for k, v in\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading english - 1grams ...\n"
     ]
    }
   ],
   "source": [
    "##### text preprocessor for ekphrasis\n",
    "text_processor = TextPreProcessor(\n",
    "    # terms that will be normalized\n",
    "    normalize=['url', 'email', 'percent', 'money', 'phone', 'user',\n",
    "        'time', 'date', 'number'],\n",
    "    # terms that will be annotated\n",
    "    fix_html=True,  # fix HTML tokens\n",
    "    annotate={\"hashtag\", \"allcaps\", \"elongated\", \"repeated\",\n",
    "        'emphasis', 'censored'},\n",
    "    # corpus from which the word statistics are going to be used \n",
    "    # for word segmentation \n",
    "    segmenter=\"twitter\", \n",
    "    \n",
    "    # corpus from which the word statistics are going to be used \n",
    "    # for spell correction\n",
    "    #corrector=\"twitter\", \n",
    "    \n",
    "    unpack_hashtags=True,  # perform word segmentation on hashtags\n",
    "    unpack_contractions=True,  # Unpack contractions (can't -> can not)\n",
    "    spell_correct_elong=False,  # spell correction for elongated words\n",
    "    \n",
    "    # select a tokenizer. You can use SocialTokenizer, or pass your own\n",
    "    # the tokenizer, should take as input a string and return a list of tokens\n",
    "    tokenizer=SocialTokenizer(lowercase=True).tokenize,\n",
    "    \n",
    "    # list of dictionaries, for replacing tokens extracted from the text,\n",
    "    # with other expressions. You can pass more than one dictionaries.\n",
    "    dicts=[emoticons]\n",
    ")\n",
    "#### Bert tokenizer\n",
    "def custom_tokenize(sent,tokenizer,max_length=512):\n",
    "    # `encode` will:\n",
    "    #   (1) Tokenize the sentence.\n",
    "    #   (2) Prepend the `[CLS]` token to the start.\n",
    "    #   (3) Append the `[SEP]` token to the end.\n",
    "    #   (4) Map tokens to their IDs.\n",
    "    try:\n",
    "\n",
    "        encoded_sent = tokenizer.encode(\n",
    "                            sent,                      # Sentence to encode.\n",
    "                            add_special_tokens = False, # Add '[CLS]' and '[SEP]'\n",
    "                            #max_length = max_length,\n",
    "                            # This function also supports truncation and conversion\n",
    "                            # to pytorch tensors, but we need to do padding, so we\n",
    "                            # can't use these features :( .\n",
    "                            #max_length = 128,          # Truncate all sentences.\n",
    "                            #return_tensors = 'pt',     # Return pytorch tensors.\n",
    "                       )\n",
    "\n",
    "        # Add the encoded sentence to the list.\n",
    "\n",
    "    except ValueError:\n",
    "        encoded_sent = tokenizer.encode(\n",
    "                            ' ',                      # Sentence to encode.\n",
    "                            add_special_tokens = False, # Add '[CLS]' and '[SEP]'\n",
    "                            max_length = max_length,\n",
    "                    \n",
    "                       )\n",
    "          ### decide what to later\n",
    "\n",
    "    return encoded_sent\n",
    "\n",
    "\n",
    "#input: text\n",
    "#process: ekphrasis preprocesser + some extra processing  \n",
    "#output: list of tokens      \n",
    "def ek_extra_preprocess(text,tokenizer):\n",
    "    remove_words=['<allcaps>','</allcaps>','<hashtag>','</hashtag>','<elongated>','<emphasis>','<repeated>','\\'','s']\n",
    "    word_list=text_processor.pre_process_doc(text)\n",
    "\n",
    "    word_list=list(filter(lambda a: a not in remove_words, word_list)) \n",
    "    sent=\" \".join(word_list)\n",
    "    sent = re.sub(r\"[<\\*>]\", \" \",sent)\n",
    "    sub_word_list = custom_tokenize(sent,tokenizer)\n",
    "    return sub_word_list\n",
    "\n",
    "\n",
    "#input: text\n",
    "#process: remove html tags  \n",
    "#output: text with no html tags\n",
    "def cleanhtml(raw_html):\n",
    "    cleanr = re.compile('<.*?>')\n",
    "    cleantext = re.sub(cleanr, '', raw_html)\n",
    "    return cleantext\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "##### Preprocessing queries for raw text not needed for implementation\n",
    "special_cases = {}\n",
    "# Times\n",
    "for h in range(1, 12 + 1):\n",
    "    for period in [\"a.m.\", \"am\"]:\n",
    "        special_cases[\"%d%s\" % (h, period)] = [\n",
    "            {ORTH: \"%d\" % h},\n",
    "            {ORTH: period, LEMMA: \"a.m.\", NORM: \"a.m.\"},\n",
    "        ]\n",
    "    for period in [\"p.m.\", \"pm\"]:\n",
    "        special_cases[\"%d%s\" % (h, period)] = [\n",
    "            {ORTH: \"%d\" % h},\n",
    "            {ORTH: period, LEMMA: \"p.m.\", NORM: \"p.m.\"},\n",
    "        ]\n",
    "        \n",
    "for orth in [\n",
    "        \"a.m.\",\n",
    "        \"Adm.\",\n",
    "        \"Bros.\",\n",
    "        \"co.\",\n",
    "        \"Co.\",\n",
    "        \"Corp.\",\n",
    "        \"D.C.\",\n",
    "        \"Dr.\",\n",
    "        \"e.g.\",\n",
    "        \"E.g.\",\n",
    "        \"E.G.\",\n",
    "        \"Gen.\",\n",
    "        \"Gov.\",\n",
    "        \"i.e.\",\n",
    "        \"I.e.\",\n",
    "        \"I.E.\",\n",
    "        \"Inc.\",\n",
    "        \"Jr.\",\n",
    "        \"Ltd.\",\n",
    "        \"Md.\",\n",
    "        \"Messrs.\",\n",
    "        \"Mo.\",\n",
    "        \"Mont.\",\n",
    "        \"Mr.\",\n",
    "        \"Mrs.\",\n",
    "        \"Ms.\",\n",
    "        \"p.m.\",\n",
    "        \"Ph.D.\",\n",
    "        \"Prof.\",\n",
    "        \"Rep.\",\n",
    "        \"Rev.\",\n",
    "        \"Sen.\",\n",
    "        \"St.\",\n",
    "        \"vs.\",\n",
    "        \"v.s.\",\n",
    "        ]:\n",
    "    special_cases[orth] = [{ORTH: orth}]\n",
    "    \n",
    "#print (special_cases)\n",
    "\n",
    "\n",
    "\n",
    "def preProcessing(query):\n",
    "    queryLower = query.lower()\n",
    "    if queryLower.startswith('eli5'):\n",
    "        cutMarker = queryLower.find(' ') + 1\n",
    "        query = query[cutMarker:]\n",
    "    \n",
    "    \n",
    "    nlp2.tokenizer.rules = special_cases \n",
    "    \n",
    "    #simple_url_re = re.compile(r'''^https?://''')\n",
    "    #nlp2.tokenizer.token_match = {}\n",
    "    \n",
    "    #print(nlp.tokenizer.rules)\n",
    "    prefixes = (\n",
    "        [\"§\", \"%\", \"=\", \"—\", \"–\", r\"\\+(?![0-9])\"]\n",
    "        + LIST_PUNCT\n",
    "        + LIST_ELLIPSES\n",
    "        + LIST_QUOTES\n",
    "        + LIST_CURRENCY\n",
    "        + LIST_ICONS\n",
    "    )\n",
    "\n",
    "\n",
    "    suffixes = (\n",
    "        LIST_PUNCT\n",
    "        + LIST_ELLIPSES\n",
    "        + LIST_QUOTES\n",
    "        + LIST_ICONS\n",
    "        + [\"'s\", \"'S\", \"’s\", \"’S\", \"—\", \"–\"]\n",
    "        + [\n",
    "            r\"(?<=[0-9])\\+\",\n",
    "            r\"(?<=°[FfCcKk])\\.\",\n",
    "            r\"(?<=[0-9])(?:{c})\".format(c=CURRENCY),\n",
    "            r\"(?<=[0-9])(?:{u})\".format(u=UNITS),\n",
    "            r\"(?<=[0-9{al}{e}{p}(?:{q})])\\.\".format(\n",
    "                al=ALPHA_LOWER, e=r\"%²\\-\\+\", q=CONCAT_QUOTES, p=PUNCT\n",
    "            ),\n",
    "            r\"(?<=[{au}][{au}])\\.\".format(au=ALPHA_UPPER),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    infixes = (\n",
    "        LIST_ELLIPSES\n",
    "        + LIST_ICONS\n",
    "        + [\n",
    "            r\"(?<=[0-9])[+\\-\\*^](?=[0-9-])\",\n",
    "            r\"(?<=[{al}{q}])\\.(?=[{au}{q}])\".format(\n",
    "                al=ALPHA_LOWER, au=ALPHA_UPPER, q=CONCAT_QUOTES\n",
    "            ),\n",
    "            r\"(?<=[{a}]),(?=[{a}])\".format(a=ALPHA),\n",
    "            #r\"(?<=[{a}])(?:{h})(?=[{a}])\".format(a=ALPHA, h=HYPHENS),\n",
    "            r\"(?<=[{a}0-9])[:<>=/](?=[{a}])\".format(a=ALPHA),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    prefixes_re = compile_prefix_regex(prefixes)\n",
    "    nlp2.tokenizer.prefix_search=prefixes_re.search\n",
    "    \n",
    "    suffixes_re = compile_suffix_regex(suffixes)\n",
    "    nlp2.tokenizer.suffix_search=suffixes_re.search\n",
    "    \n",
    "    infix_re = compile_infix_regex(infixes)\n",
    "    nlp2.tokenizer.infix_finditer = infix_re.finditer\n",
    "    \n",
    "    query = query.replace('\\n', ' ')\n",
    "    query = query.replace('\\t', ' ')\n",
    "    query = re.sub(r'(\\w\\w)\\?(\\w\\w)', r'\\1 ? \\2', query)\n",
    "    query = query.replace('(', ' ( ')\n",
    "    query = query.replace(')', ' ) ')\n",
    "    query = query.replace('   ', ' ')\n",
    "    query = query.replace('  ', ' ')\n",
    "   \n",
    "    doc = nlp2(query)#, disable=['parser', 'ner'])\n",
    "    tokens = []\n",
    "    for token in doc:\n",
    "        if token.text != ' ':\n",
    "            tokens.append(token.text) \n",
    "        \n",
    "    if len(tokens) == 0:\n",
    "        print(\"Zero token sentence detected!\")\n",
    "    return tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:58.902233Z",
     "iopub.status.busy": "2025-06-14T15:28:58.901965Z",
     "iopub.status.idle": "2025-06-14T15:28:58.940817Z",
     "shell.execute_reply": "2025-06-14T15:28:58.939917Z",
     "shell.execute_reply.started": "2025-06-14T15:28:58.902206Z"
    }
   },
   "outputs": [],
   "source": [
    "def returnMask(row,tokenizer):\n",
    "    max_length = 128\n",
    "    text_tokens=row['text']\n",
    "    ##### a very rare corner case\n",
    "    if(len(text_tokens)==0):\n",
    "        text_tokens=['dummy']\n",
    "        print(\"length of text ==0\")\n",
    "    mask_all = row['rationales']\n",
    "    \n",
    "    mask_all_temp=mask_all\n",
    "    count_temp=0\n",
    "    while(len(mask_all_temp)!=3):\n",
    "        mask_all_temp.append([0]*len(text_tokens))\n",
    "    #print(len(mask_all_temp),len(mask_all_temp[0]))\n",
    "    \n",
    "    word_mask_all=[]\n",
    "    word_tokens_all=[]\n",
    "    \n",
    "    \n",
    "    for mask in mask_all_temp:\n",
    "        if(mask[0]==-1):\n",
    "            mask=[0]*len(mask)\n",
    "        list_pos=[]\n",
    "        mask_pos=[]\n",
    "        flag=0\n",
    "        for i in range(0,len(mask)):\n",
    "            if(i==0 and mask[i]==0):\n",
    "                list_pos.append(0)\n",
    "                mask_pos.append(0)\n",
    "            if(flag==0 and mask[i]==1):\n",
    "                mask_pos.append(1)\n",
    "                list_pos.append(i)\n",
    "                flag=1\n",
    "            elif(flag==1 and mask[i]==0):\n",
    "                flag=0\n",
    "                mask_pos.append(0)\n",
    "                list_pos.append(i)\n",
    "        if(list_pos[-1]!=len(mask)):\n",
    "            list_pos.append(len(mask))\n",
    "            mask_pos.append(0)\n",
    "        string_parts=[]\n",
    "        for i in range(len(list_pos)-1):\n",
    "            string_parts.append(text_tokens[list_pos[i]:list_pos[i+1]])\n",
    "        #print(\"Flag 6\",string_parts)\n",
    "\n",
    "        word_tokens=[]\n",
    "        word_mask=[]\n",
    "\n",
    "        #print(len(string_parts))\n",
    "        for i in range(0,len(string_parts)):\n",
    "            \n",
    "            tokens=ek_extra_preprocess(\" \".join(string_parts[i]),tokenizer)\n",
    "            #print(\"Flag 5\",tokens)\n",
    "            #print(len(tokens))\n",
    "            masks=[mask_pos[i]]*len(tokens)\n",
    "            word_tokens+=tokens\n",
    "            word_mask+=masks\n",
    "        #print(len(word_tokens),len(word_mask))\n",
    "\n",
    "        \n",
    "        word_tokens=word_tokens[0:(max_length)]\n",
    "        word_mask=word_mask[0:(max_length)]\n",
    "        # word_tokens.append(102)\n",
    "        # word_mask.append(0)\n",
    "\n",
    "        word_mask_all.append(word_mask)\n",
    "        word_tokens_all.append(word_tokens)\n",
    "    if(len(mask_all)==0):\n",
    "        word_mask_all=[]\n",
    "    else:    \n",
    "        word_mask_all=word_mask_all[0:len(mask_all)]  \n",
    "    return word_tokens_all[0],word_mask_all    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:58.942033Z",
     "iopub.status.busy": "2025-06-14T15:28:58.941580Z",
     "iopub.status.idle": "2025-06-14T15:28:58.963587Z",
     "shell.execute_reply": "2025-06-14T15:28:58.962811Z",
     "shell.execute_reply.started": "2025-06-14T15:28:58.942002Z"
    }
   },
   "outputs": [],
   "source": [
    "# dataset = datasets.load_dataset(\"NeelNanda/pile-10k\", split=\"train\")\n",
    "# print(dataset)\n",
    "# print(dataset[0]['text'][:100])\n",
    "# tokens_dataset = tokenize_and_concatenate(dataset, reference_gpt2.tokenizer, streaming=False, max_length=model_cfg.n_ctx, column_name=\"text\", add_bos_token=True, num_proc=4)\n",
    "# data_loader = torch.utils.data.DataLoader(tokens_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:58.965192Z",
     "iopub.status.busy": "2025-06-14T15:28:58.964420Z",
     "iopub.status.idle": "2025-06-14T15:28:58.978652Z",
     "shell.execute_reply": "2025-06-14T15:28:58.978030Z",
     "shell.execute_reply.started": "2025-06-14T15:28:58.965167Z"
    }
   },
   "outputs": [],
   "source": [
    "# dataset = load_dataset(\"hatexplain\")\n",
    "\n",
    "# trainloader = torch.utils.data.DataLoader(dataset[\"train\"], batch_size=1, shuffle=False)\n",
    "\n",
    "# testloader = torch.utils.data.DataLoader(dataset[\"test\"], batch_size=1, shuffle=False)\n",
    "# valloader = torch.utils.data.DataLoader(dataset[\"validation\"], batch_size=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:58.979754Z",
     "iopub.status.busy": "2025-06-14T15:28:58.979489Z",
     "iopub.status.idle": "2025-06-14T15:28:59.002600Z",
     "shell.execute_reply": "2025-06-14T15:28:59.002045Z",
     "shell.execute_reply.started": "2025-06-14T15:28:58.979732Z"
    }
   },
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "from transformers import BertTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:28:59.003596Z",
     "iopub.status.busy": "2025-06-14T15:28:59.003400Z",
     "iopub.status.idle": "2025-06-14T15:29:00.215402Z",
     "shell.execute_reply": "2025-06-14T15:29:00.214591Z",
     "shell.execute_reply.started": "2025-06-14T15:28:59.003580Z"
    }
   },
   "outputs": [],
   "source": [
    "with open('dataset.json', 'r') as fp:\n",
    "        data = json.load(fp)\n",
    "dict_data=[]\n",
    "for key in data:\n",
    "    temp={}\n",
    "    temp['post_id']=key\n",
    "    temp['text']=data[key]['post_tokens']\n",
    "    final_label=[]\n",
    "    for i in range(1,4):\n",
    "        temp['annotatorid'+str(i)]=data[key]['annotators'][i-1]['annotator_id']\n",
    "        temp['target'+str(i)]=data[key]['annotators'][i-1]['target']\n",
    "        temp['label'+str(i)]=data[key]['annotators'][i-1]['label']\n",
    "        final_label.append(temp['label'+str(i)])\n",
    "\n",
    "    final_label_id=max(final_label,key=final_label.count)\n",
    "    temp['rationales']=data[key]['rationales']\n",
    "    #print(temp[\"rationales\"])\n",
    "    if(final_label.count(final_label_id)==1):\n",
    "        temp['final_label']='undecided'\n",
    "    else:\n",
    "        temp['final_label']=final_label_id    \n",
    "    dict_data.append(temp)    \n",
    "temp_read = pd.DataFrame(dict_data)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:34:37.981834Z",
     "iopub.status.busy": "2025-06-14T15:34:37.981506Z",
     "iopub.status.idle": "2025-06-14T15:34:37.987836Z",
     "shell.execute_reply": "2025-06-14T15:34:37.987122Z",
     "shell.execute_reply.started": "2025-06-14T15:34:37.981811Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(['ion', 'hang', 'wit', 'bitches', 'who', 'niggas', 'are', 'insecure'],\n",
       " [[0, 0, 0, 1, 0, 0, 0, 0],\n",
       "  [0, 0, 0, 1, 0, 0, 0, 0],\n",
       "  [0, 0, 0, 1, 0, 0, 0, 0]],\n",
       " 'offensive')"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp_read.iloc[83].text,temp_read.iloc[83].rationales,temp_read.iloc[83].final_label\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-14T15:33:45.512908Z",
     "iopub.status.busy": "2025-06-14T15:33:45.512150Z",
     "iopub.status.idle": "2025-06-14T15:33:45.519531Z",
     "shell.execute_reply": "2025-06-14T15:33:45.518751Z",
     "shell.execute_reply.started": "2025-06-14T15:33:45.512878Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "post_id                               1179088797964763136_twitter\n",
       "text            [<user>, i, am, bit, confused, coz, chinese, p...\n",
       "annotatorid1                                                    1\n",
       "target1                                                   [Asian]\n",
       "label1                                                 hatespeech\n",
       "annotatorid2                                                    4\n",
       "target2                                                   [Asian]\n",
       "label2                                                  offensive\n",
       "annotatorid3                                                    3\n",
       "target3                                                   [Asian]\n",
       "label3                                                 hatespeech\n",
       "rationales      [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...\n",
       "final_label                                            hatespeech\n",
       "Name: 3, dtype: object"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp_read.iloc[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading BERT tokenizer...\n",
      "total_data 20148\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "50ee3f24a98449f0a12c8baa283e8e84",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20148 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "data = temp_read\n",
    "print('Loading BERT tokenizer...')\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=False)\n",
    "print('total_data',len(data))\n",
    "post_ids_list=[]\n",
    "text_list=[]\n",
    "rationales_list=[]\n",
    "label_list=[]\n",
    "for index,row in tqdm.tqdm(data.iterrows(),total=len(data)):\n",
    "    #print(params)\n",
    "    text=row['text']\n",
    "    post_id=row['post_id']\n",
    "\n",
    "    annotation_list=[row['label1'],row['label2'],row['label3']] \n",
    "    annotation=row['final_label']\n",
    "\n",
    "    #print(annotation_list,annotation)\n",
    "        \n",
    "    if(annotation != 'undecided'):\n",
    "        tokens,rationales = returnMask(row,tokenizer)\n",
    "        rationales_list.append(rationales)\n",
    "        text_list.append(tokens)\n",
    "        label_list.append(annotation)\n",
    "        post_ids_list.append(post_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('post_id_divisions.json', 'r') as fp:\n",
    "    post_id_dict=json.load(fp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "tdata = pd.DataFrame(list(zip(post_ids_list,text_list,rationales_list,label_list)), \n",
    "                             columns =['Post_id','Text', 'Attention' , 'Label']) \n",
    "\n",
    "# print(len(tdata.iloc[0][\"Attention\"][0]))\n",
    "X_train=tdata[tdata['Post_id'].isin(post_id_dict['train'])]\n",
    "X_val=tdata[tdata['Post_id'].isin(post_id_dict['val'])]\n",
    "X_test=tdata[tdata['Post_id'].isin(post_id_dict['test'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Post_id</th>\n",
       "      <th>Text</th>\n",
       "      <th>Attention</th>\n",
       "      <th>Label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1179063826874032128_twitter</td>\n",
       "      <td>[2057, 3685, 3613, 4214, 9731, 10469, 2015, 20...</td>\n",
       "      <td>[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td>\n",
       "      <td>normal</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1178793830532956161_twitter</td>\n",
       "      <td>[6583, 26677, 8038, 3363, 9152, 13327, 2015, 9...</td>\n",
       "      <td>[[0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, ...</td>\n",
       "      <td>normal</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1179088797964763136_twitter</td>\n",
       "      <td>[5310, 1045, 2572, 2978, 5457, 2522, 2480, 282...</td>\n",
       "      <td>[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td>\n",
       "      <td>hatespeech</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>24198545_gab</td>\n",
       "      <td>[1998, 2023, 2003, 2339, 1045, 2203, 2039, 200...</td>\n",
       "      <td>[[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1,...</td>\n",
       "      <td>hatespeech</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>14567516_gab</td>\n",
       "      <td>[2053, 5620, 5181, 1998, 22212, 2015, 2129, 43...</td>\n",
       "      <td>[[1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 0, ...</td>\n",
       "      <td>offensive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19223</th>\n",
       "      <td>9988840_gab</td>\n",
       "      <td>[2023, 25047, 16939, 17276, 4632, 12873, 2121,...</td>\n",
       "      <td>[[0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0,...</td>\n",
       "      <td>offensive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19225</th>\n",
       "      <td>9990225_gab</td>\n",
       "      <td>[2043, 1045, 2034, 2288, 2006, 2182, 1998, 205...</td>\n",
       "      <td>[[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0,...</td>\n",
       "      <td>offensive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19226</th>\n",
       "      <td>9991681_gab</td>\n",
       "      <td>[2001, 24532, 2102, 4315, 9587, 25016, 2213, 1...</td>\n",
       "      <td>[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td>\n",
       "      <td>normal</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19227</th>\n",
       "      <td>9992513_gab</td>\n",
       "      <td>[2009, 2003, 9643, 2298, 2012, 2088, 28321, 40...</td>\n",
       "      <td>[[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1,...</td>\n",
       "      <td>hatespeech</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19228</th>\n",
       "      <td>9998729_gab</td>\n",
       "      <td>[1996, 3644, 3795, 2923, 7069, 2031, 2069, 109...</td>\n",
       "      <td>[[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,...</td>\n",
       "      <td>offensive</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>15383 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                           Post_id  \\\n",
       "1      1179063826874032128_twitter   \n",
       "2      1178793830532956161_twitter   \n",
       "3      1179088797964763136_twitter   \n",
       "6                     24198545_gab   \n",
       "7                     14567516_gab   \n",
       "...                            ...   \n",
       "19223                  9988840_gab   \n",
       "19225                  9990225_gab   \n",
       "19226                  9991681_gab   \n",
       "19227                  9992513_gab   \n",
       "19228                  9998729_gab   \n",
       "\n",
       "                                                    Text  \\\n",
       "1      [2057, 3685, 3613, 4214, 9731, 10469, 2015, 20...   \n",
       "2      [6583, 26677, 8038, 3363, 9152, 13327, 2015, 9...   \n",
       "3      [5310, 1045, 2572, 2978, 5457, 2522, 2480, 282...   \n",
       "6      [1998, 2023, 2003, 2339, 1045, 2203, 2039, 200...   \n",
       "7      [2053, 5620, 5181, 1998, 22212, 2015, 2129, 43...   \n",
       "...                                                  ...   \n",
       "19223  [2023, 25047, 16939, 17276, 4632, 12873, 2121,...   \n",
       "19225  [2043, 1045, 2034, 2288, 2006, 2182, 1998, 205...   \n",
       "19226  [2001, 24532, 2102, 4315, 9587, 25016, 2213, 1...   \n",
       "19227  [2009, 2003, 9643, 2298, 2012, 2088, 28321, 40...   \n",
       "19228  [1996, 3644, 3795, 2923, 7069, 2031, 2069, 109...   \n",
       "\n",
       "                                               Attention       Label  \n",
       "1      [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...      normal  \n",
       "2      [[0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, ...      normal  \n",
       "3      [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...  hatespeech  \n",
       "6      [[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1,...  hatespeech  \n",
       "7      [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 0, ...   offensive  \n",
       "...                                                  ...         ...  \n",
       "19223  [[0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0,...   offensive  \n",
       "19225  [[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0,...   offensive  \n",
       "19226  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...      normal  \n",
       "19227  [[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1,...  hatespeech  \n",
       "19228  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,...   offensive  \n",
       "\n",
       "[15383 rows x 4 columns]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "class textDataset(Dataset):\n",
    "\n",
    "    def __init__(self, data, transform=None):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            csv_file (string): Path to the csv file with annotations.\n",
    "            root_dir (string): Directory with all the images.\n",
    "            transform (callable, optional): Optional transform to be applied\n",
    "                on a sample.\n",
    "        \"\"\"\n",
    "        self.inputs = data[\"Text\"]\n",
    "        self.labels = data[\"Label\"]\n",
    "        self.rationales = data[\"Attention\"]\n",
    "        self.transform = transform\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.inputs)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.inputs.iloc[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder\n",
    "from torch.utils.data import TensorDataset, DataLoader, RandomSampler,SequentialSampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def combine_features(tuple_data,is_train=False):\n",
    "    max_length =128\n",
    "    input_ids =  [ele[0] for ele in tuple_data]\n",
    "    att_vals = [ele[1] for ele in tuple_data]\n",
    "    labels = [ele [2] for ele in tuple_data]\n",
    "\n",
    "\n",
    "    encoder = LabelEncoder()\n",
    "    \n",
    "    encoder.classes_ = np.load('classes.npy',allow_pickle=True)\n",
    "    labels=encoder.transform(labels)\n",
    "    \n",
    "    input_ids = pad_sequences(input_ids,maxlen=max_length, \n",
    "                              dtype=\"long\", value=0, truncating=\"post\", padding=\"post\")\n",
    "    #print(len(att_vals))\n",
    "    #print(att_vals,len(att_vals),\"Flag12\",len(att_vals[0]))\n",
    "\n",
    "    rationales_vals = []\n",
    "    for values in att_vals:\n",
    "        temp_rationales = pad_sequences(values,maxlen=max_length, \n",
    "                                     dtype=\"long\", value=0, truncating=\"post\", padding=\"post\")\n",
    "        temp_rationales = torch.tensor(temp_rationales)\n",
    "        rationales_vals.append(temp_rationales)\n",
    "    rationales_vals = torch.stack(rationales_vals,dim=0)    \n",
    "    #print(att_vals,len(att_vals),\"Flag11\",len(att_vals[0]))\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    att_masks=custom_att_masks(input_ids)\n",
    "    dataloader=return_dataloader(input_ids,labels,rationales_vals,att_masks,is_train)\n",
    "    return dataloader\n",
    "\n",
    "def return_dataloader(input_ids,labels,att_vals,att_masks,is_train=False):\n",
    "    inputs = torch.tensor(input_ids)\n",
    "    labels = torch.tensor(labels,dtype=torch.long)\n",
    "    masks = torch.tensor(np.array(att_masks),dtype=torch.uint8)\n",
    "    print(att_vals.shape)\n",
    "    #attention = torch.tensor(np.array(att_vals),dtype=torch.float)\n",
    "    attention = att_vals\n",
    "    #print(inputs.shape,attention.shape,masks.shape,labels.shape)\n",
    "    data = TensorDataset(inputs,attention,masks,labels)\n",
    "    if(is_train==False):\n",
    "        sampler = SequentialSampler(data)\n",
    "    else:\n",
    "        sampler = RandomSampler(data)\n",
    "    dataloader = DataLoader(data, sampler=sampler, batch_size=32)\n",
    "    return dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def custom_att_masks(input_ids):\n",
    "    attention_masks = []\n",
    "\n",
    "    # For each sentence...\n",
    "    for sent in input_ids:\n",
    "\n",
    "        # Create the attention mask.256\n",
    "        #   - If a token ID is 0, then it's padding, set the mask to 0.\n",
    "        #   - If a token ID is > 0, then it's a real token, set the mask to 1.\n",
    "        att_mask = [int(token_id > 0) for token_id in sent]\n",
    "\n",
    "        # Store the attention mask for this sentence.\n",
    "        attention_masks.append(att_mask)\n",
    "    return attention_masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encodeData(dataframe):\n",
    "    tuple_new_data=[]\n",
    "    for index,row in tqdm.tqdm(dataframe.iterrows(),total=len(dataframe)):\n",
    "        tuple_new_data.append((row['Text'],row['Attention'],row['Label']))\n",
    "    return tuple_new_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5556817e9519438bbc0049d6393f3e77",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/15383 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "742783885cf0497ea1b2a023f755e12d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1922 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0566354e06a8498aae5c64ceaa962aec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1924 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "X_train = encodeData(X_train)\n",
    "X_val = encodeData(X_val)\n",
    "X_test = encodeData(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.preprocessing.sequence import pad_sequences\n",
    "from torch.utils.data import TensorDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([15383, 3, 128])\n",
      "torch.Size([1922, 3, 128])\n",
      "torch.Size([1924, 3, 128])\n"
     ]
    }
   ],
   "source": [
    "# convert X_train to list of lists to be passed to combine features\n",
    "train_dataloader = combine_features(X_train,is_train=True)\n",
    "validation_dataloader = combine_features(X_val,is_train=False)\n",
    "test_dataloader=combine_features(X_test,is_train=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "# trainset = textDataset(X_train)\n",
    "# valset = textDataset(X_val)\n",
    "# testset = textDataset(X_test)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "# trainloader = DataLoader(trainset, batch_size=16000, shuffle=False)\n",
    "# valloader = DataLoader(valset,batch_size=3000,shuffle=False)\n",
    "# testloader = DataLoader(testset,batch_size=3000,shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "for c,batch in enumerate(train_dataloader):\n",
    "    tokens, rationales, mask, labels = batch\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_performance(model,dataloader,dataset=\"train\"):\n",
    "    model.eval()\n",
    "    pred = []\n",
    "    gt = []\n",
    "    for c, batch in enumerate(dataloader):\n",
    "        inputs,rationales,mask,tlabels = batch\n",
    "        inputs,tlabels,mask = inputs.to(device),tlabels.to(device),mask.to(device)\n",
    "        outputs,_= model(inputs)\n",
    "        #print(outputs.shape)\n",
    "        pred.append(torch.argmax(outputs,dim=1).cpu())\n",
    "        gt.append(tlabels.cpu())\n",
    "        #print(len(pred))\n",
    "    pred = torch.hstack(pred)\n",
    "    #print(pred.size())\n",
    "    gt = torch.hstack(gt)\n",
    "    print(\"Acccuracy on \"+ dataset + \":\", sum(pred == gt)/len(pred))\n",
    "    model.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_heatmaps(model,dataloader,name):\n",
    "    attn_values = []\n",
    "    prediction_values = []\n",
    "    for c, batch in enumerate(dataloader):\n",
    "        inputs,rationales,mask,tlabels = batch\n",
    "        inputs,tlabels,mask = inputs.to(device),tlabels.to(device),mask.to(device)\n",
    "        outputs,attn= model(inputs)\n",
    "        outputs = nn.Softmax(dim=1)(outputs)\n",
    "        #print(attn[0].shape)\n",
    "        for i in range(len(attn[0])):\n",
    "            temp_attn = attn[0][i,0,-1].cpu().detach().numpy()\n",
    "            #print(temp_attn.shape)\n",
    "            #break\n",
    "            temp = 0 \n",
    "            for rationale in rationales[i]:\n",
    "                temp_rationale = rationale\n",
    "                indices = temp_rationale.long().cpu().numpy()\n",
    "                temp += sum(temp_attn[indices.astype(bool)])\n",
    "            attn_values.append(temp/3)\n",
    "            prediction_values.append(outputs[i,tlabels[i]].item())\n",
    "\n",
    "    print(\"attention\",sum( np.array(attn_values)<=0.33)/len(attn_values),sum(np.logical_and(np.array(attn_values)>0.33, np.array(attn_values)<=0.66))/len(attn_values),sum(np.array(attn_values)>0.66)/len(attn_values))\n",
    "    print(\"prediction\",sum( np.array(prediction_values)<=0.33)/len(prediction_values),sum(np.logical_and(np.array(prediction_values)>0.33, np.array(prediction_values)<=0.66))/len(prediction_values),sum(np.array(prediction_values)>0.66)/len(prediction_values))\n",
    "    fig, ax = plt.subplots()\n",
    "    h, xedges, yedges, im = ax.hist2d(np.array(attn_values),np.array(prediction_values),[[0,0.33,0.66,1.1],[0,0.33,0.66,1.1]])\n",
    "    plt.close(fig)\n",
    "    temp = (h.T/h.sum())*100\n",
    "\n",
    "    # Prevent automatic plotting by removing `plt.show()`\n",
    "    # or manually remove the image\n",
    "    im.remove()  # This prevents the heatmap from being displayed\n",
    "    plt.figure(figsize=(7,7))\n",
    "    ax = sns.heatmap(np.round(temp,2),vmin=5,vmax=70,annot=np.round(temp,2),fmt=\"g\",cmap=sns.color_palette(\"coolwarm\"),\n",
    "    yticklabels=[0.33,0.66,1.],\n",
    "    xticklabels=[0.33,0.66,1],annot_kws={\"size\":18},cbar=False)\n",
    "    ax.invert_yaxis()\n",
    "    plt.xlabel(r\"distinct token attention\",fontweight=\"bold\",fontsize=14)\n",
    "    plt.ylabel(r\"true token probability\",fontweight=\"bold\",fontsize=14) # change xlabel based on algo\n",
    "    plt.xticks([0,1,2,3],[0,0.33,0.66,1],weight=\"bold\",fontsize=14)\n",
    "    plt.yticks([0,1,2,3],[0,0.33,0.66,1],weight=\"bold\", va=\"top\",fontsize=14)\n",
    "    plt.savefig(\"./plots/\"+name+\".pdf\",bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "def rollout_single_step(layer_attentions, alpha=0.3):\n",
    "    \"\"\"\n",
    "    Compute rollout attention for classification.\n",
    "    Includes identity correction and per-layer normalization.\n",
    "    \"\"\"\n",
    "    attentions = []\n",
    "    for attn in layer_attentions:\n",
    "        # Mean over heads\n",
    "        if attn.ndim == 4:   # (batch, heads, seq, seq)\n",
    "            attn = attn[0].mean(0)\n",
    "        else:                # (heads, seq, seq)\n",
    "            attn = attn.mean(0)\n",
    "        \n",
    "        # Add identity skip connection\n",
    "        attn = alpha * attn + (1 - alpha) * np.eye(attn.shape[0])\n",
    "        \n",
    "        # Normalize rows\n",
    "        attn = attn / (attn.sum(-1, keepdims=True) + 1e-12)\n",
    "        attentions.append(attn)\n",
    "    \n",
    "    # Rollout: multiply attention matrices\n",
    "    R = attentions[0]\n",
    "    for attn in attentions[1:]:\n",
    "        R = attn @ R   \n",
    "    \n",
    "    # Final row normalization\n",
    "    return R / (R.sum(-1, keepdims=True) + 1e-12)\n",
    "\n",
    "\n",
    "def attention_mass_classification(attention_weights, rationales, alpha=0.3):\n",
    "    \"\"\"\n",
    "    Compute attention mass on rationale positions for classification.\n",
    "    \n",
    "    Args:\n",
    "        attention_weights: List of attention tensors (one per layer)\n",
    "        rationale_positions: List of token positions corresponding to rationale\n",
    "        alpha: Weight for rollout identity connection\n",
    "        \n",
    "    Returns:\n",
    "        rollout_mass: Rollout attention mass on rationale\n",
    "        layer_avg_mass: Layer-averaged attention mass on rationale\n",
    "        max_pool_mass: Max-pooled attention mass on rationale\n",
    "        rollout_attn_probs: Rollout attention distribution (for comp/suff)\n",
    "    \"\"\"\n",
    "    # --- Method 1: Rollout Attention ---\n",
    "    R = rollout_single_step(attention_weights, alpha=alpha)\n",
    "    # Use last token (position -1) attention for classification\n",
    "    \n",
    "    #print(R.shape, R[-1].sum(), rationales.shape)\n",
    "    p = R[-1]  # last token distribution over all positions\n",
    "    rollout_attn_probs = p\n",
    "    \n",
    "\n",
    "    rollout_mass = (rollout_attn_probs*rationales.numpy()).sum(axis=1).mean()\n",
    "\n",
    "    \n",
    "    \n",
    "\n",
    "    \n",
    "    # --- Method 2: Layer-Averaged Attention ---\n",
    "    layer_attns = []\n",
    "    for layer_attn in attention_weights:\n",
    "        if layer_attn.ndim == 4:  # (batch, heads, seq, seq)\n",
    "            layer_attn = layer_attn[0].mean(0)  # Average over heads\n",
    "        else:  # (heads, seq, seq)\n",
    "            layer_attn = layer_attn.mean(0)\n",
    "        layer_attns.append(layer_attn)\n",
    "    \n",
    "    avg_attn = np.mean(layer_attns, axis=0)  # (seq, seq)\n",
    "    cls_attn = avg_attn[-1]  # [CLS] token's attention\n",
    "    \n",
    "    \n",
    "    layer_avg_mass = (cls_attn*rationales.numpy()).sum(axis=1).mean()\n",
    "  \n",
    "    \n",
    " \n",
    "\n",
    "\n",
    "    \n",
    "    # --- Method 3: Max Pooling Across Layers ---\n",
    "    max_attn = np.maximum.reduce(layer_attns)  # (seq, seq)\n",
    "    cls_attn_max = max_attn[-1]  # last token's attention\n",
    "    \n",
    "    \n",
    "    max_pool_mass = (cls_attn_max*rationales.numpy()).sum(axis=1).mean()\n",
    "\n",
    "    \n",
    "    return rollout_mass, layer_avg_mass, max_pool_mass, rollout_attn_probs\n",
    "\n",
    "\n",
    "def compute_comprehensiveness_sufficiency_classification(\n",
    "    model, inputs, predicted_class, original_prob, \n",
    "    rollout_attn_probs, rationale_positions, k_percent=20, device='cuda'):\n",
    "    \"\"\"\n",
    "    Compute comprehensiveness and sufficiency for classification task.\n",
    "    \n",
    "    Comprehensiveness = f(x) - f(x\\r:k%) [drop after removing top-k%]\n",
    "    Sufficiency = f(x) - f(r:k%) [drop when keeping only top-k%]\n",
    "    \n",
    "    Args:\n",
    "        model: Classification model\n",
    "        inputs: Input token ids (tensor of shape [batch_size, seq_len])\n",
    "        predicted_class: Predicted class label (int)\n",
    "        original_prob: Original probability for predicted class (float)\n",
    "        rollout_attn_probs: Rollout attention distribution\n",
    "        rationale_positions: Token positions of rationale (for diagnostic)\n",
    "        k_percent: Percentage of tokens to use for top-k\n",
    "        device: torch device\n",
    "        \n",
    "    Returns:\n",
    "        comprehensiveness: Probability drop after removing top-k%\n",
    "        sufficiency: Probability drop when keeping only top-k%\n",
    "    \"\"\"\n",
    "    seq_len = inputs.size(1)\n",
    "    attn_scores = rollout_attn_probs[:seq_len]\n",
    "    \n",
    "    # Calculate k (number of top tokens)\n",
    "    k = max(1, int(np.ceil(0.01 * k_percent * seq_len)))\n",
    "    \n",
    "    # Get top-k indices\n",
    "    topk_indices = np.argpartition(-attn_scores, min(k, len(attn_scores)-1))[:k]\n",
    "    topk_indices_set = set(topk_indices)\n",
    "    \n",
    "    # Diagnostic: check overlap with rationale\n",
    "    overlap = len(topk_indices_set & set(rationale_positions))\n",
    "    # print(f\"Top-{k} contains {overlap}/{len(rationale_positions)} rationale tokens\")\n",
    "    \n",
    "    # --- Comprehensiveness: f(x) - f(x\\r:k%) ---\n",
    "    # Remove top-k tokens completely\n",
    "    mask_keep_indices = [i for i in range(seq_len) if i not in topk_indices_set]\n",
    "    if len(mask_keep_indices) > 0:\n",
    "        inputs_removed = inputs[:, mask_keep_indices]\n",
    "        \n",
    "        logits_removed, _ = model(inputs_removed)\n",
    "        probs_removed = F.softmax(logits_removed, dim=-1)\n",
    "        prob_without_topk = probs_removed[0, predicted_class].item()\n",
    "    else:\n",
    "        prob_without_topk = 0.0\n",
    "    \n",
    "    comp_score = original_prob - prob_without_topk\n",
    "    \n",
    "    # --- Sufficiency: f(x) - f(r:k%) ---\n",
    "    # Keep only top-k tokens\n",
    "    inputs_keep = inputs[:, topk_indices]\n",
    "    \n",
    "    if inputs_keep.size(1) > 0:\n",
    "        logits_keep, _ = model(inputs_keep)\n",
    "        probs_keep = F.softmax(logits_keep, dim=-1)\n",
    "        prob_with_only_topk = probs_keep[0, predicted_class].item()\n",
    "    else:\n",
    "        prob_with_only_topk = 0.0\n",
    "    \n",
    "    suff_score = original_prob - prob_with_only_topk\n",
    "    \n",
    "    return comp_score, suff_score\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " Training Seed: 1234\n",
      "Total steps: 14430\n",
      "Evaluating every 10 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|████▊                                                                                                                                             | 1/30 [02:01<58:29, 121.02s/it]"
     ]
    }
   ],
   "source": [
    "from transformers import get_cosine_schedule_with_warmup\n",
    "import random\n",
    "\n",
    "from tqdm import tqdm as tqdm\n",
    "def set_seed(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "\n",
    "train_seeds = [1234, 1235, 1236, 1237, 1238]\n",
    "\n",
    "for n_seed in train_seeds:\n",
    "\n",
    "    print(\"\\n Training Seed:\", n_seed)\n",
    "    set_seed(n_seed)\n",
    "\n",
    "    batch_size = 32\n",
    "    num_epochs = 30\n",
    "    lr = 5e-5\n",
    "    weight_decay = 1e-3\n",
    "    momentum = 0.99\n",
    "    \n",
    "    \n",
    "    torch.manual_seed(n_seed)\n",
    "    # Build model\n",
    "    model_cfg = Config(debug=False, d_model=64, n_heads=4, d_head=64,\n",
    "                       d_mlp=256, n_layers=4, n_ctx=256,\n",
    "                       d_vocab=tokenizer.vocab_size)\n",
    "    model = DemoTransformer(model_cfg).cuda()\n",
    "\n",
    "    # -------------------------------\n",
    "    # Optimizer: SGD + momentum\n",
    "    # -------------------------------\n",
    "    \n",
    "    \n",
    "    # -------------------------------\n",
    "    # Per-layer LR scaling\n",
    "    # -------------------------------\n",
    "    param_groups = []\n",
    "\n",
    "    # Embedding & positional embeddings\n",
    "    param_groups.append({\n",
    "        \"params\": model.embed.parameters(),\n",
    "        \"lr\": 30 * lr,\n",
    "        \"weight_decay\": weight_decay\n",
    "    })\n",
    "    param_groups.append({\n",
    "        \"params\": model.pos_embed.parameters(),\n",
    "        \"lr\": 1 * lr,\n",
    "        \"weight_decay\": weight_decay\n",
    "    })\n",
    "\n",
    "    # Transformer blocks\n",
    "    for i, block in enumerate(model.blocks):\n",
    "        # Attention parameters\n",
    "        param_groups.append({\"params\": block.attn.W_Q, \"lr\": 30 * lr, \"weight_decay\": weight_decay})\n",
    "        param_groups.append({\"params\": block.attn.b_Q, \"lr\": lr, \"weight_decay\": weight_decay})\n",
    "\n",
    "        param_groups.append({\"params\": block.attn.W_K, \"lr\": 30 * lr, \"weight_decay\": weight_decay})\n",
    "        param_groups.append({\"params\": block.attn.b_K, \"lr\": lr, \"weight_decay\": weight_decay})\n",
    "\n",
    "        param_groups.append({\"params\": block.attn.W_V, \"lr\": lr, \"weight_decay\": weight_decay})\n",
    "        param_groups.append({\"params\": block.attn.b_V, \"lr\": lr, \"weight_decay\": weight_decay})\n",
    "\n",
    "        param_groups.append({\"params\": block.attn.W_O, \"lr\": lr, \"weight_decay\": weight_decay})\n",
    "        param_groups.append({\"params\": block.attn.b_O, \"lr\": lr, \"weight_decay\": weight_decay})\n",
    "\n",
    "        # MLP (if exists)\n",
    "        if hasattr(block, \"mlp\"):\n",
    "            param_groups.append({\"params\": block.mlp.parameters(), \"lr\": lr, \"weight_decay\": weight_decay})\n",
    "\n",
    "    # Classifier head\n",
    "    param_groups.append({\n",
    "        \"params\": model.cls_head.parameters(),\n",
    "        \"lr\": lr,\n",
    "        \"weight_decay\": weight_decay\n",
    "    })\n",
    "\n",
    "    # Finally create optimizer\n",
    "    optimizer = torch.optim.AdamW(param_groups, lr=lr)\n",
    "\n",
    "#     optimizer = torch.optim.SGD([{'params':model.embed.parameters(),\"lr\":100*lr,\"momentum\":0.99,\"weight_decay\":weight_decay},\n",
    "#                            {\"params\":model.pos_embed.parameters(),\"lr\":100*lr,\"momentum\":0.99,\"weight_decay\":weight_decay},\n",
    "#                            {'params':model.blocks[0].attn.W_Q,'lr':100*lr,\"momentum\":0.99,\"weight_decay\":weight_decay},\n",
    "#                            {'params':model.blocks[0].attn.b_Q,'lr':100*lr,\"momentum\":0.99,\"weight_decay\":weight_decay},\n",
    "#                            {'params':model.blocks[0].attn.W_K,'lr':100*lr,\"momentum\":0.99,\"weight_decay\":weight_decay},\n",
    "#                            {'params':model.blocks[0].attn.b_K,'lr':100*lr,\"momentum\":0.99,\"weight_decay\":weight_decay},\n",
    "#                            {'params':model.blocks[0].attn.W_V,'lr':lr,\"momentum\":0.99,\"weight_decay\":weight_decay},\n",
    "#                            {'params':model.blocks[0].attn.b_V,'lr':lr,\"momentum\":0.99,\"weight_decay\":weight_decay},\n",
    "#                            {'params':model.blocks[0].attn.W_O,'lr':lr,\"momentum\":0.99,\"weight_decay\":weight_decay},\n",
    "#                            {'params':model.blocks[0].attn.b_O,'lr':lr,\"momentum\":0.99,\"weight_decay\":weight_decay},\n",
    "#                            {'params':model.cls_head.parameters(),'lr':lr,\"momentum\":0.99,\"weight_decay\":weight_decay},\n",
    "#                                ],lr=lr)\n",
    "\n",
    "    # -------------------------------\n",
    "    # Scheduler setup\n",
    "    # Total steps = epochs * batches\n",
    "    # -------------------------------\n",
    "    total_steps = num_epochs * len(train_dataloader)\n",
    "    warmup_steps = total_steps // 10    # 10% warmup\n",
    "\n",
    "    scheduler = get_cosine_schedule_with_warmup(\n",
    "        optimizer,\n",
    "        num_warmup_steps=warmup_steps,\n",
    "        num_training_steps=total_steps,\n",
    "    )\n",
    "\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "    print(\"Total steps:\", total_steps)\n",
    "    print(\"Evaluating every 10 epochs\")\n",
    "    \n",
    "    step_count = 0\n",
    "    ep_loss = []\n",
    "    \n",
    "    for epoch in tqdm(range(num_epochs)):\n",
    "        model.train()\n",
    "        losses = []\n",
    "\n",
    "        for batch in train_dataloader:\n",
    "\n",
    "            inputs, rationales, mask, tlabels = batch\n",
    "            inputs = inputs.cuda()\n",
    "            tlabels = tlabels.cuda()\n",
    "            mask = mask.cuda()\n",
    "\n",
    "            logits, _ = model(inputs)\n",
    "            loss = criterion(logits, tlabels)\n",
    "\n",
    "            loss.backward()\n",
    "\n",
    "            # Gradient clipping (IMPORTANT for transformers)\n",
    "            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "\n",
    "            optimizer.step()\n",
    "            scheduler.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            step_count += 1\n",
    "            losses.append(loss.item())\n",
    "\n",
    "        ep_loss.append(np.mean(losses))\n",
    "\n",
    "        # -------------------------------\n",
    "        # Evaluation every 10 epochs\n",
    "        # -------------------------------\n",
    "        if (epoch + 1) % 10 == 0:\n",
    "            print(f\"\\nEpoch {epoch+1}/{num_epochs}: Loss = {ep_loss[-1]:.4f}\")\n",
    "            calculate_performance(model, train_dataloader, dataset=\"train_set\")\n",
    "            calculate_performance(model, validation_dataloader, dataset=\"validation_set\")\n",
    "            calculate_performance(model, test_dataloader, dataset=\"test_set\")\n",
    "\n",
    "    # Final eval\n",
    "    print(\"\\nFinal Evaluation for Seed:\", n_seed)\n",
    "    calculate_performance(model, train_dataloader, dataset=\"train_set\")\n",
    "    calculate_performance(model, validation_dataloader, dataset=\"validation_set\")\n",
    "    calculate_performance(model, test_dataloader, dataset=\"test_set\")\n",
    "    \n",
    "    # Save the model's state_dict\n",
    "    torch.save(model.state_dict(), \"./Saved_Models/4_layer_Faster_QK_HX_\"+str(n_seed)+\".pt\")\n",
    "    \n",
    "\n",
    "    # Heatmaps AFTER training (not inside loop)\n",
    "#     plot_heatmaps(model, train_dataloader, name=f\"4_layer_Faster_train_no_res_{n_seed}\")\n",
    "#     plot_heatmaps(model, validation_dataloader, name=f\"4_layer_Faster_val_no_res_{n_seed}\")\n",
    "#     plot_heatmaps(model, test_dataloader, name=f\"4_layer_Faster_test_no_res_{n_seed}\")\n",
    "    \n",
    "    #metric,_ = evaluate_interpretability_dataset(model,tokenizer,train_dataloader,device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Main evaluation script\n",
    "seeds_list = [1234]#, 1235, 1236, 1237, 1238]\n",
    "\n",
    "# Load your dataset and model setup here\n",
    "# val_dataloader = ...\n",
    "\n",
    "for nseed in seeds_list:\n",
    "    BASELINE_CKPT = f\"./Saved_Models/4_layer_Faster_QK_HX_{nseed}.pt\"\n",
    "    \n",
    "    print(f\"Initializing baseline model instance...{nseed}\")\n",
    "    \n",
    "    baseline_cfg = Config(debug=False, d_model=64, n_heads=4, d_head=64,\n",
    "                       d_mlp=256, n_layers=4, n_ctx=256,\n",
    "                       d_vocab=tokenizer.vocab_size)\n",
    "    baseline = DemoTransformer(baseline_cfg).cuda()\n",
    "\n",
    "    baseline_ckpt = torch.load(BASELINE_CKPT, map_location=device)\n",
    "    baseline.load_state_dict(baseline_ckpt)\n",
    "    baseline.eval()\n",
    "    print(\"Faster QK model loaded.\")\n",
    "    \n",
    "    # Evaluation\n",
    "    n_samples = 15000\n",
    "    Attn_Scores_Rollout = []\n",
    "    Attn_Scores_LayerAvg = []\n",
    "    Attn_Scores_MaxPool = []\n",
    "    Comprehensiveness_Scores = []\n",
    "    Sufficiency_Scores = []\n",
    "    \n",
    "    sample_count = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm(train_dataloader):\n",
    "            if sample_count >= n_samples:\n",
    "                break\n",
    "            \n",
    "            inputs, rationales, mask, labels = batch\n",
    "            inputs = inputs.to(device)\n",
    "            labels = labels.to(device)\n",
    "            \n",
    "            batch_size = inputs.size(0)\n",
    "            \n",
    "            # Process each sample in the batch\n",
    "            for i in range(batch_size):\n",
    "                if sample_count >= n_samples:\n",
    "                    break\n",
    "                \n",
    "                # Get single sample\n",
    "                sample_input = inputs[i:i+1]  # Keep batch dimension\n",
    "                sample_rationale = rationales[i]\n",
    "                sample_label = labels[i]\n",
    "                \n",
    "                # Get predictions with attention for this sample\n",
    "                logits, attention_weights = baseline(sample_input)\n",
    "                \n",
    "                # Convert attention to numpy (extract from batch dimension)\n",
    "                attention_weights_np = [att[0].cpu().numpy() for att in attention_weights]\n",
    "                \n",
    "                # Get predicted class and probability\n",
    "                probs = F.softmax(logits, dim=-1)\n",
    "                predicted_class = logits.argmax(-1).item()\n",
    "                original_prob = probs[0, predicted_class].item()\n",
    "                \n",
    "                # Get rationale positions (indices where rationales == 1)\n",
    "                #rationale_positions = (sample_rationale == 1).nonzero(as_tuple=True)[0].cpu().tolist()\n",
    "                #if not rationale_positions:\n",
    "                #    continue\n",
    "                \n",
    "                # Compute attention mass metrics\n",
    "                rollout_mass, layer_avg_mass, max_pool_mass, rollout_attn_probs = \\\n",
    "                    attention_mass_classification(attention_weights_np, sample_rationale, alpha=1)\n",
    "                \n",
    "                # Compute comprehensiveness and sufficiency\n",
    "                comp, suff = compute_comprehensiveness_sufficiency_classification(\n",
    "                    baseline, sample_input, predicted_class, original_prob,\n",
    "                    rollout_attn_probs, sample_rationale , k_percent=5, device=device\n",
    "                )\n",
    "                \n",
    "                Attn_Scores_Rollout.append(rollout_mass)\n",
    "                Attn_Scores_LayerAvg.append(layer_avg_mass)\n",
    "                Attn_Scores_MaxPool.append(max_pool_mass)\n",
    "                Comprehensiveness_Scores.append(comp)\n",
    "                Sufficiency_Scores.append(suff)\n",
    "                \n",
    "                sample_count += 1\n",
    "    \n",
    "    # Print results\n",
    "    print(f\"Faster Train Results...{nseed}\")\n",
    "    print(f\"Rollout:           {np.mean(Attn_Scores_Rollout):.4f}\")\n",
    "    print(f\"Layer-Avg:         {np.mean(Attn_Scores_LayerAvg):.4f}\")\n",
    "    print(f\"Max-Pooling:       {np.mean(Attn_Scores_MaxPool):.4f}\")\n",
    "    print(f\"Comprehensiveness: {np.mean(Comprehensiveness_Scores):.4f} (higher = better)\")\n",
    "    print(f\"Sufficiency:       {np.mean(Sufficiency_Scores):.4f} (lower = better)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Main evaluation script\n",
    "seeds_list = [1234, 1235, 1236, 1237, 1238]\n",
    "\n",
    "# Load your dataset and model setup here\n",
    "# val_dataloader = ...\n",
    "\n",
    "for nseed in seeds_list:\n",
    "    BASELINE_CKPT = f\"./Saved_Models/4_layer_Faster_QK_HX_{nseed}.pt\"\n",
    "    \n",
    "    print(f\"Initializing baseline model instance...{nseed}\")\n",
    "    \n",
    "    baseline_cfg = Config(debug=False, d_model=64, n_heads=4, d_head=64,\n",
    "                       d_mlp=256, n_layers=4, n_ctx=256,\n",
    "                       d_vocab=tokenizer.vocab_size)\n",
    "    baseline = DemoTransformer(baseline_cfg).cuda()\n",
    "\n",
    "    baseline_ckpt = torch.load(BASELINE_CKPT, map_location=device)\n",
    "    baseline.load_state_dict(baseline_ckpt)\n",
    "    baseline.eval()\n",
    "    print(\"Faster QK model loaded.\")\n",
    "    \n",
    "    # Evaluation\n",
    "    n_samples = 5000\n",
    "    Attn_Scores_Rollout = []\n",
    "    Attn_Scores_LayerAvg = []\n",
    "    Attn_Scores_MaxPool = []\n",
    "    Comprehensiveness_Scores = []\n",
    "    Sufficiency_Scores = []\n",
    "    \n",
    "    sample_count = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm(validation_dataloader):\n",
    "            if sample_count >= n_samples:\n",
    "                break\n",
    "            \n",
    "            inputs, rationales, mask, labels = batch\n",
    "            inputs = inputs.to(device)\n",
    "            labels = labels.to(device)\n",
    "            \n",
    "            batch_size = inputs.size(0)\n",
    "            \n",
    "            # Process each sample in the batch\n",
    "            for i in range(batch_size):\n",
    "                if sample_count >= n_samples:\n",
    "                    break\n",
    "                \n",
    "                # Get single sample\n",
    "                sample_input = inputs[i:i+1]  # Keep batch dimension\n",
    "                sample_rationale = rationales[i]\n",
    "                sample_label = labels[i]\n",
    "                \n",
    "                # Get predictions with attention for this sample\n",
    "                logits, attention_weights = baseline(sample_input)\n",
    "                \n",
    "                # Convert attention to numpy (extract from batch dimension)\n",
    "                attention_weights_np = [att[0].cpu().numpy() for att in attention_weights]\n",
    "                \n",
    "                # Get predicted class and probability\n",
    "                probs = F.softmax(logits, dim=-1)\n",
    "                predicted_class = logits.argmax(-1).item()\n",
    "                original_prob = probs[0, predicted_class].item()\n",
    "                \n",
    "                # Get rationale positions (indices where rationales == 1)\n",
    "                #rationale_positions = (sample_rationale == 1).nonzero(as_tuple=True)[0].cpu().tolist()\n",
    "                #if not rationale_positions:\n",
    "                #    continue\n",
    "                \n",
    "                # Compute attention mass metrics\n",
    "                rollout_mass, layer_avg_mass, max_pool_mass, rollout_attn_probs = \\\n",
    "                    attention_mass_classification(attention_weights_np, sample_rationale, alpha=1)\n",
    "                \n",
    "                # Compute comprehensiveness and sufficiency\n",
    "                comp, suff = compute_comprehensiveness_sufficiency_classification(\n",
    "                    baseline, sample_input, predicted_class, original_prob,\n",
    "                    rollout_attn_probs, sample_rationale , k_percent=5, device=device\n",
    "                )\n",
    "                \n",
    "                Attn_Scores_Rollout.append(rollout_mass)\n",
    "                Attn_Scores_LayerAvg.append(layer_avg_mass)\n",
    "                Attn_Scores_MaxPool.append(max_pool_mass)\n",
    "                Comprehensiveness_Scores.append(comp)\n",
    "                Sufficiency_Scores.append(suff)\n",
    "                \n",
    "                sample_count += 1\n",
    "    \n",
    "    # Print results\n",
    "    print(f\"Faster Train Results...{nseed}\")\n",
    "    print(f\"Rollout:           {np.mean(Attn_Scores_Rollout):.4f}\")\n",
    "    print(f\"Layer-Avg:         {np.mean(Attn_Scores_LayerAvg):.4f}\")\n",
    "    print(f\"Max-Pooling:       {np.mean(Attn_Scores_MaxPool):.4f}\")\n",
    "    print(f\"Comprehensiveness: {np.mean(Comprehensiveness_Scores):.4f} (higher = better)\")\n",
    "    print(f\"Sufficiency:       {np.mean(Sufficiency_Scores):.4f} (lower = better)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Main evaluation script\n",
    "seeds_list = [1234, 1235, 1236, 1237, 1238]\n",
    "\n",
    "# Load your dataset and model setup here\n",
    "# val_dataloader = ...\n",
    "\n",
    "for nseed in seeds_list:\n",
    "    BASELINE_CKPT = f\"./Saved_Models/4_layer_Faster_QK_HX_{nseed}.pt\"\n",
    "    \n",
    "    print(f\"Initializing baseline model instance...{nseed}\")\n",
    "    \n",
    "    baseline_cfg = Config(debug=False, d_model=64, n_heads=4, d_head=64,\n",
    "                       d_mlp=256, n_layers=4, n_ctx=256,\n",
    "                       d_vocab=tokenizer.vocab_size)\n",
    "    baseline = DemoTransformer(baseline_cfg).cuda()\n",
    "\n",
    "    baseline_ckpt = torch.load(BASELINE_CKPT, map_location=device)\n",
    "    baseline.load_state_dict(baseline_ckpt)\n",
    "    baseline.eval()\n",
    "    print(\"Faster QK model loaded.\")\n",
    "    \n",
    "    # Evaluation\n",
    "    n_samples = 5000\n",
    "    Attn_Scores_Rollout = []\n",
    "    Attn_Scores_LayerAvg = []\n",
    "    Attn_Scores_MaxPool = []\n",
    "    Comprehensiveness_Scores = []\n",
    "    Sufficiency_Scores = []\n",
    "    \n",
    "    sample_count = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm(test_dataloader):\n",
    "            if sample_count >= n_samples:\n",
    "                break\n",
    "            \n",
    "            inputs, rationales, mask, labels = batch\n",
    "            inputs = inputs.to(device)\n",
    "            labels = labels.to(device)\n",
    "            \n",
    "            batch_size = inputs.size(0)\n",
    "            \n",
    "            # Process each sample in the batch\n",
    "            for i in range(batch_size):\n",
    "                if sample_count >= n_samples:\n",
    "                    break\n",
    "                \n",
    "                # Get single sample\n",
    "                sample_input = inputs[i:i+1]  # Keep batch dimension\n",
    "                sample_rationale = rationales[i]\n",
    "                sample_label = labels[i]\n",
    "                \n",
    "                # Get predictions with attention for this sample\n",
    "                logits, attention_weights = baseline(sample_input)\n",
    "                \n",
    "                # Convert attention to numpy (extract from batch dimension)\n",
    "                attention_weights_np = [att[0].cpu().numpy() for att in attention_weights]\n",
    "                \n",
    "                # Get predicted class and probability\n",
    "                probs = F.softmax(logits, dim=-1)\n",
    "                predicted_class = logits.argmax(-1).item()\n",
    "                original_prob = probs[0, predicted_class].item()\n",
    "                \n",
    "                # Get rationale positions (indices where rationales == 1)\n",
    "                #rationale_positions = (sample_rationale == 1).nonzero(as_tuple=True)[0].cpu().tolist()\n",
    "                #if not rationale_positions:\n",
    "                #    continue\n",
    "                \n",
    "                # Compute attention mass metrics\n",
    "                rollout_mass, layer_avg_mass, max_pool_mass, rollout_attn_probs = \\\n",
    "                    attention_mass_classification(attention_weights_np, sample_rationale, alpha=1)\n",
    "                \n",
    "                # Compute comprehensiveness and sufficiency\n",
    "                comp, suff = compute_comprehensiveness_sufficiency_classification(\n",
    "                    baseline, sample_input, predicted_class, original_prob,\n",
    "                    rollout_attn_probs, sample_rationale , k_percent=5, device=device\n",
    "                )\n",
    "                \n",
    "                Attn_Scores_Rollout.append(rollout_mass)\n",
    "                Attn_Scores_LayerAvg.append(layer_avg_mass)\n",
    "                Attn_Scores_MaxPool.append(max_pool_mass)\n",
    "                Comprehensiveness_Scores.append(comp)\n",
    "                Sufficiency_Scores.append(suff)\n",
    "                \n",
    "                sample_count += 1\n",
    "    \n",
    "    # Print results\n",
    "    print(f\"Faster Train Results...{nseed}\")\n",
    "    print(f\"Rollout:           {np.mean(Attn_Scores_Rollout):.4f}\")\n",
    "    print(f\"Layer-Avg:         {np.mean(Attn_Scores_LayerAvg):.4f}\")\n",
    "    print(f\"Max-Pooling:       {np.mean(Attn_Scores_MaxPool):.4f}\")\n",
    "    print(f\"Comprehensiveness: {np.mean(Comprehensiveness_Scores):.4f} (higher = better)\")\n",
    "    print(f\"Sufficiency:       {np.mean(Sufficiency_Scores):.4f} (lower = better)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#px.line(y=losses, x=np.arange(len(losses))*(model_cfg.n_ctx * batch_size), labels={\"y\":\"Loss\", \"x\":\"Tokens\"}, title=\"Training curve for my tiny demo model!\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setting1 = \"same_lr_train_no_res_connection\"\n",
    "# setting2 = \"same_lr_validation_no_res_connection\"\n",
    "# setting3 = \"same_lr_test_no_res_connection\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    # model.eval()\n",
    "    # pred = []\n",
    "    # gt = []\n",
    "    # for c, batch in tqdm.tqdm(enumerate(dataloader)):\n",
    "    #     inputs,rationales,mask,tlabels = batch\n",
    "    #     inputs,tlabels,mask = inputs.to(device),tlabels.to(device),mask.to(device)\n",
    "    #     _,outputs,_= model(inputs)\n",
    "    #     #print(outputs.shape)\n",
    "    #     pred.append(torch.argmax(outputs,dim=1).cpu())\n",
    "    #     gt.append(tlabels.cpu())\n",
    "    #     #print(len(pred))\n",
    "    # pred = torch.hstack(pred)\n",
    "    # #print(pred.size())\n",
    "    # gt = torch.hstack(gt)\n",
    "    # print(\"Acccuracy on \"+ dataset + \":\", sum(pred == gt)/len(pred))\n",
    "    # model.train()"
   ]
  }
 ],
 "metadata": {
  "kaggle": {
   "accelerator": "gpu",
   "dataSources": [
    {
     "datasetId": 6760154,
     "sourceId": 10879915,
     "sourceType": "datasetVersion"
    }
   ],
   "dockerImageVersionId": 31011,
   "isGpuEnabled": true,
   "isInternetEnabled": true,
   "language": "python",
   "sourceType": "notebook"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
