{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "rbHij6RQTFny"
   },
   "source": [
    "# Indirect Object Identification Circuit in Pythia"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "from torch import Tensor\n",
    "import numpy as np\n",
    "import einops\n",
    "import json\n",
    "from fancy_einsum import einsum\n",
    "# import circuitsvis as cv\n",
    "\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "import transformer_lens.utils as tl_utils\n",
    "from transformer_lens import HookedTransformer\n",
    "import transformer_lens.patching as patching\n",
    "\n",
    "from jaxtyping import Float\n",
    "import plotly.express as px\n",
    "import plotly.io as pio\n",
    "\n",
    "from functools import partial\n",
    "\n",
    "from torchtyping import TensorType as TT\n",
    "\n",
    "from path_patching_cm.path_patching import Node, IterNode, path_patch, act_patch\n",
    "from path_patching_cm.ioi_dataset import IOIDataset, NAMES\n",
    "from neel_plotly import imshow as imshow_n\n",
    "\n",
    "from utils.metrics import compute_logit_diff\n",
    "from utils.data_utils import UniversalPatchingDataset\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = \"cuda\"\n",
    "else:\n",
    "    device = \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n",
    "    px.imshow(tl_utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n",
    "\n",
    "def line(tensor, renderer=None, **kwargs):\n",
    "    px.line(y=tl_utils.to_numpy(tensor), **kwargs).show(renderer)\n",
    "\n",
    "def two_lines(tensor1, tensor2, renderer=None, **kwargs):\n",
    "    px.line(y=[tl_utils.to_numpy(tensor1), tl_utils.to_numpy(tensor2)], **kwargs).show(renderer)\n",
    "\n",
    "def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", renderer=None, **kwargs):\n",
    "    x = tl_utils.to_numpy(x)\n",
    "    y = tl_utils.to_numpy(y)\n",
    "    px.scatter(y=y, x=x, labels={\"x\":xaxis, \"y\":yaxis, \"color\":caxis}, **kwargs).show(renderer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# import kaleido\n",
    "# pio.renderers.default = 'png' # USE IF MAKING GRAPHS FOR NOTEBOOK EXPORT\n",
    "# ============================================================================="
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BASE_MODEL = \"pythia-160m\"\n",
    "VARIANT = \"EleutherAI/pythia-160m-alldropout\"\n",
    "CACHE = \"model_cache\"\n",
    "CHECKPOINT = 143000\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "IOI_DATASET_SIZE = 70"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "Q5t_LmW3Tuej"
   },
   "source": [
    "## Model Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "61yvZIWCTmFM",
    "outputId": "437321a7-df54-41ab-8fca-557b796d02ec"
   },
   "outputs": [],
   "source": [
    "def get_model(base_model=\"pythia-160m\", variant=None, checkpoint=143000, cache=\"model_cache\"):\n",
    "    if not variant:\n",
    "        model = HookedTransformer.from_pretrained(\n",
    "            base_model,\n",
    "            checkpoint_value=checkpoint,\n",
    "            center_unembed=True,\n",
    "            center_writing_weights=True,\n",
    "            fold_ln=True,\n",
    "            refactor_factored_attn_matrices=False,\n",
    "            dtype=torch.bfloat16,\n",
    "            **{\"cache_dir\": cache},\n",
    "        )\n",
    "    else:\n",
    "        revision = f\"step{checkpoint}\"\n",
    "        source_model = AutoModelForCausalLM.from_pretrained(\n",
    "        VARIANT, revision=revision, cache_dir=cache\n",
    "        ).to(device).to(torch.bfloat16)\n",
    "\n",
    "        model = HookedTransformer.from_pretrained(\n",
    "            base_model,\n",
    "            hf_model=source_model,\n",
    "            center_unembed=True,\n",
    "            center_writing_weights=True,\n",
    "            fold_ln=True,\n",
    "            dtype=torch.bfloat16,\n",
    "            **{\"cache_dir\": cache},\n",
    "        )\n",
    "        \n",
    "    model.cfg.use_split_qkv_input = True\n",
    "    model.cfg.use_attn_result = True\n",
    "    model.cfg.use_hook_mlp_in = True\n",
    "\n",
    "    return model"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "x4diyMzfPCBD"
   },
   "source": [
    "## Data Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = get_model()\n",
    "ds = UniversalPatchingDataset.from_ioi(model, IOI_DATASET_SIZE)\n",
    "# logit_diff_metric = partial(compute_logit_diff,mode='simple')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "lrxnGyi7E19v",
    "outputId": "ef4ce3a5-ea1b-456a-c898-05188e08ffdc"
   },
   "outputs": [],
   "source": [
    "clean_logits, clean_cache = model.run_with_cache(ds.toks)\n",
    "# corrupted_logits, corrupted_cache = model.run_with_cache(ds.flipped_toks)\n",
    "\n",
    "# clean_logit_diff = logit_diff_metric(clean_logits, ds.answer_toks, ds.positions)\n",
    "# print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n",
    "\n",
    "# corrupted_logit_diff = logit_diff_metric(corrupted_logits, ds.answer_toks, ds.positions)\n",
    "# print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attention Head Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_components(cache, positions, layer_idx, head_idx):\n",
    "    return {\n",
    "        \"act_Q\": cache[f\"blocks.{layer_idx}.attn.hook_q\"][:, positions, head_idx, :].detach().to(\"cpu\"),\n",
    "        \"act_K\": cache[f\"blocks.{layer_idx}.attn.hook_k\"][:, positions, head_idx, :].detach().to(\"cpu\"),\n",
    "        \"act_V\": cache[f\"blocks.{layer_idx}.attn.hook_v\"][:, positions, head_idx, :].detach().to(\"cpu\"),\n",
    "        \"act_Z\": cache[f\"blocks.{layer_idx}.attn.hook_z\"][:, positions, head_idx, :].detach().to(\"cpu\"),\n",
    "        \"attn_pattern\": cache[f\"blocks.{layer_idx}.attn.hook_pattern\"][:, head_idx, positions, :].detach().to(\"cpu\"),\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def component_cosine_sim(model_a, model_b, heads, dataset=ds):\n",
    "    _, cache_a = model_a.run_with_cache(dataset.toks)\n",
    "    _, cache_b = model_b.run_with_cache(dataset.toks)\n",
    "\n",
    "    cosine_sims = dict()\n",
    "\n",
    "    for layer_idx, head_idx in heads:\n",
    "\n",
    "        cosine_sims[f\"L{layer_idx}H{head_idx}\"] = dict()\n",
    "\n",
    "        components_a = get_components(cache_a, dataset.positions, layer_idx, head_idx)\n",
    "        components_b = get_components(cache_b, dataset.positions, layer_idx, head_idx)\n",
    "\n",
    "        for k in components_a.keys():\n",
    "            cosine_sims[f\"L{layer_idx}H{head_idx}\"][k] = torch.nn.functional.cosine_similarity(components_a[k].flatten(), components_b[k].flatten(), dim=0).item()\n",
    "\n",
    "    return cosine_sims"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_a = get_model(checkpoint=75000)\n",
    "model_b = get_model(checkpoint=56000)\n",
    "\n",
    "cosine_sims = component_cosine_sim(model_a, model_b, [(9, 4)])\n",
    "cosine_sims"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get Top Heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "# suppress warnings\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "model_a = get_model(checkpoint=143000)\n",
    "\n",
    "similarity_through_time = dict()\n",
    "\n",
    "heads = [(9, 4), (8, 9), (10, 1), (4, 11), (6, 6), (7, 2), (2, 6), (6, 5), (4, 6), (10, 7), (8, 2), (8, 10), (8, 1), \n",
    "         (7, 9), (4, 1), (4, 8), (5, 8), (9, 8), (10, 11), (9, 6), (6, 11), (5, 0), (3, 0), (11, 6), (10, 2), (9, 5), (9, 7)]\n",
    "\n",
    "for c in range(1000, 142000, 1000):\n",
    "    \n",
    "    model_b = get_model(checkpoint=c)\n",
    "\n",
    "    cosine_sims = component_cosine_sim(model_a, model_b, heads)\n",
    "    similarity_through_time[c] = cosine_sims\n",
    "\n",
    "    print(f\"Checkpoint: {c} - Head Component Cosine Similarity to 143K: {cosine_sims}\")\n",
    "\n",
    "    # save json\n",
    "    with open(\"similarity_through_time_last_pos.json\", \"w\") as f:\n",
    "        json.dump(similarity_through_time, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load json\n",
    "import pandas as pd\n",
    "with open(\"similarity_through_time_last_pos.json\", \"r\") as f:\n",
    "    similarity_through_time = json.load(f)\n",
    "\n",
    "def plot_head_component_cosine_similarity(head, similarity_through_time, renderer=None):\n",
    "    # Convert nested dictionary to a list of dictionaries\n",
    "    data_list = []\n",
    "    for checkpoint, heads in similarity_through_time.items():\n",
    "        for item, value in heads[head].items():\n",
    "            data_list.append({'Checkpoint': checkpoint, 'Item': item, 'Value': value})\n",
    "\n",
    "    # Convert list to Pandas DataFrame\n",
    "    df = pd.DataFrame(data_list)\n",
    "\n",
    "    # Use Plotly Express to create the line chart\n",
    "    # Ensure each 'Item' gets its own line by specifying it in the 'color' parameter\n",
    "    fig = px.line(df, x='Checkpoint', y='Value', color='Item',\n",
    "                  title=f'Attention Head {head} Component Cosine Similarity To Final Checkpoint')\n",
    "\n",
    "    # Show the figure\n",
    "    fig.show(renderer=renderer)\n",
    "\n",
    "    return df\n",
    "\n",
    "#df = plot_head_component_cosine_similarity(\"L8H9\", similarity_through_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot all heads\n",
    "for head in similarity_through_time['1000'].keys():\n",
    "    plot_head_component_cosine_similarity(head, similarity_through_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": [],
   "toc_visible": true
  },
  "gpuClass": "premium",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
