{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data_utils import *\n",
    "from data_utils import test_distribution\n",
    "from model_utils import *\n",
    "import joblib\n",
    "NAME_MOVERS = [(9, 6), (9, 9), (10, 0)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using pad_token, but it is not set yet.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model gpt2-small into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "model = get_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training DAS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### locations where to train DAS\n",
    "resid_node = Node('resid_post', layer=8, seq_pos=-1)\n",
    "mlp_node = Node('post', layer=8, seq_pos=-1)\n",
    "resid_node_mid = Node('resid_mid', layer=8, seq_pos=-1)\n",
    "\n",
    "################################################################################\n",
    "### generate train and test datasets\n",
    "################################################################################\n",
    "D_train = train_distribution.sample_das(\n",
    "    model=model,\n",
    "    base_patterns=['ABB', 'BAB'],\n",
    "    source_patterns=['CDD', 'DCD'],\n",
    "    labels='position',\n",
    "    samples_per_combination=50,\n",
    ")\n",
    "D_test = test_distribution.sample_das(\n",
    "    model=model,\n",
    "    base_patterns=['ABB',],\n",
    "    source_patterns=['BAB'],\n",
    "    labels='position',\n",
    "    samples_per_combination=50,\n",
    ") + test_distribution.sample_das(\n",
    "    model=model,\n",
    "    base_patterns=['BAB',],\n",
    "    source_patterns=['ABB'],\n",
    "    labels='position',\n",
    "    samples_per_combination=50,\n",
    ")\n",
    "\n",
    "################################################################################\n",
    "### patchers for different locations\n",
    "################################################################################\n",
    "das_patcher_mlp = Patcher(\n",
    "    nodes=[mlp_node],\n",
    "    patch_impl=Rotation(\n",
    "        rotation=RotationMatrix(n=3072).cuda(),\n",
    "        dim=1,\n",
    "    )\n",
    ")\n",
    "\n",
    "das_patcher_resid = Patcher(\n",
    "    nodes=[resid_node],\n",
    "    patch_impl=Rotation(\n",
    "        rotation=RotationMatrix(n=768).cuda(),\n",
    "        dim=1,\n",
    "    )\n",
    ")\n",
    "\n",
    "das_patcher_resid_mid = Patcher(\n",
    "    nodes=[resid_node_mid],\n",
    "    patch_impl=Rotation(\n",
    "        rotation=RotationMatrix(n=768).cuda(),\n",
    "        dim=1,\n",
    "    )\n",
    ")\n",
    "\n",
    "baseline_patcher = Patcher(\n",
    "    nodes=[resid_node], \n",
    "    patch_impl=Full(),\n",
    ")\n",
    "\n",
    "################################################################################\n",
    "### training\n",
    "################################################################################\n",
    "torch.cuda.empty_cache()\n",
    "metrics = patch_training(\n",
    "    model=model,\n",
    "    D_train=D_train,\n",
    "    D_test=D_test,\n",
    "    baseline_patcher=baseline_patcher,\n",
    "    batch_size=20,\n",
    "    eval_every=5,\n",
    "    initial_lr=0.01,\n",
    "    n_epochs=30,\n",
    "    patcher=das_patcher_resid_mid, # change this to train different locations\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(768,)\n"
     ]
    }
   ],
   "source": [
    "# recover the trained direction from the patcher used for training\n",
    "v = das_patcher_resid_mid.patch_impl.rotation.R.weight.data.detach().cpu().numpy()[:, 0]\n",
    "print(v.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Finding other interesting directions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@batched(args=['prompt_dataset'], n_outputs=1, reducer='cat')\n",
    "def get_gradients(prompt_dataset: PromptDataset, \n",
    "                  layer: int, head: int,\n",
    "                  batch_size: int = 20,\n",
    "                  ) -> torch.Tensor:\n",
    "    activation_container = []\n",
    "    def forward_hk(module, input, output):\n",
    "        activation_container.append(output)\n",
    "    gradient_container = []\n",
    "    def backward_hk(module, grad_in, grad_out):\n",
    "        gradient_container.append(grad_out[0])\n",
    "    fwd_handle = model.blocks[layer].attn.hook_attn_scores.register_forward_hook(forward_hk)\n",
    "    bwd_handle = model.blocks[8].hook_resid_post.register_backward_hook(backward_hk)\n",
    "\n",
    "    try:\n",
    "        model.requires_grad_(True)\n",
    "        _ = model(prompt_dataset.tokens)\n",
    "        attn_scores = activation_container[0] # (batch, head, source, target)\n",
    "        attn_3 = attn_scores[:, head, -1, 3]\n",
    "        attn_5 = attn_scores[:, head, -1, 5]\n",
    "        diff = attn_3 - attn_5\n",
    "        diff.sum().backward()\n",
    "        grad = gradient_container[0]\n",
    "    finally:\n",
    "        fwd_handle.remove()\n",
    "        bwd_handle.remove()\n",
    "        model.requires_grad_(False)\n",
    "    # grad is of shape (batch, seq_len, hidden_size)\n",
    "    grad_last = grad[:, -1, :]\n",
    "    return grad_last.detach()\n",
    "\n",
    "def compute_avg_gradient(patching_dataset: PatchingDataset, layer: int, head: int, random_seed: int = 0):\n",
    "    ABB_grad = get_gradients(prompt_dataset=patching_dataset.base, layer=layer, head=head, batch_size=20).mean(dim=0)\n",
    "    BAB_grad = get_gradients(prompt_dataset=patching_dataset.source, layer=layer, head=head, batch_size=20).mean(dim=0)\n",
    "    g = (ABB_grad + BAB_grad) / 2\n",
    "    g = g / g.norm()\n",
    "    return g\n",
    "\n",
    "def get_mean_diff_direction(patching_dataset: PatchingDataset):\n",
    "    node = Node('resid_post', layer=8, seq_pos=-1)\n",
    "    ABB_acts = run_with_cache(prompts=patching_dataset.base.prompts, nodes=[node], model=model, batch_size=100)[0]\n",
    "    BAB_acts = run_with_cache(prompts=patching_dataset.source.prompts, nodes=[node], model=model, batch_size=100)[0]\n",
    "    diff = ABB_acts.mean(dim=0) - BAB_acts.mean(dim=0)\n",
    "    return diff / diff.norm()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute the avg gradient of the three heads, normalized to have unit norm\n",
    "gs = []\n",
    "for layer, head in NAME_MOVERS:\n",
    "    g = compute_avg_gradient(patching_dataset=D_test, layer=layer, head=head)\n",
    "    gs.append(g)\n",
    "g = torch.stack(gs).mean(dim=0)\n",
    "g = g / g.norm()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.12",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "fd8dbfdfd1a6a4c5f3a98a8b5f239185c4ac44e8c535538c941237e2ab93d1b0"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
