{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "import copy\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src import functional\n",
    "from src.models import ModelandTokenizer\n",
    "# from src.data import load_relation\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import transformers\n",
    "import baukit\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# state-spaces/mamba-2.8b | state-spaces/mamba-2.8b-slimpj\n",
    "MODEL_PATH = \"state-spaces/mamba-2.8b\"  \n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_path=MODEL_PATH, \n",
    "    torch_dtype=torch.float32\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_template = mt.tokenizer.eos_token + \" {} is located in the city of\"\n",
    "subject = \"Louvre\"\n",
    "alt_subject = \"The Space Needle\"\n",
    "\n",
    "functional.predict_next_token(\n",
    "    mt = mt,\n",
    "    prompt = [\n",
    "        prompt_template.format(subject),\n",
    "        prompt_template.format(alt_subject),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import json\n",
    "import logging\n",
    "import types\n",
    "from collections import defaultdict\n",
    "from typing import Literal, Optional, get_args\n",
    "\n",
    "import baukit\n",
    "import numpy\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "import src.tokens as tokenizer_utils\n",
    "\n",
    "# from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel as Mamba\n",
    "from mamba_minimal.model import Mamba\n",
    "from src import functional\n",
    "from src.functional import (\n",
    "    decode_tokens,\n",
    "    find_token_range,\n",
    "    make_inputs,\n",
    "    predict_from_input,\n",
    ")\n",
    "from src.hooking.mamba import MambaBlock_Hook_Points, MambaBlockForwardPatcher\n",
    "from src.models import ModelandTokenizer, is_mamba_variant"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.typing import Mamba\n",
    "type(mt.model), Mamba"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "isinstance(mt.model, Mamba)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def path_ablation(\n",
    "    mt,\n",
    "    inp,  # A set of inputs\n",
    "    residual_states_to_patch,  # A list of (token index, layername) triples to restore\n",
    "    answers_t,  # Answer probabilities to collect\n",
    "    tokens_to_mix,  # Range of tokens to corrupt (begin, end)\n",
    "    noise=0.1,  # Level of noise to add\n",
    "    uniform_noise=False,\n",
    "    replace=False,  # True to replace with instead of add noise\n",
    "    alt_subj_patching: bool = False,  # If True, will assume inp shape to be (2, L). Uncorrupted activations with inp[0] will be patched in the run with inp[1]. Will not corrupt the embeddings\n",
    "    block_states_to_unpatch: (list) = [],  # A list of (token index, layername) triples to restore in the uncorrupted run\n",
    "    hook_to_unpatch: Optional[\n",
    "        MambaBlock_Hook_Points\n",
    "    ] = None,\n",
    "):\n",
    "    assert is_mamba_variant(mt), \"This function is only for Mamba models. check trace_with_repatch in the original implementation for other models\"\n",
    "    embed_layername = mt.embedder_name\n",
    "\n",
    "    rs = numpy.random.RandomState(1)  # For reproducibility, use pseudorandom noise\n",
    "    if uniform_noise:\n",
    "        prng = lambda *shape: rs.uniform(-1, 1, shape)\n",
    "    else:\n",
    "        prng = lambda *shape: rs.randn(*shape)\n",
    "\n",
    "    residual_patch_spec = defaultdict(list)\n",
    "    for t, l in residual_states_to_patch:\n",
    "        residual_patch_spec[l].append(t)\n",
    "    \n",
    "    block_unpatch_spec = defaultdict(list)\n",
    "    for t, l in block_states_to_unpatch:\n",
    "        block_unpatch_spec[l].append(t)\n",
    "\n",
    "    def untuple(x):\n",
    "        return x[0] if isinstance(x, tuple) else x\n",
    "\n",
    "    # Define the model-patching rule.\n",
    "    if isinstance(noise, float):\n",
    "        noise_fn = lambda x: noise * x\n",
    "    else:\n",
    "        noise_fn = noise\n",
    "\n",
    "    def patch_rep(repr, layer):\n",
    "        assert first_run_residual_activations is not None\n",
    "\n",
    "        if layer == embed_layername and alt_subj_patching == False:\n",
    "            # If requested, we corrupt a range of token embeddings on batch items x[1:]\n",
    "            if tokens_to_mix is not None:\n",
    "                b, e = tokens_to_mix\n",
    "                noise_data = noise_fn(\n",
    "                    torch.from_numpy(prng(repr.shape[0] - 1, e - b, repr.shape[2]))\n",
    "                ).to(repr.device)\n",
    "                if replace:\n",
    "                    repr[1:, b:e] = noise_data\n",
    "                else:\n",
    "                    repr[1:, b:e] += noise_data\n",
    "            return repr\n",
    "        \n",
    "        if layer not in residual_patch_spec:\n",
    "            return repr\n",
    "        \n",
    "        # If this layer is in the patch_spec, restore the uncorrupted hidden state\n",
    "        # for selected tokens from the first run\n",
    "        h = untuple(repr)\n",
    "        for t in residual_patch_spec[layer]:\n",
    "            h[1:, t] = untuple(first_run_residual_activations[layer].output)[0, t]\n",
    "\n",
    "        return repr\n",
    "\n",
    "    mt.reset_forward()  # reset the model to use default forward functions\n",
    "\n",
    "    # need to run twice to store corrupted and uncorrupted activations\n",
    "    # of hooks inside the MambaBlock (Is there a better way?)\n",
    "    interested_layers = list(set(\n",
    "        list(residual_patch_spec.keys()) + list(block_unpatch_spec.keys())\n",
    "    ))\n",
    "    # print(f\"{interested_layers=}\")\n",
    "    first_run_hook_activations = {layer: {} for layer in interested_layers}\n",
    "    # print(f\"{first_run_hook_activations.keys()=}\")\n",
    "    for layer in interested_layers:\n",
    "        block = baukit.get_module(\n",
    "            mt.model, name=layer + \".mixer\"\n",
    "        )  # MambaBlock naming format\n",
    "        block.forward = types.MethodType(\n",
    "            MambaBlockForwardPatcher(\n",
    "                retainer=first_run_hook_activations[layer],\n",
    "            ),  # get everything for the uncorrupted run\n",
    "            block,\n",
    "        )\n",
    "    with torch.inference_mode(), baukit.TraceDict(\n",
    "        mt.model,\n",
    "        [embed_layername] + list(residual_patch_spec.keys()),\n",
    "        edit_output = None # No intervention on the clean run\n",
    "    ) as td:\n",
    "        mt(**inp)\n",
    "\n",
    "    first_run_residual_activations = td\n",
    "\n",
    "    # print(f\"{first_run_residual_activations.keys()=}\")\n",
    "\n",
    "    # ------------------------------------------------------\n",
    "    # second run with patching / unpatching\n",
    "    mt.reset_forward()  # reset the model to use default forward functions\n",
    "    for layer in block_unpatch_spec:\n",
    "        block = baukit.get_module(mt.model, name=layer + \".mixer\")\n",
    "\n",
    "        # restore the corrupted activations inside the MambaBlock here\n",
    "        # the patching with clean residual states will be done in patch_rep\n",
    "        cur_patch_spec = {\n",
    "            token_idx: first_run_hook_activations[layer][hook_to_unpatch][1, t] # corrupted activations\n",
    "            for token_idx in block_unpatch_spec[layer]\n",
    "        }  \n",
    "        block.forward = types.MethodType(\n",
    "            MambaBlockForwardPatcher(\n",
    "                patch_spec=cur_patch_spec,\n",
    "                patch_hook=hook_to_unpatch,\n",
    "            ),\n",
    "            block,\n",
    "        )\n",
    "    with torch.inference_mode(), baukit.TraceDict(\n",
    "        mt.model,\n",
    "        [embed_layername] + list(residual_patch_spec.keys()), # make sure to patch from the clean activations\n",
    "        edit_output=patch_rep,  # passing to patch_rep to noise the embeddings only. Restoring the states is done in the MambaBlockForwardPatcher forwards\n",
    "    ):\n",
    "        outputs_exp = mt.model(input_ids=inp[\"input_ids\"])\n",
    "    # ------------------------------------------------------\n",
    "    mt.reset_forward()  # reset the model to use default forward functions\n",
    "\n",
    "    # We report softmax probabilities for the answers_t token predictions of interest.\n",
    "    logits = outputs_exp.logits if hasattr(outputs_exp, \"logits\") else outputs_exp\n",
    "    probs = torch.softmax(logits[1:, -1, :], dim=1).mean(dim=0)[answers_t]\n",
    "\n",
    "    return probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trace_important_states_with_ablation(\n",
    "    mt,\n",
    "    inp,\n",
    "    e_range,\n",
    "    answer_t,\n",
    "    noise=0.1,\n",
    "    uniform_noise=False,\n",
    "    replace=False,\n",
    "    token_range=None,\n",
    "    ablate_mambahook: Optional[MambaBlock_Hook_Points] = None,\n",
    "    alt_subj_patching: bool = False,\n",
    "):\n",
    "    ntoks = inp[\"input_ids\"].shape[1]\n",
    "    table = []\n",
    "    if token_range is None:\n",
    "        token_range = range(ntoks)\n",
    "\n",
    "    for tnum in token_range:\n",
    "        block_states_to_unpatch = []\n",
    "        if ablate_mambahook is not None:\n",
    "            block_states_to_unpatch = [\n",
    "                (tnum, mt.layer_name_format.format(layer)) for layer in range(0, mt.n_layer)\n",
    "            ]\n",
    "\n",
    "        row = []\n",
    "        for layer in range(mt.n_layer):\n",
    "            r = path_ablation(\n",
    "                mt = mt,\n",
    "                inp = inp,\n",
    "                residual_states_to_patch=[(tnum, mt.layer_name_format.format(layer))],\n",
    "                answers_t=answer_t,\n",
    "                tokens_to_mix=e_range,\n",
    "                noise=noise,\n",
    "                uniform_noise=uniform_noise,\n",
    "                replace=replace,\n",
    "                block_states_to_unpatch=block_states_to_unpatch,\n",
    "                hook_to_unpatch=ablate_mambahook,\n",
    "                alt_subj_patching=alt_subj_patching,\n",
    "            )\n",
    "            row.append(r)\n",
    "        table.append(torch.stack(row))\n",
    "    return torch.stack(table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.tracing import trace_with_patch, replace_eos_with_pad\n",
    "\n",
    "def calculate_hidden_flow_with_ablation(\n",
    "    mt: ModelandTokenizer,\n",
    "    prompt: str,\n",
    "    subject: str,\n",
    "    alt_subject: Optional[str] = None,\n",
    "    num_samples=10,\n",
    "    noise=0.1,\n",
    "    token_range=None,\n",
    "    uniform_noise=False,\n",
    "    replace=False,\n",
    "    window=10,\n",
    "    ablate_mambahook: Optional[MambaBlock_Hook_Points] = None,\n",
    "):\n",
    "    if alt_subject is None:\n",
    "        inp = make_inputs(mt.tokenizer, [prompt] * (num_samples + 1))\n",
    "        with torch.no_grad():\n",
    "            answer_t, base_score = [d[0] for d in predict_from_input(mt.model, inp)]\n",
    "        e_range = find_token_range(\n",
    "            string=prompt,\n",
    "            substring=subject,\n",
    "            tokenizer=mt.tokenizer,\n",
    "        )\n",
    "        low_score = trace_with_patch(\n",
    "            mt,\n",
    "            inp,\n",
    "            [],\n",
    "            answer_t,\n",
    "            e_range,\n",
    "            noise=noise,\n",
    "            uniform_noise=uniform_noise,\n",
    "            mamba_block_hook=None,  # don't need to patch for calculating the low score\n",
    "            alt_subj_patching=alt_subject is not None,\n",
    "        ).item()\n",
    "    else:\n",
    "        if \"{}\" in prompt:\n",
    "            prompt = prompt.format(subject)\n",
    "        clean_prompt = prompt\n",
    "        alt_prompt = prompt.replace(subject, alt_subject)\n",
    "        with tokenizer_utils.set_padding_side(mt.tokenizer, padding_side=\"left\"):\n",
    "            inp = mt.tokenizer(\n",
    "                [clean_prompt, alt_prompt],\n",
    "                return_tensors=\"pt\",\n",
    "                padding=\"longest\",\n",
    "                return_offsets_mapping=True,\n",
    "            ).to(mt.device)\n",
    "        offset_mapping = inp.pop(\"offset_mapping\")\n",
    "        subject_range = find_token_range(\n",
    "            string=clean_prompt,\n",
    "            substring=subject,\n",
    "            tokenizer=mt.tokenizer,\n",
    "            offset_mapping=offset_mapping[0],\n",
    "        )\n",
    "        alt_subj_range = find_token_range(\n",
    "            string=alt_prompt,\n",
    "            substring=alt_subject,\n",
    "            tokenizer=mt.tokenizer,\n",
    "            offset_mapping=offset_mapping[1],\n",
    "        )\n",
    "        assert subject_range[1] == alt_subj_range[1]\n",
    "        e_range = (min(subject_range[0], alt_subj_range[0]), subject_range[1])\n",
    "\n",
    "        with torch.no_grad():\n",
    "            outputs = mt(**inp)\n",
    "        logits = outputs.logits[:, -1] if hasattr(outputs, \"logits\") else outputs[:, -1]\n",
    "        next_token_probs = logits.float().softmax(dim=-1)\n",
    "        answer_t = next_token_probs[0].argmax(dim=-1)\n",
    "        base_score = next_token_probs[0, answer_t]  # p(ans|subj)\n",
    "        low_score = next_token_probs[1, answer_t]  # p(ans|alt_subj)\n",
    "    \n",
    "    \n",
    "    if token_range == \"subject_last\":\n",
    "        token_range = [e_range[1] - 1]\n",
    "    elif token_range == \"prompt_last\":\n",
    "        token_range = [inp[\"input_ids\"].shape[1] - 1]\n",
    "    elif token_range is not None:\n",
    "        raise ValueError(f\"Unknown token_range: {token_range}\")\n",
    "\n",
    "    \n",
    "    [answer] = decode_tokens(mt.tokenizer, [answer_t])\n",
    "\n",
    "    differences = trace_important_states_with_ablation(\n",
    "        mt =  mt,\n",
    "        inp =  inp,\n",
    "        e_range = e_range,\n",
    "        answer_t = answer_t,\n",
    "        noise = noise,\n",
    "        uniform_noise = uniform_noise,\n",
    "        replace = replace,\n",
    "        token_range = token_range,\n",
    "        ablate_mambahook = ablate_mambahook,\n",
    "        alt_subj_patching=alt_subject is not None,\n",
    "    )\n",
    "\n",
    "    differences = differences.detach().cpu()\n",
    "    indirect_effect = dict(\n",
    "        scores=differences,\n",
    "        low_score=low_score,\n",
    "        high_score=base_score,\n",
    "        input_ids=inp[\"input_ids\"][0],\n",
    "        input_tokens=replace_eos_with_pad(\n",
    "            mt.tokenizer, list(decode_tokens(mt.tokenizer, inp[\"input_ids\"][0]))\n",
    "        ),\n",
    "        subject_range=e_range,\n",
    "        answer=answer,\n",
    "        window=window,\n",
    "        correct_prediction=True,\n",
    "        kind=ablate_mambahook,\n",
    "    )\n",
    "\n",
    "    if alt_subject is not None:\n",
    "        indirect_effect[\"alt_subject\"] = replace_eos_with_pad(\n",
    "            mt.tokenizer,\n",
    "            list(\n",
    "                decode_tokens(\n",
    "                    mt.tokenizer, inp[\"input_ids\"][1, e_range[0] : e_range[1]]\n",
    "                )\n",
    "            ),\n",
    "        )\n",
    "\n",
    "    return indirect_effect\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.plotting import plot_trace_heatmap\n",
    "\n",
    "# indirect_effect = calculate_hidden_flow_with_ablation(\n",
    "#     mt = mt,\n",
    "#     prompt = prompt_template,\n",
    "#     subject = subject,\n",
    "#     alt_subject = alt_subject,\n",
    "#     num_samples = 10,\n",
    "#     noise = 0.1,\n",
    "#     token_range = None,\n",
    "#     uniform_noise = False,\n",
    "#     replace = False,\n",
    "#     window = 10,\n",
    "#     ablate_mambahook = None,\n",
    "# )\n",
    "\n",
    "# plot_trace_heatmap(indirect_effect, modelname=MODEL_PATH.split(\"/\")[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# indirect_effect_ablate_ssm = calculate_hidden_flow_with_ablation(\n",
    "#     mt = mt,\n",
    "#     prompt = prompt_template,\n",
    "#     subject = subject,\n",
    "#     alt_subject = alt_subject,\n",
    "#     num_samples = 10,\n",
    "#     noise = 0.1,\n",
    "#     token_range = None,\n",
    "#     uniform_noise = False,\n",
    "#     replace = False,\n",
    "#     window = 10,\n",
    "#     ablate_mambahook = \"ssm_after_ssm\",\n",
    "# )\n",
    "# plot_trace_heatmap(indirect_effect_ablate_ssm, modelname=MODEL_PATH.split(\"/\")[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# indirect_effect_ablate_other = calculate_hidden_flow_with_ablation(\n",
    "#     mt = mt,\n",
    "#     prompt = prompt_template,\n",
    "#     subject = subject,\n",
    "#     alt_subject = alt_subject,\n",
    "#     num_samples = 10,\n",
    "#     noise = 0.1,\n",
    "#     token_range = None,\n",
    "#     uniform_noise = False,\n",
    "#     replace = False,\n",
    "#     window = 10,\n",
    "#     ablate_mambahook = \"mlp_after_silu\",\n",
    "# )\n",
    "# plot_trace_heatmap(indirect_effect_ablate_other, modelname=MODEL_PATH.split(\"/\")[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kwargs = dict(\n",
    "    mt = mt,\n",
    "    prompt = prompt_template,\n",
    "    subject = subject,\n",
    "    alt_subject = alt_subject,\n",
    "    num_samples = 10,\n",
    "    noise = 0.1,\n",
    "    token_range = \"subject_last\",\n",
    "    uniform_noise = False,\n",
    "    replace = False,\n",
    "    window = 10,\n",
    ")\n",
    "\n",
    "indirect_effect = calculate_hidden_flow_with_ablation(\n",
    "    ablate_mambahook=None,\n",
    "    **kwargs,\n",
    ")\n",
    "\n",
    "indirect_effect_ssm_severed = calculate_hidden_flow_with_ablation(\n",
    "    ablate_mambahook=\"ssm_after_ssm\",\n",
    "    **kwargs,\n",
    ")\n",
    "\n",
    "indirect_effect_mlp_severed = calculate_hidden_flow_with_ablation(\n",
    "    ablate_mambahook=\"mlp_after_silu\",\n",
    "    **kwargs,\n",
    ")\n",
    "\n",
    "indirect_effect_block_severed = calculate_hidden_flow_with_ablation(\n",
    "    ablate_mambahook=\"after_down_proj\",\n",
    "    **kwargs,\n",
    ")\n",
    "\n",
    "high_score = indirect_effect[\"high_score\"].item()\n",
    "low_score = indirect_effect[\"low_score\"].item()\n",
    "\n",
    "ie = (indirect_effect[\"scores\"] - low_score) / (high_score - low_score)\n",
    "ie_ssm_severed = (indirect_effect_ssm_severed[\"scores\"] - low_score) / (high_score - low_score)\n",
    "ie_mlp_severed = (indirect_effect_mlp_severed[\"scores\"] - low_score) / (high_score - low_score)\n",
    "ie_block_severed = (indirect_effect_block_severed[\"scores\"] - low_score) / (high_score - low_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "#####################################################################################\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 16\n",
    "MEDIUM_SIZE = 18\n",
    "BIGGER_SIZE = 24\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE+5)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "#####################################################################################\n",
    "\n",
    "def plot_aie_subject_last(\n",
    "    indirect_effects,\n",
    "    indirect_effects_with_ssm_severed,\n",
    "    indirect_effects_with_mlp_severed,\n",
    "    indirect_effect_block_severed,\n",
    "):\n",
    "    state_indirect_effects = {\n",
    "        \"Single State\": indirect_effects.squeeze().numpy(),\n",
    "        \"SSM Severed\": indirect_effects_with_ssm_severed.squeeze().numpy(),\n",
    "        \"Gate Severed\": indirect_effects_with_mlp_severed.squeeze().numpy(),\n",
    "        \"Block Severed\": indirect_effect_block_severed.squeeze().numpy(),\n",
    "    }\n",
    "    config_color = {\n",
    "        \"Single State\": \"purple\",\n",
    "        \"SSM Severed\": \"red\",\n",
    "        \"Gate Severed\": \"green\",\n",
    "        \"Block Severed\": \"blue\",\n",
    "    }\n",
    "\n",
    "    # plt.rcdefaults()\n",
    "    plt.figure(figsize=(20, 6))\n",
    "\n",
    "    idx = 0\n",
    "    bar_width = 0.22\n",
    "    for config in state_indirect_effects:\n",
    "        plot_ie = []\n",
    "        for layer in range(mt.n_layer):\n",
    "            plot_ie.append(state_indirect_effects[config][layer])\n",
    "        \n",
    "        plt.bar(\n",
    "            np.arange(len(plot_ie)) + idx * bar_width, plot_ie,\n",
    "            width = bar_width,\n",
    "            label = config,\n",
    "            # edgecolor = \"black\",\n",
    "            color = config_color[config],\n",
    "            alpha = 0.9\n",
    "        )\n",
    "        idx += 1\n",
    "\n",
    "    def layer_tick(layer_idx, jump=5):\n",
    "        return f\"{layer_idx}\" if layer_idx % jump == 0 else \"\"\n",
    "\n",
    "    plt.ylim(0, 1)\n",
    "    plt.xticks(\n",
    "        np.arange(mt.n_layer) + bar_width, \n",
    "        [layer_tick(i) for i in range(mt.n_layer)]\n",
    "    )\n",
    "    \n",
    "    ylabel = f\"Indirect Effect ({'%'} of p(ans) recovered)\"\n",
    "\n",
    "    plt.ylabel(ylabel)\n",
    "    plt.legend(ncol = 4, bbox_to_anchor=(0.5, -.18), loc='lower center', frameon=False)\n",
    "    # plt.savefig(f\"figs/faithfulness_lre_models.pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_aie_subject_last(ie, ie_ssm_severed, ie_mlp_severed, ie_block_severed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kwargs = dict(\n",
    "    mt = mt,\n",
    "    prompt = prompt_template,\n",
    "    subject = subject,\n",
    "    alt_subject = alt_subject,\n",
    "    num_samples = 10,\n",
    "    noise = 0.1,\n",
    "    token_range = \"prompt_last\",\n",
    "    uniform_noise = False,\n",
    "    replace = False,\n",
    "    window = 10,\n",
    ")\n",
    "\n",
    "indirect_effect = calculate_hidden_flow_with_ablation(\n",
    "    ablate_mambahook=None,\n",
    "    **kwargs,\n",
    ")\n",
    "\n",
    "indirect_effect_ssm_severed = calculate_hidden_flow_with_ablation(\n",
    "    ablate_mambahook=\"ssm_after_ssm\",\n",
    "    **kwargs,\n",
    ")\n",
    "\n",
    "indirect_effect_gate_severed = calculate_hidden_flow_with_ablation(\n",
    "    ablate_mambahook=\"mlp_after_silu\",\n",
    "    **kwargs,\n",
    ")\n",
    "\n",
    "indirect_effect_block_severed = calculate_hidden_flow_with_ablation(\n",
    "    ablate_mambahook=\"after_down_proj\",\n",
    "    **kwargs,\n",
    ")\n",
    "\n",
    "high_score = indirect_effect[\"high_score\"].item()\n",
    "low_score = indirect_effect[\"low_score\"].item()\n",
    "\n",
    "ie = (indirect_effect[\"scores\"] - low_score) / (high_score - low_score)\n",
    "ie_ssm_severed = (indirect_effect_ssm_severed[\"scores\"] - low_score) / (high_score - low_score)\n",
    "ie_gate_severed = (indirect_effect_gate_severed[\"scores\"] - low_score) / (high_score - low_score)\n",
    "ie_block_severed = (indirect_effect_block_severed[\"scores\"] - low_score) / (high_score - low_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_aie_subject_last(ie, ie_ssm_severed, ie_gate_severed, ie_block_severed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.data.dataclasses import CounterFactDataset\n",
    "\n",
    "from src.dataset.dataclasses import load_relation\n",
    "relation = load_relation(file = \"../data/relation/factual/place_in_city.json\")\n",
    "relation.select_icl_examples(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "question, answer = relation[10]\n",
    "print(question, end=\"\\n\\n\")\n",
    "print(f\"{answer=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import filter_samples_by_model_knowledge\n",
    "relation = filter_samples_by_model_knowledge(\n",
    "    mt = mt,\n",
    "    relation = relation,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------------------------------------------------\n",
    "n_trials = 30\n",
    "# -------------------------------------------------\n",
    "\n",
    "samples = relation.samples[: n_trials]\n",
    "edit_targets = functional.random_edit_targets(samples = samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "common_kwargs = dict(\n",
    "    mt = mt,\n",
    "    prompt = relation.prompt_templates[0],\n",
    "    token_range = \"subject_last\",\n",
    "    window = 6,\n",
    ")\n",
    "\n",
    "all_indirect_effects = []\n",
    "all_indirect_effects_ssm_severed = []\n",
    "all_indirect_effects_mlp_severed = []\n",
    "\n",
    "for sample in tqdm(samples):\n",
    "    alt_sample = edit_targets[sample]\n",
    "\n",
    "    print(f\"sample={str(sample)}, alt_sample={str(alt_sample)}\")\n",
    "\n",
    "    indirect_effect = calculate_hidden_flow_with_ablation(\n",
    "        subject = sample.subject,\n",
    "        alt_subject = alt_sample.subject,\n",
    "        ablate_mambahook=None,\n",
    "        **common_kwargs,\n",
    "    )\n",
    "\n",
    "    indirect_effect_ssm_severed = calculate_hidden_flow_with_ablation(\n",
    "        subject = sample.subject,\n",
    "        alt_subject = alt_subject,\n",
    "        ablate_mambahook=\"ssm_after_ssm\",\n",
    "        **common_kwargs,\n",
    "    )\n",
    "\n",
    "    indirect_effect_mlp_severed = calculate_hidden_flow_with_ablation(\n",
    "        subject = sample.subject,\n",
    "        alt_subject = alt_subject,\n",
    "        ablate_mambahook=\"mlp_after_silu\",\n",
    "        **common_kwargs,\n",
    "    )\n",
    "\n",
    "    high_score = indirect_effect[\"high_score\"].item()\n",
    "    low_score = indirect_effect[\"low_score\"].item()\n",
    "\n",
    "    ie = (indirect_effect[\"scores\"] - low_score) / (high_score - low_score)\n",
    "    ie_ssm_severed = (indirect_effect_ssm_severed[\"scores\"] - low_score) / (high_score - low_score)\n",
    "    ie_mlp_severed = (indirect_effect_mlp_severed[\"scores\"] - low_score) / (high_score - low_score)\n",
    "\n",
    "    all_indirect_effects.append(ie)\n",
    "    all_indirect_effects_ssm_severed.append(ie_ssm_severed)\n",
    "    all_indirect_effects_mlp_severed.append(ie_mlp_severed)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_indirect_effects = torch.stack(all_indirect_effects).squeeze().mean(dim=0)\n",
    "avg_indirect_effects_ssm_severed = torch.stack(all_indirect_effects_ssm_severed).squeeze().mean(dim=0)\n",
    "avg_indirect_effects_mlp_severed = torch.stack(all_indirect_effects_mlp_severed).squeeze().mean(dim=0)\n",
    "\n",
    "plot_aie_subject_last(\n",
    "    avg_indirect_effects, \n",
    "    avg_indirect_effects_ssm_severed, \n",
    "    avg_indirect_effects_mlp_severed\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "common_kwargs = dict(\n",
    "    mt = mt,\n",
    "    prompt = relation.prompt_templates[0],\n",
    "    token_range = \"prompt_last\",\n",
    "    window = 6,\n",
    ")\n",
    "\n",
    "all_indirect_effects = []\n",
    "all_indirect_effects_ssm_severed = []\n",
    "all_indirect_effects_mlp_severed = []\n",
    "\n",
    "for sample in tqdm(samples):\n",
    "    alt_sample = edit_targets[sample]\n",
    "\n",
    "    print(f\"sample={str(sample)}, alt_sample={str(alt_sample)}\")\n",
    "\n",
    "    indirect_effect = calculate_hidden_flow_with_ablation(\n",
    "        subject = sample.subject,\n",
    "        alt_subject = alt_sample.subject,\n",
    "        ablate_mambahook=None,\n",
    "        **common_kwargs,\n",
    "    )\n",
    "\n",
    "    indirect_effect_ssm_severed = calculate_hidden_flow_with_ablation(\n",
    "        subject = sample.subject,\n",
    "        alt_subject = alt_sample.subject,\n",
    "        ablate_mambahook=\"ssm_after_ssm\",\n",
    "        **common_kwargs,\n",
    "    )\n",
    "\n",
    "    indirect_effect_mlp_severed = calculate_hidden_flow_with_ablation(\n",
    "        subject = sample.subject,\n",
    "        alt_subject = alt_sample.subject,\n",
    "        ablate_mambahook=\"mlp_after_silu\",\n",
    "        **common_kwargs,\n",
    "    )\n",
    "\n",
    "    high_score = indirect_effect[\"high_score\"].item()\n",
    "    low_score = indirect_effect[\"low_score\"].item()\n",
    "\n",
    "    ie = (indirect_effect[\"scores\"] - low_score) / (high_score - low_score)\n",
    "    ie_ssm_severed = (indirect_effect_ssm_severed[\"scores\"] - low_score) / (high_score - low_score)\n",
    "    ie_mlp_severed = (indirect_effect_mlp_severed[\"scores\"] - low_score) / (high_score - low_score)\n",
    "\n",
    "    all_indirect_effects.append(ie)\n",
    "    all_indirect_effects_ssm_severed.append(ie_ssm_severed)\n",
    "    all_indirect_effects_mlp_severed.append(ie_mlp_severed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_indirect_effects = torch.stack(all_indirect_effects).squeeze().mean(dim=0)\n",
    "avg_indirect_effects_ssm_severed = torch.stack(all_indirect_effects_ssm_severed).squeeze().mean(dim=0)\n",
    "avg_indirect_effects_mlp_severed = torch.stack(all_indirect_effects_mlp_severed).squeeze().mean(dim=0)\n",
    "\n",
    "plot_aie_subject_last(\n",
    "    avg_indirect_effects, \n",
    "    avg_indirect_effects_ssm_severed, \n",
    "    avg_indirect_effects_mlp_severed\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "#! Didn't really reveal a clean separation of roles between different layers.\n",
    "#! But, did reveal a clear separation of roles between different blocks.\n",
    "\n",
    "# TODO: The problem with severing the SSM is that it doesn't only effect this token, but also directly effects the 3 tokens that come after it\n",
    "# because of convolution. And, as a ripple effect, the whole computation after this token is affected. But, this was the same case for ATTN. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relations",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
