{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython import get_ipython\n",
    "from IPython.display import clear_output, display\n",
    "\n",
    "ipython = get_ipython()\n",
    "ipython.magic(\"load_ext autoreload\")\n",
    "ipython.magic(\"autoreload 2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from typing import List, Optional, Union, Dict, Tuple\n",
    "from pathlib import Path \n",
    "\n",
    "import torch\n",
    "from torch import Tensor\n",
    "import numpy as np\n",
    "import einops\n",
    "from fancy_einsum import einsum\n",
    "import circuitsvis as cv\n",
    "\n",
    "import transformer_lens.utils as tl_utils\n",
    "\n",
    "from transformer_lens import HookedTransformer\n",
    "import transformer_lens.patching as patching\n",
    "\n",
    "from transformers import AutoModelForCausalLM\n",
    "\n",
    "from torch import Tensor\n",
    "from jaxtyping import Float\n",
    "import plotly.express as px\n",
    "\n",
    "from functools import partial\n",
    "\n",
    "from torchtyping import TensorType as TT\n",
    "\n",
    "from path_patching_cm.path_patching import Node, IterNode, path_patch, act_patch\n",
    "from path_patching_cm.ioi_dataset import IOIDataset, NAMES\n",
    "from neel_plotly import imshow as imshow_n\n",
    "\n",
    "from utils.visualization import imshow_p, plot_attention_heads, plot_attention\n",
    "from utils.data_utils import generate_data_and_caches, UniversalPatchingDataset\n",
    "from utils.metrics import compute_logit_diff, compute_probability_diff, compute_probability_mass, compute_rank_0_rate\n",
    "from utils.visualization (\n",
    "    plot_attention_heads,\n",
    "    scatter_attention_and_contribution,\n",
    "    get_attn_head_patterns\n",
    ")\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = \"cuda\"\n",
    "else:\n",
    "    device = \"cpu\"\n",
    "\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_mean_reciprocal_rank(\n",
    "        logits: torch.Tensor, \n",
    "        answer_token_indices: torch.Tensor,\n",
    "        positions: torch.Tensor = None,\n",
    "        flags_tensor: torch.Tensor = None,\n",
    "        mode=\"simple\"\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Computes the Mean Reciprocal Rank (MRR) for each item in the batch.\n",
    "\n",
    "    Args:\n",
    "        logits (torch.Tensor): Logits to use.\n",
    "        answer_token_indices (torch.Tensor): Indices of the correct answer tokens.\n",
    "        positions (torch.Tensor): Positions to get logits at, one position per batch item.\n",
    "        flags_tensor (torch.Tensor): Flags indicating the grouping of tokens (used in \"groups\" mode).\n",
    "        mode (str): Mode of operation - \"simple\", \"pairs\", or \"groups\".\n",
    "\n",
    "    Returns:\n",
    "        torch.Tensor: The Mean Reciprocal Rank for the batch.\n",
    "    \"\"\"\n",
    "    logits = get_positional_logits(logits, positions)\n",
    "    probabilities = torch.softmax(logits, dim=-1)\n",
    "    mrr = torch.zeros(logits.size(0), device=logits.device)\n",
    "\n",
    "    # Mode 1: Simple\n",
    "    if mode == \"simple\":\n",
    "        correct_indices = answer_token_indices[:, 0]\n",
    "        for i in range(logits.size(0)):\n",
    "            sorted_indices = probabilities[i].sort(descending=True)[1]\n",
    "            rank = (sorted_indices == correct_indices[i]).nonzero(as_tuple=True)[0].item() + 1\n",
    "            mrr[i] = 1.0 / rank\n",
    "\n",
    "    # Mode 2: Pairs\n",
    "    elif mode == \"pairs\":\n",
    "        for i in range(logits.size(0)):\n",
    "            for pair in answer_token_indices[i]:\n",
    "                sorted_indices = probabilities[i].sort(descending=True)[1]\n",
    "                rank = (sorted_indices == pair[0]).nonzero(as_tuple=True)[0].item() + 1\n",
    "                mrr[i] += 1.0 / rank\n",
    "            mrr[i] /= answer_token_indices.size(1)\n",
    "\n",
    "    # Mode 3: Groups\n",
    "    elif mode == \"groups\":\n",
    "        assert flags_tensor is not None\n",
    "        for i in range(logits.size(0)):\n",
    "            selected_probs = probabilities[i, answer_token_indices[i]]\n",
    "            sorted_indices = selected_probs.sort(descending=True)[1]\n",
    "            correct_ranks = (flags_tensor[i] == 1).nonzero(as_tuple=True)[0]\n",
    "            ranks = torch.tensor([sorted_indices.tolist().index(rank.item()) + 1 for rank in correct_ranks])\n",
    "            mrr[i] = (1.0 / ranks).mean()\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"Invalid mode specified\")\n",
    "\n",
    "    return mrr.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \n",
    "\n",
    "# model = HookedTransformer.from_pretrained(\n",
    "#     \"EleutherAI/pythia-2.8b\",\n",
    "#     center_unembed=True,\n",
    "#     center_writing_weights=True,\n",
    "#     fold_ln=True,\n",
    "#     refactor_factored_attn_matrices=False,\n",
    "# )\n",
    "# model.set_use_hook_mlp_in(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HookedTransformer.from_pretrained(\n",
    "    \"EleutherAI/pythia-2.8b\",\n",
    "    checkpoint_value=10000,\n",
    "    center_unembed=True,\n",
    "    center_writing_weights=True,\n",
    "    fold_ln=True,\n",
    "    dtype=torch.bfloat16,\n",
    "    refactor_factored_attn_matrices=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_positional_logits(\n",
    "        logits: Float[Tensor, \"batch seq d_vocab\"],\n",
    "        positions: Float[Tensor, \"batch\"] = None\n",
    ")-> Float[Tensor, \"batch d_vocab\"]:\n",
    "    \"\"\"Gets the logits at the provided positions. If no positions are provided, the final logits are returned.\n",
    "\n",
    "    Args:\n",
    "        logits (torch.Tensor): Logits to use.\n",
    "        positions (torch.Tensor): Positions to get logits at. This should be a tensor of shape (batch_size,).\n",
    "\n",
    "    Returns:\n",
    "        torch.Tensor: Logits at the provided positions.\n",
    "    \"\"\"\n",
    "    if positions is None:\n",
    "        return logits[:, -1, :]\n",
    "    \n",
    "    return logits[range(logits.size(0)), positions, :]\n",
    "\n",
    "\n",
    "def compute_logit_diff(\n",
    "        logits: Float[Tensor, \"batch seq d_vocab\"], \n",
    "        answer_token_indices: Float[Tensor, \"batch num_answers\"],\n",
    "        positions: Float[Tensor, \"batch\"] = None,\n",
    "        flags_tensor: torch.Tensor = None,\n",
    "        per_prompt=False,\n",
    "        mode=\"simple\"\n",
    ")-> Float[Tensor, \"batch num_answers\"]:\n",
    "    \"\"\"Computes the difference between a correct and incorrect logit (or mean of a group of logits) for each item in the batch.\n",
    "\n",
    "    Takes the full logits, and the indices of the tokens to compare. These indices can be of multiple types as follows:\n",
    "\n",
    "    - Simple: The tensor should be of shape (batch_size, 2), where the first index in the third dimension is the correct token index,\n",
    "        and the second index is the incorrect token index.\n",
    "\n",
    "    - Pairs: In this mode, answer_token_indices is a 3D tensor of shape (batch, num_pairs, 2). For each pair, you'll need to compute \n",
    "             the difference between the logits at the two indices, then average these differences across each pair for every batch item.\n",
    "\n",
    "    - Groups: Here, answer_token_indices is also a 3D tensor of shape (batch, num_tokens, 2). The third dimension indicates group membership \n",
    "              (correct or incorrect). The mean logits for each group are calculated and then subtracted from each other.\n",
    "              \n",
    "\n",
    "    Args:\n",
    "        logits (torch.Tensor): Logits to use.\n",
    "        answer_token_indices (torch.Tensor): Indices of the tokens to compare.\n",
    "        positions (torch.Tensor): Positions to get logits at. Should be one position per batch item.\n",
    "\n",
    "    Returns:\n",
    "        torch.Tensor: Difference between the logits of the provided tokens.\n",
    "    \"\"\"\n",
    "    logits = get_positional_logits(logits, positions)\n",
    "    \n",
    "    # Mode 1: Simple\n",
    "    if mode == \"simple\":\n",
    "        correct_logits = logits[torch.arange(logits.size(0)), answer_token_indices[:, 0]]\n",
    "        incorrect_logits = logits[torch.arange(logits.size(0)), answer_token_indices[:, 1]]\n",
    "        logit_diff = correct_logits - incorrect_logits\n",
    "\n",
    "    # Mode 2: Pairs\n",
    "    elif mode == \"pairs\":\n",
    "        pair_diffs = logits[torch.arange(logits.size(0))[:, None], answer_token_indices[..., 0]] - \\\n",
    "                     logits[torch.arange(logits.size(0))[:, None], answer_token_indices[..., 1]]\n",
    "        logit_diff = pair_diffs.mean(dim=1)\n",
    "\n",
    "    # Mode 3: Groups\n",
    "    elif mode == \"groups\":\n",
    "        assert flags_tensor is not None\n",
    "        logit_diff = torch.zeros(logits.size(0), device=logits.device)\n",
    "\n",
    "        for i in range(logits.size(0)):\n",
    "            selected_logits = logits[i, answer_token_indices[i]]\n",
    "\n",
    "            # Calculate the logit difference using the correct/incorrect flags\n",
    "            correct_logits = selected_logits[flags_tensor[i] == 1]\n",
    "            incorrect_logits = selected_logits[flags_tensor[i] == -1]\n",
    "\n",
    "            # Handle cases where there are no correct or incorrect logits\n",
    "            if len(correct_logits) > 0:\n",
    "                correct_mean = correct_logits.mean()\n",
    "            else:\n",
    "                correct_mean = 0\n",
    "\n",
    "            if len(incorrect_logits) > 0:\n",
    "                incorrect_mean = incorrect_logits.mean()\n",
    "            else:\n",
    "                incorrect_mean = 0\n",
    "\n",
    "            logit_diff[i] = correct_mean - incorrect_mean\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"Invalid mode specified\")\n",
    "\n",
    "    return logit_diff.mean() if not per_prompt else logit_diff\n",
    "\n",
    "\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "def compute_probability_diff(\n",
    "        logits: torch.Tensor, \n",
    "        answer_token_indices: torch.Tensor,\n",
    "        positions: torch.Tensor = None,\n",
    "        flags_tensor: torch.Tensor = None,\n",
    "        per_prompt=False,\n",
    "        mode=\"simple\"\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"Computes the difference between probability of a correct and incorrect logit (or mean of a group of logits) for each item in the batch.\n",
    "\n",
    "    Takes the full logits, and the indices of the tokens to compare. These indices can be of multiple types as follows:\n",
    "\n",
    "    - Simple: The tensor should be of shape (batch_size, 2), where the first index in the third dimension is the correct token index,\n",
    "        and the second index is the incorrect token index.\n",
    "\n",
    "    - Pairs: In this mode, answer_token_indices is a 3D tensor of shape (batch, num_pairs, 2). For each pair, you'll need to compute \n",
    "             the difference between the probabilities at the two indices, then average these differences across each pair for every batch item.\n",
    "\n",
    "    - Groups: Here, answer_token_indices is also a 3D tensor of shape (batch, num_tokens, 2). The third dimension indicates group membership \n",
    "              (correct or incorrect). The mean probabilities for each group are calculated and then subtracted from each other.\n",
    "              \n",
    "\n",
    "    Args:\n",
    "        logits (torch.Tensor): Logits to use.\n",
    "        answer_token_indices (torch.Tensor): Indices of the tokens to compare.\n",
    "        positions (torch.Tensor): Positions to get logits at. Should be one position per batch item.\n",
    "\n",
    "    Returns:\n",
    "        torch.Tensor: Difference between the logits of the provided tokens.\n",
    "    \"\"\"\n",
    "    logits = get_positional_logits(logits, positions)\n",
    "    probabilities = torch.softmax(logits, dim=-1)  # Applying softmax to logits\n",
    "    print(f\"probabilities={probabilities.shape}\")\n",
    "\n",
    "    # Mode 1: Simple\n",
    "    if mode == \"simple\":\n",
    "        correct_probs = probabilities[torch.arange(logits.size(0)), answer_token_indices[:, 0]]\n",
    "        incorrect_probs = probabilities[torch.arange(logits.size(0)), answer_token_indices[:, 1]]\n",
    "        prob_diff = correct_probs - incorrect_probs\n",
    "\n",
    "    # Mode 2: Pairs\n",
    "    elif mode == \"pairs\":\n",
    "        pair_diffs = probabilities[torch.arange(logits.size(0))[:, None], answer_token_indices[..., 0]] - \\\n",
    "                     probabilities[torch.arange(logits.size(0))[:, None], answer_token_indices[..., 1]]\n",
    "        prob_diff = pair_diffs.mean(dim=1)\n",
    "\n",
    "    # Mode 3: Groups\n",
    "    elif mode == \"groups\":\n",
    "        # Initialize tensors to store the probability differences for each batch item\n",
    "        assert flags_tensor is not None\n",
    "        prob_diff = torch.zeros(logits.size(0), device=logits.device)\n",
    "\n",
    "        for i in range(logits.size(0)):\n",
    "            # Select the probabilities for the token IDs of this batch item\n",
    "            selected_probs = probabilities[i, answer_token_indices[i]]\n",
    "\n",
    "            # Calculate the probability difference using the correct/incorrect flags\n",
    "            correct_probs = selected_probs[flags_tensor[i] == 1]\n",
    "            incorrect_probs = selected_probs[flags_tensor[i] == -1]\n",
    "\n",
    "            # Handle cases where there are no correct or incorrect tokens\n",
    "            if len(correct_probs) > 0:\n",
    "                correct_mean = correct_probs.mean()\n",
    "            else:\n",
    "                correct_mean = 0\n",
    "\n",
    "            if len(incorrect_probs) > 0:\n",
    "                incorrect_mean = incorrect_probs.mean()\n",
    "            else:\n",
    "                incorrect_mean = 0\n",
    "\n",
    "            prob_diff[i] = correct_mean - incorrect_mean\n",
    "\n",
    "    # Mode 4: Group Sum\n",
    "    elif mode == \"group_sum\":\n",
    "        assert flags_tensor is not None\n",
    "        prob_diff = torch.zeros(logits.size(0), device=logits.device)\n",
    "\n",
    "        for i in range(logits.size(0)):\n",
    "            selected_probs = probabilities[i, answer_token_indices[i]]\n",
    "\n",
    "            # Calculate the sum of probabilities using the correct/incorrect flags\n",
    "            correct_sum = selected_probs[flags_tensor[i] == 1].sum()\n",
    "            incorrect_sum = selected_probs[flags_tensor[i] == -1].sum()\n",
    "\n",
    "            prob_diff[i] = incorrect_sum - correct_sum\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"Invalid mode specified\")\n",
    "\n",
    "    return prob_diff.mean() if not per_prompt else prob_diff\n",
    "\n",
    "\n",
    "def compute_probability_mass(\n",
    "        logits: torch.Tensor, \n",
    "        answer_token_indices: torch.Tensor,\n",
    "        positions: torch.Tensor = None,\n",
    "        flags_tensor: torch.Tensor = None,\n",
    "        group=\"correct\",\n",
    "        mode=\"simple\"\n",
    ") -> torch.Tensor:\n",
    "    logits = get_positional_logits(logits, positions)\n",
    "    probabilities = torch.softmax(logits, dim=-1)\n",
    "\n",
    "    # Determine the flag value based on the specified group\n",
    "    flag_value = 1 if group == \"correct\" else -1\n",
    "\n",
    "    # Mode logic\n",
    "    if mode == \"simple\":\n",
    "        selected_indices = answer_token_indices[:, 0] if group == \"correct\" else answer_token_indices[:, 1]\n",
    "        group_probs = probabilities[torch.arange(logits.size(0)), selected_indices]\n",
    "\n",
    "    elif mode == \"pairs\":\n",
    "        group_probs = torch.zeros(logits.size(0), device=logits.device)\n",
    "        for i in range(logits.size(0)):\n",
    "            for pair in answer_token_indices[i]:\n",
    "                selected_index = pair[0] if group == \"correct\" else pair[1]\n",
    "                group_probs[i] += probabilities[i, selected_index]\n",
    "            group_probs[i] /= answer_token_indices.size(1)\n",
    "\n",
    "    elif mode == \"groups\":\n",
    "        assert flags_tensor is not None\n",
    "        group_probs = torch.zeros(logits.size(0), device=logits.device)\n",
    "\n",
    "        for i in range(logits.size(0)):\n",
    "            selected_probs = probabilities[i, answer_token_indices[i]]\n",
    "            group_probs[i] = selected_probs[flags_tensor[i] == flag_value].mean()\n",
    "\n",
    "    elif mode == \"group_sum\":\n",
    "        assert flags_tensor is not None\n",
    "        group_probs = torch.zeros(logits.size(0), device=logits.device)\n",
    "\n",
    "        for i in range(logits.size(0)):\n",
    "            selected_probs = probabilities[i, answer_token_indices[i]]\n",
    "            group_probs[i] = selected_probs[flags_tensor[i] == flag_value].sum()\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"Invalid mode specified\")\n",
    "\n",
    "    return group_probs.mean()\n",
    "\n",
    "\n",
    "\n",
    "def compute_rank_0_rate(\n",
    "        logits: torch.Tensor, \n",
    "        answer_token_indices: torch.Tensor,\n",
    "        positions: torch.Tensor = None,\n",
    "        flags_tensor: torch.Tensor = None,\n",
    "        group=\"correct\",\n",
    "        mode=\"simple\"\n",
    ") -> torch.Tensor:\n",
    "    logits = get_positional_logits(logits, positions)\n",
    "    probabilities = torch.softmax(logits, dim=-1)\n",
    "\n",
    "    # Mode logic\n",
    "    if mode == \"simple\":\n",
    "        top_rank_indices = probabilities.argmax(dim=-1)\n",
    "        correct_indices = answer_token_indices[:, 0] if group == \"correct\" else answer_token_indices[:, 1]\n",
    "        rank_0_rate = (top_rank_indices == correct_indices).float().mean()\n",
    "\n",
    "    elif mode == \"pairs\":\n",
    "        rank_0_rate = torch.zeros(logits.size(0), device=logits.device)\n",
    "        for i in range(logits.size(0)):\n",
    "            for pair in answer_token_indices[i]:\n",
    "                top_rank_index = probabilities[i].argmax()\n",
    "                correct_index = pair[0] if group == \"correct\" else pair[1]\n",
    "                rank_0_rate[i] += (top_rank_index == correct_index).float()\n",
    "            rank_0_rate[i] /= answer_token_indices.size(1)\n",
    "\n",
    "    elif mode == \"groups\":\n",
    "        assert flags_tensor is not None\n",
    "        rank_0_rate = torch.zeros(logits.size(0), device=logits.device)\n",
    "\n",
    "        for i in range(logits.size(0)):\n",
    "            selected_probs = probabilities[i, answer_token_indices[i]]\n",
    "            top_rank_id = selected_probs.argmax()\n",
    "            rank_0_rate[i] = (flags_tensor[i, top_rank_id] == 1).float() if group == \"correct\" else \\\n",
    "                             (flags_tensor[i, top_rank_id] == -1).float()\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"Invalid mode specified\")\n",
    "\n",
    "    return rank_0_rate.mean()\n",
    "\n",
    "\n",
    "import torch\n",
    "\n",
    "def compute_max_group_rank_reciprocal(\n",
    "        logits: torch.Tensor, \n",
    "        answer_token_indices: torch.Tensor,\n",
    "        positions: torch.Tensor = None,\n",
    "        flags_tensor: torch.Tensor = None,\n",
    "        mode=\"simple\"\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Computes the mean of the reciprocal of the maximum rank of members of the correct group across different modes.\n",
    "\n",
    "    Args:\n",
    "        logits (torch.Tensor): Logits to use.\n",
    "        answer_token_indices (torch.Tensor): Indices of the tokens for comparison or grouping.\n",
    "        positions (torch.Tensor): Positions to get logits at, one position per batch item.\n",
    "        flags_tensor (torch.Tensor): Flags indicating the grouping of tokens (used in \"groups\" mode).\n",
    "        mode (str): Operation mode - \"simple\", \"pairs\", or \"groups\".\n",
    "\n",
    "    Returns:\n",
    "        torch.Tensor: The mean of the reciprocal of the maximum rank of correct group members.\n",
    "    \"\"\"\n",
    "    logits = get_positional_logits(logits, positions)\n",
    "    probabilities = torch.softmax(logits, dim=-1)\n",
    "    batch_size = logits.size(0)\n",
    "\n",
    "    # Initialize tensor to hold the reciprocal of the maximum rank for each item in the batch\n",
    "    reciprocal_max_rank = torch.zeros(batch_size, device=logits.device)\n",
    "\n",
    "    if mode == \"simple\":\n",
    "        for i in range(batch_size):\n",
    "            correct_index = answer_token_indices[i, 0]\n",
    "            sorted_indices = probabilities[i].sort(descending=True)[1]\n",
    "            rank = (sorted_indices == correct_index).nonzero(as_tuple=True)[0].item() + 1\n",
    "            reciprocal_max_rank[i] = 1.0 / rank\n",
    "\n",
    "    elif mode == \"pairs\":\n",
    "        for i in range(batch_size):\n",
    "            pair_ranks = []\n",
    "            for pair in answer_token_indices[i]:\n",
    "                # Only consider the first index in each pair as correct\n",
    "                correct_index = pair[0]\n",
    "                sorted_indices = probabilities[i].sort(descending=True)[1]\n",
    "                rank = (sorted_indices == correct_index).nonzero(as_tuple=True)[0].item() + 1\n",
    "                pair_ranks.append(rank)\n",
    "            # Use the max rank from pairs\n",
    "            max_rank = min(pair_ranks)\n",
    "            reciprocal_max_rank[i] = 1.0 / max_rank\n",
    "\n",
    "    elif mode == \"groups\":\n",
    "        for i in range(batch_size):\n",
    "            group_ranks = []\n",
    "            for j, flag in enumerate(flags_tensor[i]):\n",
    "                if flag == 1:  # Correct group member\n",
    "                    correct_index = answer_token_indices[i, j]\n",
    "                    sorted_indices = probabilities[i].sort(descending=True)[1]\n",
    "                    rank = (sorted_indices == correct_index).nonzero(as_tuple=True)[0].item() + 1\n",
    "                    group_ranks.append(rank)\n",
    "            # Use the max rank from correct group members\n",
    "            if group_ranks:\n",
    "                max_rank = min(group_ranks)\n",
    "                reciprocal_max_rank[i] = 1.0 / max_rank\n",
    "            else:\n",
    "                reciprocal_max_rank[i] = 0  # Handle case with no correct answers\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"Invalid mode specified\")\n",
    "\n",
    "    return reciprocal_max_rank.mean()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_compute_max_group_rank_reciprocal():\n",
    "    # Define a helper function to simplify the test cases creation\n",
    "    def create_logits_and_indices(logits_values, correct_indices, flags=None):\n",
    "        logits = torch.tensor(logits_values, dtype=torch.float).unsqueeze(0).unsqueeze(1)  # Add batch dimension if needed\n",
    "        answer_token_indices = torch.tensor(correct_indices, dtype=torch.long).unsqueeze(1)  # Adjust dimensions as needed\n",
    "        flags_tensor = torch.tensor(flags, dtype=torch.long).unsqueeze(1) if flags is not None else None\n",
    "        return logits, answer_token_indices, flags_tensor\n",
    "\n",
    "    # Simple Mode Test Case\n",
    "    logits, answer_token_indices, _ = create_logits_and_indices([0.1, 0.2, 0.7, 0.6], [2])\n",
    "    mrr_simple = compute_max_group_rank_reciprocal(logits, answer_token_indices, mode=\"simple\")\n",
    "    print(f\"Simple mode MRR: {mrr_simple}\")\n",
    "\n",
    "    # Pairs Mode Test Case\n",
    "    logits, answer_token_indices, _ = create_logits_and_indices([[0.1, 0.2], [0.7, 0.6], [0.4, 0.5]], [[[2, 1], [0, 3]]])\n",
    "    mrr_pairs = compute_max_group_rank_reciprocal(logits, answer_token_indices, mode=\"pairs\")\n",
    "    print(f\"Pairs mode MRR: {mrr_pairs}\")\n",
    "\n",
    "    # Groups Mode Test Case\n",
    "    logits, answer_token_indices, flags_tensor = create_logits_and_indices([0.1, 0.2, 0.7, 0.6], [0, 1, 2, 3], [1, -1, 1, -1])\n",
    "    mrr_groups = compute_max_group_rank_reciprocal(logits, answer_token_indices, flags_tensor=flags_tensor, mode=\"groups\")\n",
    "    print(f\"Groups mode MRR: {mrr_groups}\")\n",
    "\n",
    "# Execute the test function\n",
    "test_compute_max_group_rank_reciprocal()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## IOI"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Old"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _logits_to_mean_logit_diff(logits: Float[Tensor, \"batch seq d_vocab\"], ioi_dataset: IOIDataset, per_prompt=False):\n",
    "    '''\n",
    "    Returns logit difference between the correct and incorrect answer.\n",
    "\n",
    "    If per_prompt=True, return the array of differences rather than the average.\n",
    "    '''\n",
    "\n",
    "    # Only the final logits are relevant for the answer\n",
    "    # Get the logits corresponding to the indirect object / subject tokens respectively\n",
    "    io_logits: Float[Tensor, \"batch\"] = logits[range(logits.size(0)), ioi_dataset.word_idx[\"end\"], ioi_dataset.io_tokenIDs]\n",
    "    print(io_logits.shape)\n",
    "    s_logits: Float[Tensor, \"batch\"] = logits[range(logits.size(0)), ioi_dataset.word_idx[\"end\"], ioi_dataset.s_tokenIDs]\n",
    "    # Find logit difference\n",
    "    answer_logit_diff = io_logits - s_logits\n",
    "    return answer_logit_diff if per_prompt else answer_logit_diff.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 70\n",
    "ioi_dataset, abc_dataset, _, _, _ = generate_data_and_caches(model, N, verbose=True)\n",
    "clean_toks = ioi_dataset.toks.to(device)\n",
    "corrupted_toks = abc_dataset.toks.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logits = model(clean_toks)\n",
    "_logits_to_mean_logit_diff(logits, ioi_dataset, per_prompt=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### New"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = UniversalPatchingDataset.from_ioi(model, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logits = model(ds.toks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_logit_diff(logits, ds.answer_toks, ds.positions, per_prompt=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_probability_diff(logits, ds.answer_toks, ds.positions, per_prompt=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_mean_reciprocal_rank(logits, ds.answer_toks, ds.positions, mode=\"simple\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_max_group_rank_reciprocal(logits, ds.answer_toks, ds.positions, mode=\"simple\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Greater-Than"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Old"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.greater_than_dataset import get_prob_diff, YearDataset, get_valid_years"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_old = YearDataset(get_valid_years(model.tokenizer, 1100, 1800), 1000, Path(\"data/potential_nouns.txt\"), model.tokenizer)\n",
    "\n",
    "# def batch(iterable, n:int=1):\n",
    "#    current_batch = []\n",
    "#    for item in iterable:\n",
    "#        current_batch.append(item)\n",
    "#        if len(current_batch) == n:\n",
    "#            yield current_batch\n",
    "#            current_batch = []\n",
    "#    if current_batch:\n",
    "#        yield current_batch\n",
    "\n",
    "# clean = list(batch(ds.good_sentences, 9))\n",
    "# labels = list(batch(ds.years_YY, 9))\n",
    "# corrupted = list(batch(ds.bad_sentences, 9))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "IDX = 768\n",
    "#model.to_str_tokens(ds.good_toks[IDX]), model.to_str_tokens(ds.bad_toks[IDX])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def prepare_indices_for_prob_diff(tokenizer, years):\n",
    "    \"\"\"\n",
    "    Prepares two tensors for use with the compute_probability_diff function in 'groups' mode.\n",
    "\n",
    "    Args:\n",
    "        tokenizer (PreTrainedTokenizer): Tokenizer to convert years to token indices.\n",
    "        years (torch.Tensor): Tensor containing the year for each prompt in the batch.\n",
    "\n",
    "    Returns:\n",
    "        torch.Tensor, torch.Tensor: Two tensors, one for token IDs and one for correct/incorrect flags.\n",
    "    \"\"\"\n",
    "\n",
    "    # Get the indices for years 00 to 99\n",
    "    year_indices = get_year_indices(tokenizer)  # Tensor of size 100 with token IDs for years\n",
    "\n",
    "    # Prepare tensors to store token IDs and correct/incorrect flags\n",
    "    token_ids_tensor = year_indices.repeat(years.size(0), 1)  # Repeat the year_indices for each batch item\n",
    "    flags_tensor = torch.zeros_like(token_ids_tensor)  # Initialize the flags tensor with zeros\n",
    "\n",
    "    for i, year in enumerate(years):\n",
    "        # Mark years greater than the given year as correct (1)\n",
    "        flags_tensor[i, year + 1:] = 1\n",
    "        # Mark years less than or equal to the given year as incorrect (-1)\n",
    "        flags_tensor[i, :year + 1] = -1\n",
    "\n",
    "    return token_ids_tensor, flags_tensor\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#input_length = 1 + len(model.tokenizer(ds.good_sentences[0])[0])\n",
    "prob_diff = get_prob_diff(model.tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.circuit_utils import run_with_batches\n",
    "\n",
    "clean_logits = run_with_batches(model, ds_old.good_toks.to(device), batch_size=20, max_seq_len=12)\n",
    "corrupted_logits = run_with_batches(model, ds_old.bad_toks.to(device), batch_size=20, max_seq_len=12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prob_diff(clean_logits,ds_old.years_YY)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### New"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = UniversalPatchingDataset.from_greater_than(model, 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logits = model(ds.toks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_probability_diff(logits, ds.answer_toks, flags_tensor=ds.group_flags, mode=\"group_sum\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_mean_reciprocal_rank(logits, ds.answer_toks, ds.positions, ds.group_flags, mode=\"groups\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_max_group_rank_reciprocal(logits, ds.answer_toks, ds.positions, ds.group_flags, mode=\"groups\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sentiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.sentiment_datasets import get_dataset, PromptType, get_prompts\n",
    "from utils.circuit_analysis import get_logit_diff as get_logit_diff_ca"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classification"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Old"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_type = PromptType.CLASSIFICATION_4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = get_dataset(model, device, prompt_type=ds_type)\n",
    "ds.all_prompts[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds.clean_tokens.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds.answer_tokens.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_logits = model(ds.clean_tokens.to(device))\n",
    "corrupted_logits = model(ds.corrupted_tokens.to(device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds.answer_tokens.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.metrics import CircuitMetric\n",
    "logit_diff_metric = CircuitMetric(\"logit_diff_multi\", partial(get_logit_diff_ca, answer_tokens=ds.answer_tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logit_diff_metric(clean_logits), logit_diff_metric(corrupted_logits)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### New"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = UniversalPatchingDataset.from_sentiment(model, \"class\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logits = model(ds.toks)\n",
    "flipped_logits = model(ds.flipped_toks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_logit_diff(logits, ds.answer_toks, mode=\"pairs\"), compute_logit_diff(flipped_logits, ds.answer_toks, mode=\"pairs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_mean_reciprocal_rank(logits, ds.answer_toks, ds.positions, mode=\"pairs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_max_group_rank_reciprocal(logits, ds.answer_toks, ds.positions, mode=\"pairs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_lens.utils import test_prompt\n",
    "for prompt_tokens in ds.toks:\n",
    "    prompt = model.to_string(prompt_tokens[1:])\n",
    "    test_prompt(prompt, \" Positive\", model, top_k=5)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Continuation"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Old"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_type = PromptType.COMPLETION_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = get_dataset(model, device, prompt_type=ds_type)\n",
    "ds.all_prompts[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds.clean_tokens.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds.answer_tokens.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_logits = model(ds.clean_tokens.to(device))\n",
    "corrupted_logits = model(ds.corrupted_tokens.to(device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds.answer_tokens.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.metrics import CircuitMetric\n",
    "logit_diff_metric = CircuitMetric(\"logit_diff_multi\", partial(get_logit_diff_ca, answer_tokens=ds.answer_tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logit_diff_metric(clean_logits), logit_diff_metric(corrupted_logits)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### New"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = UniversalPatchingDataset.from_sentiment(model, \"cont\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logits = model(ds.toks)\n",
    "compute_logit_diff(logits, ds.answer_toks, mode=\"pairs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.metrics import compute_accuracy\n",
    "compute_accuracy(logits, ds.answer_toks, mode=\"pairs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_mean_reciprocal_rank(logits, ds.answer_toks, mode=\"pairs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_max_group_rank_reciprocal(logits, ds.answer_toks, mode=\"pairs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_lens.utils import test_prompt\n",
    "for prompt_tokens in ds.toks:\n",
    "    prompt = model.to_string(prompt_tokens[1:])\n",
    "    test_prompt(prompt, \"bad\", model, top_k=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds.ans"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import random\n",
    "from torch.utils.data import DataLoader\n",
    "from datasets import Dataset, concatenate_datasets, load_from_disk\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "model.name = \"EleutherAI/pythia-2.8b\"\n",
    "\n",
    "\n",
    "def filter_function(example, model):\n",
    "    prompt = model.to_tokens(example['text'] + \" Review Sentiment:\", prepend_bos=False)\n",
    "    answer = torch.tensor([29071, 32725]).unsqueeze(0).unsqueeze(0).to(device) if example['label'] == 1 else torch.tensor([32725, 29071]).unsqueeze(0).unsqueeze(0).to(device)\n",
    "    #print(answer.shape)\n",
    "    logits = model(prompt, return_type=\"logits\")\n",
    "    logit_diff = compute_logit_diff(logits, answer, mode=\"pairs\")\n",
    "    \n",
    "    # Determine if the top answer (index 0) token is in top 10 logits\n",
    "    _, top_indices = logits.topk(10, dim=-1)  # Get indices of top 10 logits\n",
    "    top_answer_token = answer[0, 0, 0]  # Assuming answer is of shape (1, 1, 2) and the top answer token is at index 0\n",
    "    is_top_answer_in_top_10_logits = (top_indices == top_answer_token).any()\n",
    "    \n",
    "    # Add a new field 'keep_example' to the example\n",
    "    example['keep_example'] = (logit_diff > 0.0) and is_top_answer_in_top_10_logits.item()\n",
    "    return example\n",
    "\n",
    "\n",
    "def concatenate_classification_prompts(examples):\n",
    "    return {\"text\": (examples['text'] + \" Review Sentiment:\")}\n",
    "\n",
    "\n",
    "def get_final_pos_index(examples):\n",
    "    return {'final_pos_index': examples[\"attention_mask\"].sum() - 1}\n",
    "\n",
    "\n",
    "def tokenize_function(examples, tokenizer):\n",
    "    return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True, max_length=64)\n",
    "\n",
    "\n",
    "def find_dataset_positions(example, token_id=13):\n",
    "    # Create a tensor of zeros with the same shape as example['tokens']\n",
    "    positions = torch.zeros_like(torch.tensor(example['tokens']), dtype=torch.int)\n",
    "\n",
    "    # Find positions where tokens match the given token_id\n",
    "    positions[example['tokens'] == token_id] = 1\n",
    "    has_token = True if positions.sum() > 0 else False\n",
    "\n",
    "    return {'positions': positions, 'has_token': has_token}\n",
    "\n",
    "\n",
    "def convert_answers(example, pos_answer_id=29071, neg_answer_id=32725):\n",
    "    if example['label'] == 1:\n",
    "        answers = torch.tensor([pos_answer_id, neg_answer_id])\n",
    "    else:\n",
    "        answers = torch.tensor([neg_answer_id, pos_answer_id])\n",
    "\n",
    "    return {'answers': answers}\n",
    "\n",
    "\n",
    "def get_random_subset(dataset, n):\n",
    "    total_size = len(dataset)\n",
    "    random_indices = random.sample(range(total_size), n)\n",
    "    return dataset.select(random_indices)\n",
    "\n",
    "\n",
    "def prepare_sst_for_model(\n",
    "        model: HookedTransformer,\n",
    "        dataset_name: str = \"sst2\", \n",
    "        batch_size: int = 5,\n",
    "        pad_token_id: int = 1, \n",
    "        pos_answer_id: int = 29071, \n",
    "        neg_answer_id: int = 32725\n",
    "    ) -> Tuple[DataLoader, DataLoader, DataLoader]:\n",
    "    # Define the batch size\n",
    "    BATCH_SIZE = batch_size\n",
    "\n",
    "    sst_data = load_from_disk(dataset_name)\n",
    "\n",
    "    # Use the map function to apply the filter_function\n",
    "    filter_function_for_model = partial(filter_function, model=model)\n",
    "    sst_data_with_flag_train = sst_data['train'].map(filter_function_for_model, keep_in_memory=True)\n",
    "    sst_data_with_flag_dev = sst_data['dev'].map(filter_function_for_model, keep_in_memory=True)\n",
    "    sst_data_with_flag_test = sst_data['test'].map(filter_function_for_model, keep_in_memory=True)\n",
    "    #sst_data_with_flag = concatenate_datasets([sst_data['train'], sst_data['dev'], sst_data['test']])\n",
    "    sst_data_with_flag = concatenate_datasets([sst_data_with_flag_train, sst_data_with_flag_dev, sst_data_with_flag_test])\n",
    "    #sst_data_with_flag = sst_data_with_flag_dev\n",
    "\n",
    "    # Use the filter function to keep only the examples where 'keep_example' is True\n",
    "    sst_zero_shot = sst_data_with_flag.filter(lambda x: x['keep_example'])\n",
    "    # print number of items in dataset\n",
    "    print(f\"Number of items in dataset: {len(sst_zero_shot)}\")\n",
    "    # save dataset\n",
    "    #new model name without slashes\n",
    "    model_abbr = re.sub(r'/', '_', model.name)\n",
    "    sst_zero_shot.save_to_disk(f\"sst_zero_shot_{model_abbr}\")\n",
    "\n",
    "    # Load a tokenizer (you'll need to specify the appropriate model)\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model.name)\n",
    "    # set padding token\n",
    "    tokenizer.pad_token = model.to_string([pad_token_id])\n",
    "\n",
    "    dataset = sst_zero_shot.map(concatenate_classification_prompts, batched=False)\n",
    "    tokenizer_function_for_model = partial(tokenize_function, tokenizer=tokenizer)\n",
    "    dataset = dataset.map(tokenizer_function_for_model, batched=False)\n",
    "    \n",
    "    convert_answers_for_model = partial(convert_answers, pos_answer_id=pos_answer_id, neg_answer_id=neg_answer_id)\n",
    "    dataset = dataset.map(convert_answers_for_model, batched=False)\n",
    "    dataset = dataset.rename_column(\"input_ids\", \"tokens\")\n",
    "    dataset.set_format(type=\"torch\", columns=[\"tokens\", \"attention_mask\", \"label\", \"answers\"])\n",
    "    dataset = dataset.map(get_final_pos_index, batched=False)\n",
    "    dataset = dataset.map(find_dataset_positions, batched=False)\n",
    "    dataset = dataset.filter(lambda example: example['has_token']==True)\n",
    "\n",
    "    # create a subset with only positive labels\n",
    "    pos_dataset = dataset.filter(lambda example: example['label']==1)\n",
    "    neg_dataset = dataset.filter(lambda example: example['label']==0)\n",
    "    len(pos_dataset), len(neg_dataset)\n",
    "\n",
    "    subset_size = (min(len(pos_dataset), len(neg_dataset)) // BATCH_SIZE) * BATCH_SIZE\n",
    "\n",
    "    pos_subset = get_random_subset(pos_dataset, subset_size)\n",
    "    neg_subset = get_random_subset(neg_dataset, subset_size)\n",
    "    balanced_subset = concatenate_datasets([pos_subset, neg_subset])\n",
    "    # randomize the order of balanced_subset\n",
    "    balanced_subset = balanced_subset.shuffle(len(balanced_subset))\n",
    "\n",
    "    balanced_subset.save_to_disk(f\"sst_zero_shot_balanced_{model_abbr}\")\n",
    "\n",
    "\n",
    "    print(f\"Number of items in pos dataset: {len(pos_subset)}\")\n",
    "    print(f\"Number of items in neg dataset: {len(neg_subset)}\")\n",
    "    print(f\"Number of items in balanced dataset: {len(balanced_subset)}\")\n",
    "    return pos_subset, neg_subset, balanced_subset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_ds, neg_ds, balanced_ds = prepare_sst_for_model(model, \"data/sst2\", 5, 1, 29071, 32725)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import random\n",
    "from torch.utils.data import DataLoader\n",
    "from datasets import Dataset, concatenate_datasets, load_from_disk\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "ds = load_from_disk(\"sst_zero_shot_balanced_EleutherAI_pythia-2.8b\")\n",
    "\n",
    "# Turn all items in ['tokens'] into a single tensor\n",
    "all_tokens = torch.cat([item['tokens'].unsqueeze(0) for item in ds], dim=0)\n",
    "all_answers = torch.cat([item['answers'].unsqueeze(0) for item in ds], dim=0)\n",
    "all_positions = torch.cat([item['final_pos_index'].unsqueeze(0) for item in ds], dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds[0]['tokens'], ds[0]['answers'], ds[0]['final_pos_index'], ds[0]['tokens'][ds[0]['final_pos_index']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_tokens.shape, all_answers.shape, all_positions.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.circuit_utils import run_with_batches\n",
    "\n",
    "logits = run_with_batches(model, all_tokens[:1000].to(device), batch_size=10, max_seq_len=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.metrics import compute_accuracy\n",
    "compute_accuracy(logits, all_answers[:1000], positions=all_positions[:1000], mode=\"simple\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HookedTransformer.from_pretrained(\n",
    "    \"EleutherAI/pythia-2.8b\",\n",
    "    checkpoint_value=10000,\n",
    "    center_unembed=True,\n",
    "    center_writing_weights=True,\n",
    "    fold_ln=True,\n",
    "    dtype=torch.bfloat16,\n",
    "    refactor_factored_attn_matrices=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.circuit_utils import run_with_batches\n",
    "from utils.metrics import compute_accuracy\n",
    "logits = run_with_batches(model, all_tokens[:100].to(device), batch_size=10, max_seq_len=64)\n",
    "compute_accuracy(logits, all_answers[:100], positions=all_positions[:100], mode=\"simple\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
