{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data_utils import *\n",
    "from data_utils import test_distribution\n",
    "from model_utils import *\n",
    "import joblib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = get_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "POSITION_SUBSPACE_PATH = 'subspaces/state_dict_head_outputs_1dim_922test.pt'\n",
    "rotation = RotationMatrix.load_rotation_old(path=POSITION_SUBSPACE_PATH, n=768)\n",
    "rotation.requires_grad_(False)\n",
    "POSITION_DIRECTION = rotation.R.weight.data[:, 0].detach().cuda()\n",
    "resid_node = Node('resid_post', layer=8, seq_pos=-1)\n",
    "mlp_node = Node('post', layer=8, seq_pos=-1)\n",
    "resid_node_mid = Node('resid_mid', layer=8, seq_pos=-1)\n",
    "NAME_MOVERS = [(9, 6), (9, 9), (10, 0)]\n",
    "BACKUP_NAME_MOVERS = [(9, 0), (9, 7), (10, 1), (10, 2), (10, 6), (10, 10), (11, 2), (11, 9)]\n",
    "\n",
    "# das directions\n",
    "das_mlp8 = torch.Tensor(joblib.load(filename='das_mlp8.joblib')).cuda()\n",
    "das_resid = POSITION_DIRECTION\n",
    "das_resid_mid = torch.Tensor(joblib.load(filename='das_resid_mid.joblib')).cuda()\n",
    "\n",
    "W_out = model.W_out[8]\n",
    "# Q is of shape (3072, 768), and the columns of Q are an orthonormal basis\n",
    "# for the rowspace of W_out\n",
    "Q, _ = torch.linalg.qr(W_out) \n",
    "das_mlp8_row = das_mlp8 @ Q @ Q.T\n",
    "das_mlp8_null = das_mlp8 - das_mlp8_row\n",
    "das_row_unit = das_mlp8_row / das_mlp8_row.norm()\n",
    "das_null_unit = das_mlp8_null / das_mlp8_null.norm()\n",
    "\n",
    "\n",
    "# gradients\n",
    "head_gs = joblib.load(filename='name_mover_gradients.joblib')\n",
    "head_gs = [torch.Tensor(g).cuda() for g in head_gs]\n",
    "summed_gradient = torch.Tensor(joblib.load(filename='summed_gradient.joblib')).cuda()\n",
    "\n",
    "v_mean = torch.Tensor(joblib.load(filename='v_mean.joblib')).cuda()\n",
    "\n",
    "random.seed(42)\n",
    "TEST_DATASET = test_distribution.sample_das(\n",
    "        model=model,\n",
    "        base_patterns=['ABB',],\n",
    "        source_patterns=['BAB'],\n",
    "        labels='position',\n",
    "        samples_per_combination=1_000,\n",
    "    )\n",
    "\n",
    "PATCHING_DATASET = test_distribution.sample_das(\n",
    "        model=model,\n",
    "        base_patterns=['ABB',],\n",
    "        source_patterns=['BAB'],\n",
    "        labels='position',\n",
    "        samples_per_combination=1_000,\n",
    "    ) + test_distribution.sample_das(\n",
    "        model=model,\n",
    "        base_patterns=['BAB',],\n",
    "        source_patterns=['ABB'],\n",
    "        labels='position',\n",
    "        samples_per_combination=1_000,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "das_resid @ summed_gradient\n",
    "torch.cosine_similarity(das_resid, summed_gradient, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def decompose_resid_along_namemovers(v: Tensor):\n",
    "    WQs = [model.W_Q[layer, head] for layer, head in NAME_MOVERS] # each of shape (d_model, d_head)\n",
    "    WQ_concat = torch.cat(WQs, dim=1)\n",
    "    Q, _ = torch.linalg.qr(WQ_concat)\n",
    "    rowspace_component = v @ Q @ Q.T\n",
    "    nullspace_component = v - rowspace_component\n",
    "    return rowspace_component, nullspace_component\n",
    "\n",
    "das_row, das_null = decompose_resid_along_namemovers(das_resid)\n",
    "grad_row, grad_null = decompose_resid_along_namemovers(summed_gradient)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "das_row.norm(), grad_row.norm()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_all_hooks(model):\n",
    "    for _, submodule in model.named_modules():\n",
    "        submodule._forward_hooks.clear()\n",
    "\n",
    "def get_patching_acc_and_ld(patching_dataset: PatchingDataset,\n",
    "                            patcher: Patcher) -> Tuple[float, float]:\n",
    "    \"\"\"\n",
    "    Given a patcher, return the interchange accuracy and the logit difference on\n",
    "    the base dataset.\n",
    "    \"\"\"\n",
    "    patched_predictions = []\n",
    "    base_answer_logits_patched = []\n",
    "    for batch_ds in tqdm(patching_dataset.batches(batch_size=100, shuffle=False), total=len(patching_dataset) // 100):\n",
    "        _, _, _, logits_patched = patcher.run_patching(\n",
    "            model=model,\n",
    "            P_base=batch_ds.base.tokens,\n",
    "            P_source=batch_ds.source.tokens,\n",
    "            answer_tokens_base=batch_ds.base.answer_tokens,\n",
    "            answer_tokens_source=batch_ds.source.answer_tokens,\n",
    "            patched_answer_tokens=batch_ds.patched_answer_tokens,\n",
    "            return_full_patched_logits=True,\n",
    "        )\n",
    "        patched_predictions.append(logits_patched.argmax(dim=-1))\n",
    "        base_answer_logits_patched.append(logits_patched.gather(dim=1, index=batch_ds.base.answer_tokens.cuda()))\n",
    "    patched_predictions = torch.cat(patched_predictions)\n",
    "    interchange_accuracy = (patched_predictions == patching_dataset.patched_answer_tokens[:, 0].cuda()).float().mean().item()\n",
    "    base_answer_logits_patched = torch.cat(base_answer_logits_patched)\n",
    "    base_answers_logit_diff = (base_answer_logits_patched[:, 0] - base_answer_logits_patched[:, 1]).mean().item()\n",
    "    return interchange_accuracy, base_answers_logit_diff\n",
    "\n",
    "def get_resid_projections(v: Tensor, patching_dataset: PatchingDataset):\n",
    "    node = Node('resid_post', layer=8, seq_pos=-1)\n",
    "    prompts = np.concatenate([patching_dataset.base.prompts, patching_dataset.source.prompts], axis=0)\n",
    "    A = run_with_cache(prompts=prompts, nodes=[node], model=model, batch_size=100)[0]\n",
    "    projs = einsum(\"batch dim, dim -> batch\", A, v)\n",
    "    num_examples = len(patching_dataset.base.prompts)\n",
    "    return pd.DataFrame({\n",
    "        'pattern': ['ABB'] * num_examples + ['BAB'] * num_examples,\n",
    "        'projection': projs.cpu().numpy(),\n",
    "    })\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gradient_projs = get_resid_projections(patching_dataset=TEST_DATASET, v=summed_gradient)\n",
    "das_resid_projs = get_resid_projections(patching_dataset=TEST_DATASET, v=das_resid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gradient_projs['direction'] = 'grad'\n",
    "das_resid_projs['direction'] = 'das'\n",
    "resid_projs_df = pd.concat([gradient_projs, das_resid_projs], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "sns.set_theme()\n",
    "\n",
    "for direction in ('grad', 'das'):\n",
    "    # make a high-res plot\n",
    "    plt.figure(figsize=(10, 10))\n",
    "    ax = sns.histplot(\n",
    "        data=resid_projs_df.query(f'direction == \"{direction}\"'),\n",
    "        x='projection',\n",
    "        hue='pattern',\n",
    "        # do not stack, but overlap them\n",
    "        multiple='layer',\n",
    "        bins=100,\n",
    "        stat='count',\n",
    "        common_norm=False,\n",
    "        element='bars',\n",
    "        fill=True,\n",
    "        alpha=0.5,\n",
    "        legend=True,\n",
    "    )\n",
    "    # plt.legend(title='Pattern', title_fontsize=24, fontsize=20,)\n",
    "    # set the x-axis limits\n",
    "    # increase title font size\n",
    "    # remove the labels of the x and y axes\n",
    "    ax.set_xlabel(None)\n",
    "    ax.set_ylabel(None)\n",
    "    # increase font of the tikz on the axes\n",
    "    ax.tick_params(axis='both', which='major', labelsize=20)\n",
    "    plt.setp(ax.get_legend().get_texts(), fontsize='30') # for legend text\n",
    "    plt.setp(ax.get_legend().get_title(), fontsize='36') # for legend title\n",
    "    plt.savefig(f'figures/resid_projs_{direction}.pdf', format='pdf')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = get_resid_projections(patching_dataset=TEST_DATASET, v=das_resid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "projs = df['projection'].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### check if we can predict the S or IO name from the projections on DAS resid (we can't)\n",
    "io_names = [p.io_name for p in TEST_DATASET.base.prompts] + [p.io_name for p in TEST_DATASET.source.prompts]\n",
    "s_names = [p.s_name for p in TEST_DATASET.base.prompts] + [p.s_name for p in TEST_DATASET.source.prompts]\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "X = projs.reshape(-1, 1)\n",
    "for y in (io_names, s_names):\n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0, )\n",
    "    clf = LogisticRegression(random_state=0, max_iter=1_000).fit(X_train, y_train)\n",
    "    print(f'Test acc: {clf.score(X_test, y_test)}')\n",
    "    print(f'Train acc: {clf.score(X_train, y_train)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_patch_metrics(patching_dataset: PatchingDataset,):\n",
    "    \"\"\"\n",
    "    Compute the following quantities:\n",
    "        - the components of das_mlp8 along the nullspace and rowspace of W_out,\n",
    "          and their norms;\n",
    "        - the logit difference when we only patch the rowspace component\n",
    "        - a histogram showing ABB vs BAB prompts on the nullspace component\n",
    "        - a histogram showing ABB vs BAB prompts on the rowspace component\n",
    "        - the cosine similarity between das_mlp8 @ W_out and true_subspace\n",
    "        - a histogram of ABB vs BAB prompts on true_subspace\n",
    "        - a histogram of ABB vs BAB prompts on das_mlp8 @ W_out \n",
    "        - a histogram of ABB vs BAB after patching along das_mlp8\n",
    "        \n",
    "        - a table of the logit differences and flipped accuracies for:\n",
    "            - the MLP8 direction\n",
    "            - the rowspace component of the MLP8 direction only\n",
    "            - the resid direction\n",
    "    \"\"\"\n",
    "    print(f'norm of das_mlp8_row: {das_mlp8_row.norm()}')\n",
    "    print(f'norm of das_mlp8_null: {das_mlp8_null.norm()}')\n",
    "    print(f'Check that das_mlp8_row and das_mlp8_null are orthogonal: {das_mlp8_row @ das_mlp8_null}')\n",
    "    print(f'Check that das_mlp8_null is in the nullspace of W_out: '\n",
    "          f'{torch.allclose(das_mlp8_null @ W_out, torch.zeros(768).cuda(), atol=1e-5)}')\n",
    "        \n",
    "    aggregate_metrics_rows = []\n",
    "    \n",
    "    ############################################################################ \n",
    "    ### compute clean predictions\n",
    "    ############################################################################ \n",
    "    clean_predictions = run_with_hooks(prompts=patching_dataset.base.prompts, \n",
    "                                  hooks=[], model=model, batch_size=100, \n",
    "                                  answer_tokens=patching_dataset.base.answer_tokens.cuda(),\n",
    "                                  return_predictions=True)\n",
    "    clean_logits = run_with_hooks(prompts=patching_dataset.base.prompts, hooks=[], model=model, batch_size=100,\n",
    "                    answer_tokens=patching_dataset.base.answer_tokens.cuda())\n",
    "    clean_accuracy = (clean_predictions == patching_dataset.base.answer_tokens[:, 0].cuda()).float().mean().item()\n",
    "    clean_logit_diff = (clean_logits[:, 0] - clean_logits[:, 1]).mean().item()\n",
    "    print(f'clean accuracy: {clean_accuracy}')\n",
    "    aggregate_metrics_rows.append({\n",
    "        'intervention': 'clean',\n",
    "        'accuracy': 0.0,\n",
    "        'logit_diff': clean_logit_diff,\n",
    "    })\n",
    "\n",
    "    ############################################################################ \n",
    "    ### intervention: full resid replace\n",
    "    ############################################################################ \n",
    "    full_resid_patcher = Patcher(\n",
    "        nodes=[resid_node],\n",
    "        patch_impl=Full(),\n",
    "    )\n",
    "    patched_acc_resid_full, logit_diff_resid_full = get_patching_acc_and_ld(\n",
    "        patcher=full_resid_patcher,\n",
    "        patching_dataset=patching_dataset,\n",
    "    )\n",
    "    aggregate_metrics_rows.append({\n",
    "        'intervention': 'full resid_post.8 patch',\n",
    "        'accuracy': patched_acc_resid_full,\n",
    "        'logit_diff': logit_diff_resid_full,\n",
    "    })\n",
    "\n",
    "    ############################################################################ \n",
    "    ### compute resid predictions and logit diff\n",
    "    ############################################################################ \n",
    "    resid_patcher = Patcher(\n",
    "        nodes=[resid_node],\n",
    "        patch_impl=DirectionPatch(v=das_resid),\n",
    "    )\n",
    "    patched_accuracy_resid, logit_diff_resid = get_patching_acc_and_ld(\n",
    "        patching_dataset=patching_dataset,\n",
    "        patcher=resid_patcher,\n",
    "    )\n",
    "    aggregate_metrics_rows.append({\n",
    "        'intervention': 'DAS resid_post.8',\n",
    "        'accuracy': patched_accuracy_resid,\n",
    "        'logit_diff': logit_diff_resid,\n",
    "    })\n",
    "    \n",
    "    resid_row, resid_null = decompose_resid_along_namemovers(v=das_resid)\n",
    "    resid_row_unit = resid_row / resid_row.norm()\n",
    "    resid_patcher_rowspace = Patcher(\n",
    "        nodes=[resid_node],\n",
    "        patch_impl=DirectionPatch(v=resid_row_unit),\n",
    "    )\n",
    "    patched_accuracy_resid_row, logit_diff_resid_row = get_patching_acc_and_ld(\n",
    "        patching_dataset=patching_dataset,\n",
    "        patcher=resid_patcher_rowspace,\n",
    "    )\n",
    "    aggregate_metrics_rows.append({\n",
    "        'intervention': 'DAS resid_post.8 along name movers',\n",
    "        'accuracy': patched_accuracy_resid_row,\n",
    "        'logit_diff': logit_diff_resid_row,\n",
    "    })\n",
    "\n",
    "    resid_null_unit = resid_null / resid_null.norm()\n",
    "    resid_patcher_nullspace = Patcher(\n",
    "        nodes=[resid_node],\n",
    "        patch_impl=DirectionPatch(v=resid_null_unit),\n",
    "    )\n",
    "    patched_accuracy_resid_null, logit_diff_resid_null = get_patching_acc_and_ld(\n",
    "        patching_dataset=patching_dataset,\n",
    "        patcher=resid_patcher_nullspace,\n",
    "    )\n",
    "    aggregate_metrics_rows.append({\n",
    "        'intervention': 'DAS resid_post.8 along nullspace',\n",
    "        'accuracy': patched_accuracy_resid_null,\n",
    "        'logit_diff': logit_diff_resid_null,\n",
    "    })\n",
    "\n",
    "    ############################################################################ \n",
    "    ### compute grad predictions and logit diff\n",
    "    ############################################################################ \n",
    "    grad_patcher = Patcher(\n",
    "        nodes=[resid_node],\n",
    "        patch_impl=DirectionPatch(v=summed_gradient),\n",
    "    )\n",
    "    patched_accuracy_grad, logit_diff_grad = get_patching_acc_and_ld(\n",
    "        patching_dataset=patching_dataset,\n",
    "        patcher=grad_patcher,\n",
    "    )\n",
    "    aggregate_metrics_rows.append({\n",
    "        'intervention': 'grad resid_post.8',\n",
    "        'accuracy': patched_accuracy_grad,\n",
    "        'logit_diff': logit_diff_grad,\n",
    "    })\n",
    "    \n",
    "    grad_row, grad_null = decompose_resid_along_namemovers(v=summed_gradient)\n",
    "    grad_row_unit = grad_row / grad_row.norm()\n",
    "    grad_patcher_rowspace = Patcher(\n",
    "        nodes=[resid_node],\n",
    "        patch_impl=DirectionPatch(v=grad_row_unit),\n",
    "    )\n",
    "    patched_accuracy_grad_row, logit_diff_grad_row = get_patching_acc_and_ld(\n",
    "        patching_dataset=patching_dataset,\n",
    "        patcher=grad_patcher_rowspace,\n",
    "    )\n",
    "    aggregate_metrics_rows.append({\n",
    "        'intervention': 'grad resid_post.8 along name movers',\n",
    "        'accuracy': patched_accuracy_grad_row,\n",
    "        'logit_diff': logit_diff_grad_row,\n",
    "    })\n",
    "\n",
    "    grad_null_unit = grad_null / grad_null.norm()\n",
    "    grad_patcher_nullspace = Patcher(\n",
    "        nodes=[resid_node],\n",
    "        patch_impl=DirectionPatch(v=grad_null_unit),\n",
    "    )\n",
    "    patched_accuracy_grad_null, logit_diff_grad_null = get_patching_acc_and_ld(\n",
    "        patching_dataset=patching_dataset,\n",
    "        patcher=grad_patcher_nullspace,\n",
    "    )\n",
    "    aggregate_metrics_rows.append({\n",
    "        'intervention': 'grad resid_post.8 along nullspace',\n",
    "        'accuracy': patched_accuracy_grad_null,\n",
    "        'logit_diff': logit_diff_grad_null,\n",
    "    })\n",
    "\n",
    "    ############################################################################ \n",
    "    ### patch along projection of mean diff direction onto namemovers\n",
    "    ############################################################################ \n",
    "    mean_diff_row, mean_diff_null = decompose_resid_along_namemovers(v=v_mean)\n",
    "    mean_diff_row_unit = mean_diff_row / mean_diff_row.norm()\n",
    "    mean_diff_patcher_row = Patcher(\n",
    "        nodes=[resid_node],\n",
    "        patch_impl=DirectionPatch(v=mean_diff_row_unit),\n",
    "    )\n",
    "    mean_diff_patcher_full = Patcher(\n",
    "        nodes=[resid_node],\n",
    "        patch_impl=DirectionPatch(v=v_mean),\n",
    "    )\n",
    "    patched_accuracy_mean_diff_row, logit_diff_mean_diff_row = get_patching_acc_and_ld(\n",
    "        patching_dataset=patching_dataset,\n",
    "        patcher=mean_diff_patcher_row,\n",
    "    )\n",
    "    aggregate_metrics_rows.append({\n",
    "        'intervention': 'mean diff of ABB vs BAB along name movers',\n",
    "        'accuracy': patched_accuracy_mean_diff_row,\n",
    "        'logit_diff': logit_diff_mean_diff_row,\n",
    "    })\n",
    "\n",
    "    patched_accuracy_mean_diff_full, logit_diff_mean_diff_full = get_patching_acc_and_ld(\n",
    "        patching_dataset=patching_dataset,\n",
    "        patcher=mean_diff_patcher_full,\n",
    "    )\n",
    "    aggregate_metrics_rows.append({\n",
    "        'intervention': 'mean diff of ABB vs BAB',\n",
    "        'accuracy': patched_accuracy_mean_diff_full,\n",
    "        'logit_diff': logit_diff_mean_diff_full,\n",
    "    })\n",
    "\n",
    "    ############################################################################ \n",
    "    ### patching the full MLP8 direction\n",
    "    ############################################################################ \n",
    "    das_mlp8_patcher = Patcher(\n",
    "        nodes=[mlp_node],\n",
    "        patch_impl=DirectionPatch(v=das_mlp8),\n",
    "    )\n",
    "    patched_accuracy_das_mlp8, das_mlp8_logit_diff = get_patching_acc_and_ld(\n",
    "        patcher=das_mlp8_patcher,\n",
    "        patching_dataset=patching_dataset,\n",
    "    )\n",
    "    aggregate_metrics_rows.append({\n",
    "        'intervention': 'DAS MLP8 direction',\n",
    "        'accuracy': patched_accuracy_das_mlp8,\n",
    "        'logit_diff': das_mlp8_logit_diff,\n",
    "    })\n",
    "    \n",
    "    ############################################################################ \n",
    "    ### patching only the rowspace component\n",
    "    ############################################################################ \n",
    "    rowspace_patcher = Patcher(\n",
    "        nodes=[mlp_node],\n",
    "        patch_impl=DirectionPatch(v=das_row_unit),\n",
    "    )\n",
    "    patched_accuracy_rowspace, rowspace_logit_diff = get_patching_acc_and_ld(\n",
    "        patcher=rowspace_patcher,\n",
    "        patching_dataset=patching_dataset,\n",
    "    )\n",
    "    aggregate_metrics_rows.append({\n",
    "        'intervention': 'DAS MLP8 rowspace component',\n",
    "        'accuracy': patched_accuracy_rowspace,\n",
    "        'logit_diff': rowspace_logit_diff,\n",
    "    })\n",
    "\n",
    "    ### patch only nullspace component\n",
    "    nullspace_patcher = Patcher(\n",
    "        nodes=[mlp_node],\n",
    "        patch_impl=DirectionPatch(v=das_null_unit),\n",
    "    )\n",
    "    patched_accuracy_nullspace, nullspace_logit_diff = get_patching_acc_and_ld(\n",
    "        patcher=nullspace_patcher,\n",
    "        patching_dataset=patching_dataset,\n",
    "    )\n",
    "    aggregate_metrics_rows.append({\n",
    "        'intervention': 'DAS MLP8 nullspace component',\n",
    "        'accuracy': patched_accuracy_nullspace,\n",
    "        'logit_diff': nullspace_logit_diff,\n",
    "    })\n",
    "\n",
    "    aggregate_metrics_df = pd.DataFrame(aggregate_metrics_rows)\n",
    "    return aggregate_metrics_df\n",
    "\n",
    "def compute_full_mlp8_patch(patching_dataset: PatchingDataset):\n",
    "    ############################################################################\n",
    "    ### full MLP8 patch\n",
    "    ############################################################################\n",
    "    full_mlp_patcher = Patcher(\n",
    "        nodes=[mlp_node],\n",
    "        patch_impl=Full()\n",
    "    )\n",
    "    patched_accuracy_full_mlp, full_mlp_logit_diff = get_patching_acc_and_ld(\n",
    "        patcher=full_mlp_patcher,\n",
    "        patching_dataset=patching_dataset,\n",
    "    )\n",
    "    return patched_accuracy_full_mlp, full_mlp_logit_diff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_full_mlp8_patch(patching_dataset=PATCHING_DATASET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_df = compute_patch_metrics(patching_dataset=PATCHING_DATASET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "joblib.dump(metrics_df, 'patching_metrics_ioi.joblib')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_df.query('intervention == \"clean\"')['logit_diff'].item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def include_fractional_ld(df):\n",
    "    df = df.copy()\n",
    "    clean_ld = df.query('intervention == \"clean\"')['logit_diff'].item()\n",
    "    df['frac_ld'] = df['logit_diff'].apply(lambda x: np.round(100 * (x / clean_ld), 2))\n",
    "    return df\n",
    "metrics_df = include_fractional_ld(metrics_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_df.round(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def histograms(patching_dataset, true_subspace):\n",
    "    ############################################################################ \n",
    "    ### histograms within MLP8\n",
    "    ############################################################################ \n",
    "    # mlp_activations = run_with_cache(\n",
    "    #   prompts=patching_dataset.base.prompts,\n",
    "    #   batch_size=100,\n",
    "    #   model=model, \n",
    "    #   nodes=[mlp_node],\n",
    "    # )[0]\n",
    "    # activation_mean_diff_mlp8 = mlp_activations[:num_examples].mean() - mlp_activations[num_examples:].mean()\n",
    "    # # print(f'Cosine sim between nullspace direction and activation mean diff: '\n",
    "    # #         f'{torch.cosine_similarity(das_null_unit, activation_mean_diff_mlp8, dim=0)}')\n",
    "    # # print(f'Cosine sim between rowspace direction and activation mean diff: '\n",
    "    # #         f'{torch.cosine_similarity(das_row_unit, activation_mean_diff_mlp8, dim=0)}')\n",
    "    # acts_at_row = einsum('batch d_mlp, d_mlp -> batch', mlp_activations, das_row_unit)\n",
    "    # acts_at_null = einsum('batch d_mlp, d_mlp -> batch', mlp_activations, das_null_unit)\n",
    "    # row_df = pd.DataFrame({\n",
    "    #     'value': acts_at_row.cpu(),\n",
    "    #     'variable': 'row',\n",
    "    #     'pattern': ['ABB'] * num_examples + ['BAB'] * num_examples,\n",
    "    # }) \n",
    "    # null_df = pd.DataFrame({\n",
    "    #     'value': acts_at_null.cpu(),\n",
    "    #     'variable': 'null',\n",
    "    #     'pattern': ['ABB'] * num_examples + ['BAB'] * num_examples,\n",
    "    # })\n",
    "    # df_within_mlp8 = pd.concat([row_df, null_df])\n",
    "\n",
    "    ############################################################################     \n",
    "    ### histograms onto true subspace direction in the residual stream\n",
    "    ############################################################################     \n",
    "    mlp_out_node = Node('mlp_out', layer=8, seq_pos=-1)\n",
    "\n",
    "    ### first, see what this looks like under normal conditions\n",
    "    num_examples = len(patching_dataset.base.prompts) // 2\n",
    "    mlp_out_activations = run_with_cache(\n",
    "        prompts=patching_dataset.base.prompts,\n",
    "        batch_size=100,\n",
    "        model=model,\n",
    "        nodes=[mlp_out_node],\n",
    "    )[0]\n",
    "    df_mlp_out_normal = pd.DataFrame({\n",
    "        'value': einsum('batch d_model, d_model -> batch', mlp_out_activations, true_subspace).cpu(),\n",
    "        'variable': 'normal',\n",
    "        'pattern': ['ABB'] * num_examples + ['BAB'] * num_examples,\n",
    "    })\n",
    "\n",
    "    das_mlp8_patcher = Patcher(\n",
    "        nodes=[mlp_node],\n",
    "        patch_impl=DirectionPatch(v=das_mlp8),\n",
    "    )\n",
    "    ### now, see what this looks like after patching\n",
    "    mlp_out_patched_activations = das_mlp8_patcher.get_patched_activation(\n",
    "        model=model,\n",
    "        node=mlp_out_node,\n",
    "        X_base=patching_dataset.base.tokens,\n",
    "        X_source=patching_dataset.source.tokens,\n",
    "        batch_size=100,\n",
    "        cache_base=None, \n",
    "        cache_source=None,\n",
    "    )\n",
    "    df_mlp_out_patched = pd.DataFrame({\n",
    "        'value': einsum('batch d_model, d_model -> batch', mlp_out_patched_activations, true_subspace).cpu(),\n",
    "        'variable': 'patched',\n",
    "        'pattern': ['BAB -> ABB'] * num_examples + ['ABB -> BAB'] * num_examples,\n",
    "    })\n",
    "    df_mlp_out = pd.concat([df_mlp_out_normal, df_mlp_out_patched])\n",
    "    return df_mlp_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_mlp_out = histograms(PATCHING_DATASET, summed_gradient)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_mlp_out.query('pattern == \"ABB -> BAB\"').value.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_theme()\n",
    "# make a high-res plot\n",
    "plt.figure(figsize=(10, 10))\n",
    "ax = sns.histplot(\n",
    "    data=df_mlp_out,\n",
    "    x='value',\n",
    "    hue='pattern',\n",
    "    # do not stack, but overlap them\n",
    "    # multiple='layer',\n",
    "    bins=100,\n",
    "    stat='count',\n",
    "    # common_norm=False,\n",
    "    element='bars',\n",
    "    fill=True,\n",
    "    # put the legend  in the upper left corner\n",
    "    alpha=0.5,\n",
    ")\n",
    "# do the legend manually, using values from the \"pattern\" column\n",
    "from matplotlib.patches import Patch\n",
    "ax.legend(\n",
    "    handles=[\n",
    "        Patch(facecolor='C0', label='ABB (no intervention)'),\n",
    "        Patch(facecolor='C1', label='BAB (no intervention)'),\n",
    "        Patch(facecolor='C2', label='patch BAB -> ABB'),\n",
    "        Patch(facecolor='C3', label='patch ABB -> BAB'),\n",
    "    ],\n",
    "    title='Input',\n",
    "    title_fontsize=24,\n",
    "    fontsize=18,\n",
    "    loc='upper left',\n",
    "    bbox_to_anchor=(0.0, 1.0),\n",
    "    ncol=1,\n",
    ")\n",
    "# set the x-axis limits\n",
    "# ax.set_xlim(-1.5, 3.0)\n",
    "# ax.set_title('Output of MLP8 projected onto ground truth subspace')\n",
    "# increase title font size\n",
    "# remove the labels of the x and y axes\n",
    "ax.set_xlabel(None)\n",
    "ax.set_ylabel(None)\n",
    "# increase font of the tikz on the axes\n",
    "ax.tick_params(axis='both', which='major', labelsize=20)\n",
    "# increase legend font size\n",
    "# plt.setp(ax.get_legend().get_texts(), fontsize='30') # for legend text\n",
    "# plt.setp(ax.get_legend().get_title(), fontsize='36') # for legend title\n",
    "plt.savefig(f'figures/mlp_output_histograms.pdf', format='pdf')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.12",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "fd8dbfdfd1a6a4c5f3a98a8b5f239185c4ac44e8c535538c941237e2ab93d1b0"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
