{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "import types\n",
    "from src.hooking.mamba import MambaBlock_Hook_Points, MambaBlockForwardPatcher\n",
    "import torch\n",
    "import transformers\n",
    "import baukit\n",
    "from tqdm.auto import tqdm\n",
    "import json\n",
    "import os\n",
    "from src import functional\n",
    "import src.tokens as tokenization_utils\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "\n",
    "torch.__version__, transformers.__version__, torch.version.cuda\n",
    "\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "MODEL_PATH = \"state-spaces/mamba-2.8b\" # state-spaces/mamba-2.8b | state-spaces/mamba-2.8b-slimpj\n",
    "\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_path=MODEL_PATH, \n",
    "    torch_dtype=torch.float32\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hooks = [\n",
    "    \"ssm_after_up_proj\",\n",
    "    \"ssm_after_conv1D\",\n",
    "    \"ssm_after_silu\",\n",
    "    \"ssm_after_ssm\",\n",
    "    \"mlp_after_up_proj\",\n",
    "    \"mlp_after_silu\",\n",
    "    \"before_down_proj\",\n",
    "    \"after_down_proj\",  # the output of the mamba block #! Not the residual\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.tracing import detensorize_indirect_effects\n",
    "DATA_DIR = \"../data\"\n",
    "\n",
    "@torch.inference_mode()\n",
    "def load_mean_activations(\n",
    "    mt: ModelandTokenizer,\n",
    "    num_docs=128,\n",
    "    n_tok_per_doc=128,\n",
    "):\n",
    "    ACT_DIR = os.path.join(DATA_DIR, \"mean_activations\")\n",
    "    os.makedirs(ACT_DIR, exist_ok=True)\n",
    "    FILE_NAME = mt.name.lower().split(\"/\")[-1] + \".json\"\n",
    "    if FILE_NAME in os.listdir(ACT_DIR):\n",
    "        logger.info(\"Loading mean activations from cache\")\n",
    "        with open(os.path.join(ACT_DIR, FILE_NAME), \"r\") as f:\n",
    "            mean_activations = json.load(f)\n",
    "        for layer in mean_activations:\n",
    "            for hook in mean_activations[layer]:\n",
    "                mean_activations[layer][hook] = torch.tensor(\n",
    "                    mean_activations[layer][hook]\n",
    "                ).to(mt.device)\n",
    "        return mean_activations\n",
    "\n",
    "    logger.info(\"Calculating mean activations\")\n",
    "\n",
    "    with open(os.path.join(DATA_DIR, \"attribute_snippets.json\"), \"r\") as f:\n",
    "        attribute_snippets = json.load(f)\n",
    "\n",
    "    random_text = [\n",
    "        attribute_snippets[i][\"samples\"][0][\"text\"]\n",
    "        for i in range(min(len(attribute_snippets), num_docs))\n",
    "    ]\n",
    "\n",
    "    hooks = [\n",
    "        \"ssm_after_up_proj\",\n",
    "        \"ssm_after_conv1D\",\n",
    "        \"ssm_after_silu\",\n",
    "        \"ssm_after_ssm\",\n",
    "        \"mlp_after_up_proj\",\n",
    "        \"mlp_after_silu\",\n",
    "        \"before_down_proj\",\n",
    "        \"after_down_proj\",  # the output of the mamba block #! Not the residual\n",
    "    ]\n",
    "\n",
    "    avg_activations = {\n",
    "        layer: {hook: None for hook in hooks} for layer in mt.layer_names\n",
    "    }\n",
    "\n",
    "    counter = 0\n",
    "    for text in tqdm(random_text):\n",
    "        inputs = mt.tokenizer(\n",
    "            text,\n",
    "            return_tensors=\"pt\",\n",
    "        ).to(mt.device)\n",
    "        input_ids = inputs[\"input_ids\"]\n",
    "        input_ids = input_ids[:, : min(input_ids.shape[-1], n_tok_per_doc)]\n",
    "\n",
    "        mt.reset_forward()\n",
    "\n",
    "        current_states = {\n",
    "            layer: {hook: None for hook in hooks} for layer in mt.layer_names\n",
    "        }\n",
    "        for layer in mt.layer_names:\n",
    "            mambablock = baukit.get_module(mt.model, name=layer + \".mixer\")\n",
    "            mambablock.forward = types.MethodType(\n",
    "                MambaBlockForwardPatcher(retainer=current_states[layer]), mambablock\n",
    "            )\n",
    "\n",
    "        with torch.no_grad():\n",
    "            mt.model(input_ids)\n",
    "\n",
    "        for layer in mt.layer_names:\n",
    "            for hook in hooks:\n",
    "                activations = current_states[layer][hook].detach()\n",
    "                # print(activations.shape)\n",
    "                if avg_activations[layer][hook] is None:\n",
    "                    avg_activations[layer][hook] = activations.mean(dim=1)\n",
    "                else:\n",
    "                    avg_activations[layer][hook] += activations.mean(dim=1)\n",
    "        counter += 1\n",
    "        functional.free_gpu_cache()\n",
    "\n",
    "    for layer in avg_activations:\n",
    "        for hook in avg_activations[layer]:\n",
    "            avg_activations[layer][hook] /= counter\n",
    "\n",
    "    for hook in hooks:\n",
    "        logger.info(f\"{hook} => {avg_activations['layers.4'][hook].shape}\")\n",
    "\n",
    "    avg_detensorized = {}\n",
    "    for layer in avg_activations:\n",
    "        avg_detensorized[layer] = detensorize_indirect_effects(avg_activations[layer])\n",
    "\n",
    "    with open(os.path.join(ACT_DIR, FILE_NAME), \"w\") as f:\n",
    "        json.dump(avg_detensorized, f)\n",
    "        logger.info(f\"Mean activations saved to {os.path.join(ACT_DIR, FILE_NAME)}\")\n",
    "\n",
    "    return avg_activations\n",
    "\n",
    "\n",
    "mean_activations = load_mean_activations(mt, num_docs=12, n_tok_per_doc=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.tracing import detensorize_indirect_effects\n",
    "\n",
    "# for layer in avg_activations:\n",
    "#     avg_activations[layer] = detensorize_indirect_effects(avg_activations[layer])\n",
    "\n",
    "# with open(\"../data/avg_activations.json\", \"w\") as f:\n",
    "#     json.dump(avg_activations, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_window(layer_idx, num_layers = 64, window_size = 10):\n",
    "    window_size = window_size // 2\n",
    "    start = max(0, layer_idx - window_size)\n",
    "    end = min(num_layers, layer_idx + window_size)\n",
    "    return list(range(start, end))\n",
    "\n",
    "get_window(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def retention_knockout_on_single_fact(\n",
    "    mt: ModelandTokenizer,\n",
    "    subject: str,\n",
    "    prompt_template: str,\n",
    "    patch_hook: str = \"ssm_after_up_proj\",\n",
    "    window = 10,\n",
    "):\n",
    "    prompt = tokenization_utils.maybe_prefix_eos(mt, prompt_template.format(subject))\n",
    "    inputs = mt.tokenizer(prompt, return_tensors=\"pt\", return_offsets_mapping=True).to(mt.device)\n",
    "    offsets = inputs.pop(\"offset_mapping\")[0]\n",
    "\n",
    "    e_range = functional.find_token_range(\n",
    "        string = prompt,\n",
    "        substring= subject,\n",
    "        offset_mapping=offsets,\n",
    "    )\n",
    "\n",
    "    logger.debug(f\"Subject range: {e_range}\")\n",
    "\n",
    "    prompt_last = inputs.input_ids.shape[-1] - 1\n",
    "    subj_positions = list(range(e_range[0], e_range[1]))\n",
    "    non_subj_positions = [i for i in range(inputs.input_ids.shape[-1]) if i not in subj_positions + [prompt_last]]\n",
    "\n",
    "\n",
    "    # caching states from the clean run\n",
    "    mt.reset_forward()\n",
    "    clean_states = {layer: {patch_hook: None} for layer in mt.layer_names}\n",
    "    for layer in mt.layer_names:\n",
    "        mambablock = baukit.get_module(\n",
    "            mt.model, name = layer + \".mixer\"\n",
    "        )\n",
    "        mambablock.forward = types.MethodType(\n",
    "            MambaBlockForwardPatcher(retainer=clean_states[layer]), \n",
    "            mambablock\n",
    "        )\n",
    "\n",
    "    output_clean = mt(**inputs)\n",
    "    proba = torch.nn.functional.softmax(output_clean[:, -1], dim = -1)\n",
    "    ans_t = proba.argmax(dim = -1)\n",
    "    ans = mt.tokenizer.decode(ans_t)\n",
    "    p_ans = proba[0, ans_t].item()\n",
    "\n",
    "    logger.info(f\"{subject} -> {ans} ({p_ans})\")\n",
    "\n",
    "    ablate_positions = {\n",
    "        \"subject\": subj_positions,\n",
    "        \"subj_last\": [subj_positions[-1]],\n",
    "        \"non_subject\": non_subj_positions,\n",
    "        \"prompt_last\": [prompt_last],\n",
    "    }\n",
    "    result = {\n",
    "        \"answer\": ans,\n",
    "        \"p_answer\": p_ans,\n",
    "        \"knock_out_from_last\": {}\n",
    "    }\n",
    "\n",
    "    for setting, ablate_position in ablate_positions.items(): \n",
    "        corrupted_states = {}\n",
    "        patch_hook = \"ssm_after_up_proj\"\n",
    "        restore_positions = list(set(range(prompt_last)) - set(ablate_position)) # Don't restore prompt_last (let the model calculate this)\n",
    "\n",
    "\n",
    "        layer_wise_p_ans = []\n",
    "        for layer_idx in range(mt.n_layer):\n",
    "            # corrupted run with mean ablation\n",
    "            mt.reset_forward()\n",
    "\n",
    "            current_window = get_window(layer_idx, num_layers=mt.n_layer, window_size=window)\n",
    "\n",
    "            for l in current_window:\n",
    "                layername = mt.layer_name_format.format(l)\n",
    "                mambablock = baukit.get_module(mt.model, name=layername + \".mixer\")\n",
    "                patch_spec={}\n",
    "                for i in ablate_position:\n",
    "                    patch_spec[i] = avg_activations[layername][patch_hook]\n",
    "                for i in restore_positions:\n",
    "                    patch_spec[i] = clean_states[layername][patch_hook][:, i, :]\n",
    "\n",
    "                mambablock.forward = types.MethodType(\n",
    "                    MambaBlockForwardPatcher(\n",
    "                        patch_spec=patch_spec,\n",
    "                        patch_hook=patch_hook,\n",
    "                        retainer=corrupted_states,\n",
    "                    ),\n",
    "                    mambablock,\n",
    "                )\n",
    "            # print(inputs)\n",
    "            output_corrupted = mt(**inputs)\n",
    "            proba_corrupted = torch.nn.functional.softmax(output_corrupted[:, -1], dim = -1)\n",
    "            p_ans_corrupted = proba_corrupted[0, ans_t].item()\n",
    "            layer_wise_p_ans.append(p_ans_corrupted)\n",
    "\n",
    "        logger.info(f\"{setting} -> {layer_wise_p_ans}\")\n",
    "        result[\"knock_out_from_last\"][setting] = layer_wise_p_ans\n",
    "    \n",
    "    mt.reset_forward()\n",
    "    return result\n",
    "\n",
    "\n",
    "knock_out_results = retention_knockout_on_single_fact(\n",
    "    mt = mt,\n",
    "    subject = \"The Louvre\",\n",
    "    prompt_template = \"{} is located in the city of\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "for key in knock_out_results[\"knock_out_from_last\"]:\n",
    "    plt.plot(knock_out_results[\"knock_out_from_last\"][key], label=key)\n",
    "\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.dataset.dataclasses import load_relation\n",
    "relation = load_relation(file = \"../data/relation/factual/place_in_city.json\")\n",
    "relation.select_icl_examples(0)\n",
    "\n",
    "question, answer = relation[10]\n",
    "print(question, end=\"\\n\\n\")\n",
    "print(f\"{answer=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import filter_samples_by_model_knowledge\n",
    "\n",
    "mt.reset_forward()\n",
    "relation = filter_samples_by_model_knowledge(\n",
    "    mt = mt,\n",
    "    relation = relation,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "knock_out_results = []\n",
    "\n",
    "for sample in tqdm(relation.samples):\n",
    "    knock_out_results.append(   \n",
    "        retention_knockout_on_single_fact(\n",
    "            mt = mt,\n",
    "            subject = sample.subject,\n",
    "            prompt_template = relation.prompt_templates[0],\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processed_results = {k: [] for k in knock_out_results[0][\"knock_out_from_last\"].keys()}\n",
    "\n",
    "for result in knock_out_results:\n",
    "    p_ans = result[\"p_answer\"]\n",
    "    for k, v in result[\"knock_out_from_last\"].items():\n",
    "        relative_change = (torch.Tensor(v) - p_ans) / p_ans\n",
    "        processed_results[k].append(relative_change)\n",
    "\n",
    "\n",
    "#####################################################################################\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 16\n",
    "MEDIUM_SIZE = 18\n",
    "BIGGER_SIZE = 24\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "# plt.rc(\"axes\", titlesize=MEDIUM_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "#####################################################################################\n",
    "\n",
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "for k, v in processed_results.items():\n",
    "    mean = torch.stack(v).mean(dim=0)\n",
    "    std = torch.stack(v).std(dim=0)\n",
    "\n",
    "    plt.plot(mean, label=k)\n",
    "    plt.fill_between(\n",
    "        range(len(mean)),\n",
    "        mean - std,\n",
    "        mean + std,\n",
    "        alpha=0.1,\n",
    "    )\n",
    "\n",
    "plt.ylim(-1.2, 0.5)\n",
    "plt.legend(ncol = 4, bbox_to_anchor=(0.5, -.18), loc='lower center', frameon=False)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"../results/retention_knockout_results.json\", \"w\") as f:\n",
    "    json.dump(knock_out_results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relations",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
