{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from common_imports import *\n",
    "from fact_utils import setup_counterfact, COUNTERFACT_PATH, get_covariance_path\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "DEBUGGING = True # set to False to run the full experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "setup_counterfact()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load and setup model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = \"gpt2\" if DEBUGGING else \"gpt2-xl\"\n",
    "model, tok = (\n",
    "    AutoModelForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=False).to(\n",
    "        \"cuda\"\n",
    "    ),\n",
    "    AutoTokenizer.from_pretrained(MODEL_NAME),\n",
    ")\n",
    "model.requires_grad_(False);\n",
    "N_LAYERS = len(model.transformer.h)\n",
    "D_MLP, D_MODEL = model.transformer.h[0].mlp.c_proj.weight.shape\n",
    "tok.pad_token = tok.eos_token\n",
    "# save the original weights to be able to recover after edits\n",
    "ORIG_WEIGHTS = {\n",
    "    f'transformer.h.{layer}.mlp.c_proj.weight': model.transformer.h[layer].mlp.c_proj.weight.clone().detach()\n",
    "    for layer in range(N_LAYERS)\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_with_bos(s: str) -> List[int]:\n",
    "    \"\"\"\n",
    "    Tokenize a string with a bos token prepended.\n",
    "    \"\"\"\n",
    "    return [tok.bos_token] + tok.tokenize(s)\n",
    "\n",
    "def encode_with_bos(s: str) -> Tensor:\n",
    "    \"\"\"\n",
    "    Encode a string as a tensor with a bos token prepended.\n",
    "    \"\"\"\n",
    "    input_ids = tok(s, return_tensors='pt').input_ids.to('cuda')\n",
    "    input_ids = torch.cat([torch.Tensor([tok.bos_token_id]).cuda().long(), input_ids[0]], dim=0)\n",
    "    return input_ids.unsqueeze(0)\n",
    "\n",
    "def mencode_with_bos(ss: List[str]) -> Tensor:\n",
    "    \"\"\"\n",
    "    Encode a list of strings as a tensor with a bos token prepended, padding to\n",
    "    the length of the longest string.\n",
    "    \"\"\"\n",
    "    input_ids = tok(ss, return_tensors='pt', padding=True).input_ids.to('cuda')\n",
    "    input_ids = torch.cat([torch.Tensor([tok.bos_token_id]).cuda().long().unsqueeze(0).repeat(input_ids.shape[0], 1), input_ids], dim=1)\n",
    "    return input_ids\n",
    "\n",
    "def remove_all_hooks(model: nn.Module):\n",
    "    \"\"\"\n",
    "    Remove any hooks put on a torch.nn.Module, recursively through its\n",
    "    submodules.\n",
    "    \"\"\"\n",
    "    for _, submodule in model.named_modules():\n",
    "        submodule._forward_hooks.clear()\n",
    "\n",
    "def get_last_subj_token_idx(prompt: str, subject: str) -> int:\n",
    "    \"\"\"\n",
    "    Given a prompt of the form \"... {} ...\", find the index of the last token\n",
    "    after we \n",
    "        - replace {} with `subject`, and\n",
    "        - tokenize the prompt by prepending a bos token.\n",
    "    \"\"\"\n",
    "    # cut off the prompt after the placeholder\n",
    "    prompt = prompt.split('{}')[0]\n",
    "    prompt = prompt + r\"{}\"\n",
    "    # find the placeholder in the prompt\n",
    "    prompt = prompt.replace('{}', subject)\n",
    "    # tokenize the prompt\n",
    "    tokens = tokenize_with_bos(prompt)\n",
    "    # find the index of the last token corresponding to the subject\n",
    "    last_subj_token_idx = len(tokens) - 1\n",
    "    return last_subj_token_idx\n",
    "\n",
    "def get_neuronal_activations(model: AutoModelForCausalLM,\n",
    "                             prompt: str, layer: int, seq_pos: int,\n",
    "                             ) -> Tensor:\n",
    "    \"\"\"\n",
    "    Get the post-nonlinearity activations of the prompt in the given MLP layer,\n",
    "    at the given token position.\n",
    "    \"\"\"\n",
    "    container = []\n",
    "    def hook_fn(module, input, output):\n",
    "        container.append(output)\n",
    "    handle = model.transformer.h[layer].mlp.act.register_forward_hook(hook_fn)\n",
    "    tokens = encode_with_bos(prompt)\n",
    "    _ = model(tokens)\n",
    "    value = container[0][0, seq_pos, :]\n",
    "    handle.remove()\n",
    "    return value\n",
    "\n",
    "def get_covariance(layer: int, model_name=MODEL_NAME) -> Tensor:\n",
    "    \"\"\"\n",
    "    Load covariance matrix of activations for a given layer.\n",
    "    \"\"\"\n",
    "    if DEBUGGING:\n",
    "        X = torch.randn(D_MLP, D_MLP).cuda()\n",
    "        return X @ X.T\n",
    "    else:\n",
    "        npz_path = get_covariance_path(layer=layer)\n",
    "        data = np.load(npz_path)\n",
    "        return data['mom2.mom2'].float().cuda()\n",
    "\n",
    "def get_W_proj(layer: int) -> Tensor:\n",
    "    \"\"\"\n",
    "    Get the projection weight for given MLP layer.\n",
    "    \"\"\"\n",
    "    return model.transformer.h[layer].mlp.c_proj.weight.detach().clone()\n",
    "\n",
    "def decompose_along_W(W: Tensor, v: Tensor, normalize: bool) -> Tuple[Tensor, Tensor]:\n",
    "    \"\"\"\n",
    "    Return the rowspace and nullspace components of a vector v along W.\n",
    "    \"\"\"\n",
    "    Q, _ = torch.linalg.qr(W)\n",
    "    rowspace_component = v @ Q @ Q.T\n",
    "    nullspace_component = v - rowspace_component\n",
    "    if normalize:\n",
    "        rowspace_component = rowspace_component / rowspace_component.norm()\n",
    "        nullspace_component = nullspace_component / nullspace_component.norm()\n",
    "    return rowspace_component, nullspace_component"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate a fact patching dataset from CounterFact"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_EXAMPLES = 1_000\n",
    "N_PER_PROMPT = 5\n",
    "with open(COUNTERFACT_PATH, 'r') as f:\n",
    "    COUNTERFACT = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def group_by_relation(cf_examples: List[dict]) -> List[List[dict]]:\n",
    "    \"\"\"\n",
    "    Group given counterfact examples by relation id\n",
    "    \"\"\"\n",
    "    res = defaultdict(list)\n",
    "    for ex in cf_examples:\n",
    "        relation_id = ex['requested_rewrite']['relation_id']\n",
    "        res[relation_id].append(ex)\n",
    "    return list(res.values())\n",
    "\n",
    "def check_knowledge(prompt: str, subject: str, target: str) -> bool:\n",
    "    \"\"\"\n",
    "    Check if the model knows the given fact (i.e. predicts the target)\n",
    "    \"\"\"\n",
    "    tokens = encode_with_bos(prompt.format(subject))\n",
    "    logits = model(tokens).logits[0, -1, :]\n",
    "    obj_id = tok.encode(f' {target}')[0]\n",
    "    return logits.argmax() == obj_id\n",
    "\n",
    "def collect_fact_patches(counterfact_examples: List[dict],) -> List[List[dict]]:\n",
    "    \"\"\"\n",
    "    Generate a fact patching dataset from the counterfact dataset.\n",
    "    Return a list of lists of fact patching examples of the form\n",
    "    {\n",
    "        \"prompt\": e.g. 'The mother tongue of {} is\",\n",
    "        \"base_subject\": e.g. 'Danielle Darrieux',\n",
    "        \"base_target\": e.g. 'French',\n",
    "        \"source_subject\": e.g. 'Thomas Joannes Stieltjes', \n",
    "        \"source_target\": e.g. 'Dutch',\n",
    "    }\n",
    "    where within each list, all examples come with the same prompt (which is\n",
    "    even stricter than the relation_id, which allows for different prompts\n",
    "    expressing the same relation).\n",
    "    \n",
    "    Constraints:\n",
    "        - the model must know both the base and source facts\n",
    "        - the targets must be different\n",
    "    \"\"\"\n",
    "    groups = group_by_relation(counterfact_examples)\n",
    "    res = []\n",
    "    for gp in tqdm(groups):\n",
    "        gp_res = []\n",
    "        prompt = gp[0]['requested_rewrite']['prompt']\n",
    "        known_facts = [check_knowledge(prompt=prompt, \n",
    "                subject=elt['requested_rewrite']['subject'],\n",
    "                target=elt['requested_rewrite']['target_true']['str'])\n",
    "                       for elt in gp]\n",
    "        for i, elt_1 in enumerate(gp):\n",
    "            for j, elt_2 in enumerate(gp):\n",
    "                if i == j:\n",
    "                    continue\n",
    "                if not known_facts[i] or not known_facts[j]:\n",
    "                    continue\n",
    "                req_rewrite_1 = elt_1['requested_rewrite']\n",
    "                req_rewrite_2 = elt_2['requested_rewrite']\n",
    "                base_target = req_rewrite_1['target_true']['str']\n",
    "                source_target = req_rewrite_2['target_true']['str']\n",
    "                if base_target == source_target:\n",
    "                    continue\n",
    "                gp_res.append({\n",
    "                    'prompt': prompt,\n",
    "                    'base_subject': req_rewrite_1['subject'],\n",
    "                    'base_target': base_target,\n",
    "                    'source_subject': req_rewrite_2['subject'],\n",
    "                    'source_target': source_target,\n",
    "                })\n",
    "        res.append(gp_res)\n",
    "    return res\n",
    "\n",
    "def sample_fact_patches(fact_patching_dataset: List[List[dict]], n_per_prompt: int) -> List[dict]:\n",
    "    \"\"\"\n",
    "    Given a list of lists of fact patching examples as returned by\n",
    "    `collect_fact_patches`, sample `n_per_prompt` examples from each list (i.e.\n",
    "    for each prompt) having at least `n_per_prompt` examples.\n",
    "    \"\"\"\n",
    "    res = []\n",
    "    for gp in fact_patching_dataset:\n",
    "        if len(gp) >= n_per_prompt:\n",
    "            # ensure we sample without replacement\n",
    "            gp_res = random.sample(gp, n_per_prompt)\n",
    "            res.extend(gp_res)\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "FACT_PATCHING_DATASET = collect_fact_patches(COUNTERFACT[:N_EXAMPLES])\n",
    "FACT_PATCHING_SAMPLES = sample_fact_patches(FACT_PATCHING_DATASET, N_PER_PROMPT)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning 1-dimensional activation patches to change factual recall"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LearnableDirection(nn.Module):\n",
    "    \"\"\"\n",
    "    A learnable 1-dimensional subspace parametrized by a unit vector (the unit\n",
    "    norm constraint is enforced in the training loop).\n",
    "    \"\"\"\n",
    "    def __init__(self, dim: int):\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "        initial_value = torch.randn(dim)\n",
    "        initial_value = initial_value / initial_value.norm()\n",
    "        self.direction = nn.Parameter(initial_value)\n",
    "\n",
    "def mget_patched_logits(\n",
    "    model: AutoModelForCausalLM,\n",
    "    prompts: List[str],\n",
    "    layer: int, \n",
    "    patching_positions: List[int],\n",
    "    vs: List[LearnableDirection],\n",
    "    source_activations: Tensor,\n",
    "    last_token_positions: Optional[List[int]] = None,\n",
    "    ) -> Tensor:\n",
    "    \"\"\"\n",
    "    Get the logits after patching given token positions along given directions\n",
    "    from the given source activations.\n",
    "    \"\"\"\n",
    "    if last_token_positions is None:\n",
    "        prompts_encoded_separately = [encode_with_bos(prompt).squeeze(0) for prompt in prompts]\n",
    "        last_token_positions = [len(prompt) - 1 for prompt in prompts_encoded_separately]\n",
    "    n = len(prompts)\n",
    "    directions_tensor = torch.stack([v.direction for v in vs], dim=0)\n",
    "    def hook_fn(module, input, output):\n",
    "        acts = output[list(range(n)), patching_positions, :] # n x d_mlp\n",
    "        current_projs = einsum(\"n_examples d_mlp, n_examples d_mlp -> n_examples\", acts, directions_tensor)\n",
    "        desired_projs = einsum(\"n_examples d_mlp, n_examples d_mlp -> n_examples\", source_activations, directions_tensor)\n",
    "        new_act = acts.clone() + einsum(\"n_examples, n_examples d_mlp -> n_examples d_mlp\", desired_projs - current_projs, directions_tensor) / (directions_tensor.norm(dim=-1) ** 2).unsqueeze(1)\n",
    "        mask = torch.zeros_like(output)\n",
    "        mask[list(range(n)), patching_positions, :] = 1\n",
    "        new_output = torch.where(mask.bool(), new_act.unsqueeze(1), output)\n",
    "        return new_output\n",
    "    handle = model.transformer.h[layer].mlp.act.register_forward_hook(hook_fn)\n",
    "    tokens = mencode_with_bos(prompts)\n",
    "    logits = model(tokens).logits\n",
    "    logits = logits[list(range(n)), last_token_positions, :] # n x d_vocab\n",
    "    handle.remove()\n",
    "    return logits\n",
    "\n",
    "def mtrain_das(\n",
    "    base_prompts: List[str], source_prompts: List[str],\n",
    "    base_targets: List[str], source_targets: List[str],\n",
    "    layer: int, \n",
    "    source_last_subj_poss: List[int], # position of last subject token in source_prompt\n",
    "    base_last_subj_poss: List[int], # position of last subject token in base_prompt\n",
    "    finishing_epochs: int = 200,\n",
    "    lr: float = 1e-3, n_steps: int = 1000,\n",
    "    end_factor: float = 1e-3,\n",
    "    ):\n",
    "    \"\"\"\n",
    "    Train 1-dimensional subspaces to change the given facts. This is batched,\n",
    "    which speeds up the training compared to training each fact patching\n",
    "    pair separately. \n",
    "    \"\"\"\n",
    "    n = len(base_prompts)\n",
    "    # find the last token positions for the base prompts\n",
    "    base_prompts_encoded_separately = [encode_with_bos(prompt).squeeze(0) for prompt in base_prompts]\n",
    "    base_last_token_positions = [len(prompt) - 1 for prompt in base_prompts_encoded_separately]\n",
    "    \n",
    "    vs = [LearnableDirection(D_MLP).cuda() for _ in range(n)]\n",
    "    base_idxs = [tok.encode(f' {base_target}')[0] for base_target in base_targets]\n",
    "    source_idxs = [tok.encode(f' {source_target}')[0] for source_target in source_targets]\n",
    "    source_activations = torch.stack([get_neuronal_activations(model, source_prompt, layer, source_last_subj_pos)\n",
    "                            for source_prompt, source_last_subj_pos in zip(source_prompts, source_last_subj_poss)], dim=0)\n",
    "    optimizer = torch.optim.SGD([v.direction for v in vs], lr=lr)\n",
    "    lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, end_factor=end_factor, total_iters=n_steps)\n",
    "    pbar = tqdm(range(n_steps + finishing_epochs))\n",
    "    losses_per_step = []\n",
    "    solutions_per_step = []\n",
    "    for i in pbar:\n",
    "        optimizer.zero_grad()\n",
    "        logits = mget_patched_logits(\n",
    "            model=model, prompts=base_prompts, layer=layer, \n",
    "            vs=vs, source_activations=source_activations,\n",
    "            patching_positions=base_last_subj_poss,\n",
    "            last_token_positions=base_last_token_positions,\n",
    "        ) # shape (n, vocab_size)\n",
    "        logit_diffs = logits[range(n), base_idxs] - logits[range(n), source_idxs]\n",
    "        loss = logit_diffs.sum()\n",
    "        losses_this_step = [logit_diffs[i].item() for i in range(n)]\n",
    "        losses_per_step.append(losses_this_step)\n",
    "        solutions_per_step.append([v.direction.data.detach().cpu().numpy() for v in vs])\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if i <= n_steps:\n",
    "            lr_scheduler.step()\n",
    "        # normalize directions\n",
    "        for v in vs:\n",
    "            v.direction.data = v.direction.data / v.direction.data.norm()\n",
    "        best_loss_per_example = [min(losses[i] for losses in losses_per_step) for i in range(n)]\n",
    "        pbar.set_description(f'loss: {loss.item():.3f}, best losses: {best_loss_per_example}')\n",
    "    # return the losses and the solution with the lowest loss\n",
    "    losses_per_example = [[losses_per_step[i][j] for i in range(len(losses_per_step))] for j in range(n)]\n",
    "    solutions_per_example = [[solutions_per_step[i][j] for i in range(len(solutions_per_step))] for j in range(n)]\n",
    "    best_indices = [np.argmin(losses) for losses in losses_per_example]\n",
    "    for i, v in enumerate(vs):\n",
    "        v.direction.data = torch.tensor(solutions_per_example[i][best_indices[i]]).cuda()\n",
    "    return vs, losses_per_example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analyzing fact patches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_patched_logits_full_mlp(\n",
    "    base_prompt: str, source_prompt: str, \n",
    "    base_target: str, source_target: str,\n",
    "    layer: int, source_last_subj_pos: int, base_last_subj_pos: int,\n",
    ") -> Tensor:\n",
    "    \"\"\"\n",
    "    Return the logit diff from patching the entire MLP at the given layer and\n",
    "    positions. \n",
    "    \"\"\"\n",
    "    source_activation = get_neuronal_activations(model, source_prompt,\n",
    "                                                 layer, source_last_subj_pos)\n",
    "    def hook_fn(module, input, output):\n",
    "        output[0, base_last_subj_pos, :] = source_activation\n",
    "        return output\n",
    "    handle = model.transformer.h[layer].mlp.act.register_forward_hook(hook_fn)\n",
    "    tokens = encode_with_bos(base_prompt)\n",
    "    logits = model(tokens).logits[:, -1, :]\n",
    "    handle.remove()\n",
    "    return logits\n",
    "\n",
    "def analyze_das_patch(\n",
    "    base_prompt: str, source_prompt: str, \n",
    "    base_target: str, source_target: str,\n",
    "    layer: int, source_last_subj_pos: int, base_last_subj_pos: int,\n",
    "    das_result: LearnableDirection,\n",
    "    ) -> Tuple[dict, pd.DataFrame]:\n",
    "    \"\"\"\n",
    "    Compute the norms of the nullspace/rowspace components of the patching\n",
    "    directions, and the logit diffs when patching along this direction, as well\n",
    "    as several baselines:\n",
    "        - patching the entire MLP\n",
    "        - patching only the rowspace component\n",
    "        - clean run (no intervention) \n",
    "    \"\"\"\n",
    "    norm_metrics = {}\n",
    "    patching_metrics = []\n",
    "    W = get_W_proj(layer=layer)\n",
    "    v = das_result.direction.data.detach().clone()\n",
    "    das_result = LearnableDirection(D_MLP).cuda()\n",
    "    das_result.direction.data = v\n",
    "    v_row, v_null = decompose_along_W(W=W, v=v, normalize=False)\n",
    "    norm_metrics['row_norm'] = v_row.norm().item()\n",
    "    norm_metrics['null_norm'] = v_null.norm().item()\n",
    "\n",
    "    base_idx = tok.encode(f' {base_target}')[0]\n",
    "    source_idx = tok.encode(f' {source_target}')[0]\n",
    "    source_activation = get_neuronal_activations(model, source_prompt, layer, source_last_subj_pos)\n",
    "\n",
    "    ### evaluate the baseline logit diff for this fact\n",
    "    clean_logits = model(encode_with_bos(base_prompt)).logits[:, -1, :]\n",
    "    clean_logit_diff = clean_logits[0, base_idx] - clean_logits[0, source_idx]\n",
    "    # metrics['clean_logit_diff'] = clean_logit_diff.item()\n",
    "    patching_metrics.append({\n",
    "        'method': 'clean',\n",
    "        'ld': clean_logit_diff.item(),\n",
    "        'predict_source': False,\n",
    "        'predict_base': True,\n",
    "        'prediction': tok.decode(clean_logits[0, :].argmax().item()),\n",
    "    })\n",
    "\n",
    "    ### patch using the full direction\n",
    "    patched_logits = mget_patched_logits(\n",
    "        model=model, prompts=[base_prompt], layer=layer,\n",
    "        vs=[das_result], source_activations=source_activation.unsqueeze(0),\n",
    "        patching_positions=[base_last_subj_pos],\n",
    "        last_token_positions=None,\n",
    "    )[0]\n",
    "    # patched_logits = get_patched_logits(model, prompt=base_prompt,\n",
    "    #                                     layer=layer, seq_pos=base_last_subj_pos,\n",
    "    #                                     v=das_result, source_activation=source_activation)\n",
    "    logit_diff = patched_logits[base_idx] - patched_logits[source_idx]\n",
    "    patched_prediction = int(patched_logits.argmax().item())\n",
    "    patching_metrics.append({\n",
    "        'method': 'das',\n",
    "        'ld': logit_diff.item(),\n",
    "        'predict_source': patched_prediction == source_idx,\n",
    "        'predict_base': patched_prediction == base_idx,\n",
    "        'prediction': tok.decode(patched_prediction),\n",
    "    })\n",
    "\n",
    "    ### patch the entire MLP\n",
    "    patched_logits_full_mlp = get_patched_logits_full_mlp(base_prompt, source_prompt,\n",
    "                                            base_target, source_target,\n",
    "                                            layer, source_last_subj_pos, base_last_subj_pos)\n",
    "    full_mlp_logit_diff = patched_logits_full_mlp[0, base_idx] - patched_logits_full_mlp[0, source_idx]\n",
    "    full_mlp_prediction = int(patched_logits_full_mlp[0, :].argmax().item())\n",
    "    patching_metrics.append({\n",
    "        'method': 'full_mlp',\n",
    "        'ld': full_mlp_logit_diff.item(),\n",
    "        'predict_source': full_mlp_prediction == source_idx,\n",
    "        'predict_base': full_mlp_prediction == base_idx,\n",
    "        'prediction': tok.decode(full_mlp_prediction),\n",
    "    })\n",
    "\n",
    "    ### patch using the row component only  \n",
    "    v_row_unit = v_row / v_row.norm()\n",
    "    das_row = LearnableDirection(D_MLP).cuda()\n",
    "    das_row.direction.data = v_row_unit\n",
    "    # patched_logits_row = get_patched_logits(model, prompt=base_prompt, \n",
    "    #                    layer=layer, seq_pos=base_last_subj_pos,\n",
    "    #                    v=das_row, source_activation=source_activation)\n",
    "    patched_logits_row = mget_patched_logits(\n",
    "        model=model, prompts=[base_prompt], layer=layer,\n",
    "        vs=[das_row], source_activations=source_activation.unsqueeze(0),\n",
    "        patching_positions=[base_last_subj_pos],\n",
    "        last_token_positions=None,\n",
    "    )[0]\n",
    "    row_logit_diff = patched_logits_row[base_idx] - patched_logits_row[source_idx]\n",
    "    row_prediction = int(patched_logits_row.argmax().item())\n",
    "    patching_metrics.append({\n",
    "        'method': 'row',\n",
    "        'ld': row_logit_diff.item(),\n",
    "        'predict_source': row_prediction == source_idx,\n",
    "        'predict_base': row_prediction == base_idx,\n",
    "        'prediction': tok.decode(row_prediction),\n",
    "    })\n",
    "    return norm_metrics, pd.DataFrame(patching_metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Main activation patching experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mrun_fact_patching(patching_facts: List[dict],\n",
    "                      layer_step: int,\n",
    "                      n_steps: int = 600, finishing_epochs: int = 0,\n",
    "                      num_trials_per_lr: int = 2, \n",
    "                      end_factor: float = 1e-3,\n",
    "                      lrs: tuple = (3*1e-1, 1e-1, 3*1e-2, 1e-2,)):\n",
    "    \"\"\"\n",
    "    Given fact patching pairs, run a sweep over learning rates and return the\n",
    "    best 1-dimensional subspaces found.\n",
    "    \n",
    "    Returns:\n",
    "        - best_vs: list of LearnableDirections that achieve the lowest logit\n",
    "        difference across all training steps and hyperparameter settings for\n",
    "        each fact patching pair\n",
    "        - norm_metrics_df: columns (layer, fact_idx, row_norm, null_norm)\n",
    "        - patching_metrics_df: (layer, fact_idx, method, ld, predict_source, predict_base, prediction)\n",
    "    \"\"\"\n",
    "    base_prompts = [fp['prompt'].format(fp['base_subject']) for fp in patching_facts]\n",
    "    source_prompts = [fp['prompt'].format(fp['source_subject']) for fp in patching_facts]\n",
    "    base_targets = [fp[\"base_target\"] for fp in patching_facts]\n",
    "    source_targets = [fp[\"source_target\"] for fp in patching_facts]\n",
    "    base_seq_poss = [get_last_subj_token_idx(fp['prompt'], fp['base_subject']) for fp in patching_facts]\n",
    "    source_seq_poss = [get_last_subj_token_idx(fp['prompt'], fp['source_subject']) for fp in patching_facts]\n",
    "\n",
    "    norm_metric_rows = []\n",
    "    patching_metric_dfs = []\n",
    "    best_vs = []\n",
    "    lrs = lrs * num_trials_per_lr\n",
    "    for layer in range(0, N_LAYERS, layer_step):\n",
    "        best_loss = [float('inf') for _ in range(len(patching_facts))]\n",
    "        best_vs_for_layer = [None for _ in range(len(patching_facts))]\n",
    "        for lr in lrs:\n",
    "            vs, losses_per_example = mtrain_das(\n",
    "                base_prompts=base_prompts, source_prompts=source_prompts,\n",
    "                base_targets=base_targets, source_targets=source_targets,\n",
    "                end_factor=end_factor,\n",
    "                layer=layer, source_last_subj_poss=source_seq_poss, base_last_subj_poss=base_seq_poss,\n",
    "                lr=lr, n_steps=n_steps, finishing_epochs=finishing_epochs,\n",
    "            )\n",
    "            for i, losses in enumerate(losses_per_example):\n",
    "                if min(losses) < best_loss[i]:\n",
    "                    best_loss[i] = min(losses)\n",
    "                    best_vs_for_layer[i] = vs[i]\n",
    "        for i, best_v in enumerate(best_vs_for_layer):\n",
    "            norm_metrics, patching_metrics_df = analyze_das_patch(\n",
    "                base_prompt=base_prompts[i], source_prompt=source_prompts[i],\n",
    "                base_target=base_targets[i], source_target=source_targets[i],\n",
    "                layer=layer, source_last_subj_pos=source_seq_poss[i], base_last_subj_pos=base_seq_poss[i],\n",
    "                das_result=best_v,\n",
    "            )\n",
    "            norm_metrics['layer'] = layer\n",
    "            norm_metrics['fact_idx'] = i\n",
    "            norm_metric_rows.append(norm_metrics)\n",
    "            patching_metrics_df['layer'] = layer\n",
    "            patching_metrics_df['fact_idx'] = i\n",
    "            patching_metric_dfs.append(patching_metrics_df)\n",
    "            best_vs.append({'layer': layer, 'v': best_v.direction.detach().cpu().numpy(), 'fact_idx': i})\n",
    "    norm_metrics_df = pd.DataFrame(norm_metric_rows)\n",
    "    patching_metrics_df = pd.concat(patching_metric_dfs, ignore_index=True)\n",
    "    return best_vs, norm_metrics_df, patching_metrics_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "remove_all_hooks(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if DEBUGGING:\n",
    "    best_vs, norm_metrics_df, patching_metrics_df = mrun_fact_patching(\n",
    "        patching_facts=random.sample(FACT_PATCHING_SAMPLES, 5),\n",
    "        n_steps=100,\n",
    "        layer_step=10, \n",
    "        lrs=(1e-1, ),\n",
    "        num_trials_per_lr=1,\n",
    "    )\n",
    "else:\n",
    "    best_vs, norm_metrics_df, patching_metrics_df = mrun_fact_patching(\n",
    "        patching_facts=FACT_PATCHING_SAMPLES,\n",
    "        layer_step=5,\n",
    "        num_trials_per_lr=2,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# From patches to rank-1 edits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_patch_to_edit(\n",
    "    layer: int, v: np.ndarray, \n",
    "    patching_fact: dict,\n",
    "):\n",
    "    \"\"\"\n",
    "    Return vectors a, b so that the given patching direction is equivalent to\n",
    "    the weight edit W <--- W + ab^T, and moreover a, b are chosen to minimize\n",
    "    the \"damage\". \n",
    "    \n",
    "    Also, return a bunch of metrics for the edit, such as \n",
    "\n",
    "    Find vectors a, b such that the patch is equivalent to the weight edit \n",
    "                                W <--- W + ab^T\n",
    "    the patch does\n",
    "       u_base' = u_base + ((u_source - u_base)^T v) v\n",
    "    so\n",
    "     W u_base' = W u_base + ((u_source - u_base)^T v)Wv\n",
    "    we want an edit with the same effect on the activations, \n",
    "    so we want \n",
    "    W'u_base = W u_base' = W u_base + ((u_source - u_base)^T v)Wv\n",
    "    ((u_source - u_base)^T v)Wv = (b^Tu_base)a\n",
    "    and we can choose a = Wu_base' - Wu_base = W((u_source - u_base)^T v)v,\n",
    "    and b such that \n",
    "    b^Tu_base = 1\n",
    "    to minimize the variance along the direction being introduced, we choose \n",
    "    b to minimize b^TSigma b subject to b^Tu_base = 1\n",
    "    and the solution is b = Sigma^{-1}u_base / (u_base^T Sigma^{-1} u_base)\n",
    "    \"\"\"\n",
    "    v = torch.tensor(v).cuda()\n",
    "\n",
    "    base_prompt = patching_fact['prompt'].format(patching_fact['base_subject'])\n",
    "    source_prompt = patching_fact['prompt'].format(patching_fact['source_subject'])\n",
    "    base_target = patching_fact['base_target']\n",
    "    source_target = patching_fact['source_target']\n",
    "    last_subj_pos_base = get_last_subj_token_idx(patching_fact['prompt'], patching_fact['base_subject'])\n",
    "    last_subj_pos_source = get_last_subj_token_idx(patching_fact['prompt'], patching_fact['source_subject'])\n",
    "\n",
    "    base_target_idx = tok.encode(f' {base_target}')[0]\n",
    "    source_target_idx = tok.encode(f' {source_target}')[0]\n",
    "    \n",
    "    base_act = get_neuronal_activations(\n",
    "        model=model, prompt=base_prompt, layer=layer, seq_pos=last_subj_pos_base,\n",
    "    )\n",
    "    source_act = get_neuronal_activations(\n",
    "        model=model, prompt=source_prompt, layer=layer, seq_pos=last_subj_pos_source,\n",
    "    )\n",
    "    metrics = {}\n",
    "    W_proj = get_W_proj(layer=layer)\n",
    "    # first, we want to know how the difference is distributed across the\n",
    "    # nullspace and rowspace components of v\n",
    "    v_row, v_null = decompose_along_W(W=W_proj, v=v, normalize=False)\n",
    "    act_diff = source_act - base_act\n",
    "    act_diff_sim_to_row = torch.cosine_similarity(act_diff, v_row, dim=-1)\n",
    "    act_diff_sim_to_null = torch.cosine_similarity(act_diff, v_null, dim=-1)\n",
    "    metrics['act_diff_sim_to_row'] = act_diff_sim_to_row.item()\n",
    "    metrics['act_diff_sim_to_null'] = act_diff_sim_to_null.item()\n",
    "\n",
    "    Sigma = get_covariance(layer=layer)\n",
    "    a = (act_diff @ v) * (v @ W_proj)\n",
    "    Sigma_pinv = torch.linalg.pinv(Sigma)\n",
    "    b = base_act @ Sigma_pinv / (base_act @ Sigma_pinv @ base_act)\n",
    "    metrics['variance_of_edit'] = (b @ Sigma @ b).item()\n",
    "\n",
    "    ### now, run the model with this edit\n",
    "    update = torch.outer(a, b).T\n",
    "    original_weight = model.transformer.h[layer].mlp.c_proj.weight.data.clone()\n",
    "    try:\n",
    "        model.transformer.h[layer].mlp.c_proj.weight.data += update\n",
    "        toks = encode_with_bos(s=base_prompt)\n",
    "        logits_after_edit = model(toks.cuda()).logits[:, -1, :]\n",
    "    finally:\n",
    "        model.transformer.h[layer].mlp.c_proj.weight.data = original_weight\n",
    "    \n",
    "    # check if the edit is equivalent to the patch: compute the logit \n",
    "    # difference and the prediction\n",
    "    edit_logit_diff = logits_after_edit[0, base_target_idx] - logits_after_edit[0, source_target_idx]\n",
    "    edit_prediction = torch.argmax(logits_after_edit, dim=-1)\n",
    "    metrics['edit_logit_diff'] = edit_logit_diff.item()\n",
    "    metrics['edit_prediction'] = tok.decode([edit_prediction.item()])\n",
    "    metrics['edit_predicts_source'] = edit_prediction.item() == source_target_idx\n",
    "    metrics['edit_predicts_base'] = edit_prediction.item() == base_target_idx\n",
    "    return a, b, metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_rewrite_score(\n",
    "    logits_before_edit, logits_after_edit, source_target_idx, base_target_idx\n",
    "):\n",
    "    # compute the rewrite score\n",
    "    probs_before_edit = torch.softmax(logits_before_edit, dim=-1)\n",
    "    probs_after_edit = torch.softmax(logits_after_edit, dim=-1)\n",
    "    prob_false_after = probs_after_edit[0, source_target_idx]\n",
    "    prob_false_before = probs_before_edit[0, source_target_idx]\n",
    "    rewrite_score = (prob_false_after - prob_false_before) / (1 - prob_false_before)\n",
    "    return rewrite_score.item()\n",
    "\n",
    "def evaluate_subspace_intervention(v: Tensor, layer: int, patching_fact: dict, scale: float):\n",
    "    v = np.sqrt(scale) * v # this makes the intervention have the same scale as the rank-1 edit\n",
    "\n",
    "    base_prompt = patching_fact['prompt'].format(patching_fact['base_subject'])\n",
    "    source_prompt = patching_fact['prompt'].format(patching_fact['source_subject'])\n",
    "    base_target = patching_fact['base_target']\n",
    "    source_target = patching_fact['source_target']\n",
    "    last_subj_pos_base = get_last_subj_token_idx(patching_fact['prompt'], patching_fact['base_subject'])\n",
    "    last_subj_pos_source = get_last_subj_token_idx(patching_fact['prompt'], patching_fact['source_subject'])\n",
    "    base_target_idx = tok.encode(f' {base_target}')[0]\n",
    "    source_target_idx = tok.encode(f' {source_target}')[0]\n",
    "\n",
    "    def hook_fn(module, input, output):\n",
    "        # note that the source prompt is unused in this intervention\n",
    "        acts = output[0, last_subj_pos_base, :]\n",
    "        update_coef = einsum(\"d_mlp, d_mlp -> \", acts, v)\n",
    "        update = - update_coef * v\n",
    "        output[0, last_subj_pos_base, :] += update\n",
    "        return output\n",
    "    \n",
    "    tokens = mencode_with_bos(base_prompt)\n",
    "    clean_logits = model(tokens).logits[:, -1, :]\n",
    "\n",
    "    handle = model.transformer.h[layer].mlp.act.register_forward_hook(hook_fn)\n",
    "    intervened_logits = model(tokens).logits[:, -1, :]\n",
    "    handle.remove()\n",
    "    rewrite_score = get_rewrite_score(\n",
    "        logits_before_edit=clean_logits,\n",
    "        logits_after_edit=intervened_logits,\n",
    "        source_target_idx=source_target_idx,\n",
    "        base_target_idx=base_target_idx,\n",
    "    )\n",
    "    return rewrite_score\n",
    "\n",
    "def evaluate_edit(a: Tensor, b: Tensor, scale: float, \n",
    "                  patching_fact: dict, layer: int):\n",
    "    base_prompt = patching_fact['prompt'].format(patching_fact['base_subject'])\n",
    "    source_prompt = patching_fact['prompt'].format(patching_fact['source_subject'])\n",
    "    base_target = patching_fact['base_target']\n",
    "    source_target = patching_fact['source_target']\n",
    "    last_subj_pos_base = get_last_subj_token_idx(patching_fact['prompt'], patching_fact['base_subject'])\n",
    "    last_subj_pos_source = get_last_subj_token_idx(patching_fact['prompt'], patching_fact['source_subject'])\n",
    "\n",
    "    base_target_idx = tok.encode(f' {base_target}')[0]\n",
    "    source_target_idx = tok.encode(f' {source_target}')[0]\n",
    "\n",
    "    toks = encode_with_bos(s=base_prompt)\n",
    "    logits_before_edit = model(toks.cuda()).logits[:, -1, :]\n",
    "    \n",
    "    update = torch.outer(a, b).T * scale\n",
    "    original_weight = model.transformer.h[layer].mlp.c_proj.weight.data.clone()\n",
    "    try:\n",
    "        model.transformer.h[layer].mlp.c_proj.weight.data += update\n",
    "        toks = encode_with_bos(s=base_prompt)\n",
    "        logits_after_edit = model(toks.cuda()).logits[:, -1, :]\n",
    "    finally:\n",
    "        model.transformer.h[layer].mlp.c_proj.weight.data = original_weight\n",
    "    \n",
    "    # compute the rewrite score\n",
    "    rewrite_score = get_rewrite_score(\n",
    "        logits_before_edit, logits_after_edit, source_target_idx, base_target_idx\n",
    "    )\n",
    "    return rewrite_score\n",
    "\n",
    "def evaluate_rome_and_subspace(fp_results, \n",
    "                               rank1_to_subsp_result,\n",
    "                               rome_edit_results,):\n",
    "    patching_facts = fp_results[3]\n",
    "    data = {} # (layer, fact_idx) -> dict of things\n",
    "    for r in rome_edit_results:\n",
    "        layer, fact_idx = r['layer'], r['fact_idx']\n",
    "        data[(layer, fact_idx)] = {\n",
    "            'patching_fact': patching_facts[fact_idx],\n",
    "            'a': r['v'],\n",
    "            'b': r['u'],\n",
    "            'scale': r['scale'].item()\n",
    "        }\n",
    "    for r in rank1_to_subsp_result:\n",
    "        layer, fact_idx = r['layer'], r['fact_idx']\n",
    "        data[(layer, fact_idx)]['v_subsp'] = r['v']\n",
    "    \n",
    "    results = []\n",
    "    for (layer, fact_idx), d in tqdm(data.items()):\n",
    "        rewrite_score_rome = evaluate_edit(\n",
    "            a=d['a'], b=d['b'], scale=d['scale'],\n",
    "            patching_fact=d['patching_fact'], layer=layer\n",
    "        )\n",
    "        rewrite_score_subsp = evaluate_subspace_intervention(\n",
    "            v=d['v_subsp'], layer=layer, patching_fact=d['patching_fact'], scale=d['scale']\n",
    "        )\n",
    "        results.append({\n",
    "            'layer': layer,\n",
    "            'fact_idx': fact_idx,\n",
    "            'rewrite_score_rome': rewrite_score_rome,\n",
    "            'rewrite_score_subsp': rewrite_score_subsp,\n",
    "        })\n",
    "    return results"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.12",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "fd8dbfdfd1a6a4c5f3a98a8b5f239185c4ac44e8c535538c941237e2ab93d1b0"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
