{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f3d50c0c",
   "metadata": {},
   "source": [
    "# Purpose\n",
    "\n",
    "This notebook just contains a bunch of cells that I used to explore the model and some alternatives. \n",
    "It is mostly not directly useful to producing the experiments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e61a0720",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a120102",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import argparse\n",
    "\n",
    "import numpy as np\n",
    "import os\n",
    "import pandas as pd\n",
    "import scipy as sp\n",
    "import sys\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import warnings\n",
    "import random\n",
    "import collections\n",
    "\n",
    "# CD-T Imports\n",
    "import math\n",
    "import tqdm\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "import itertools\n",
    "\n",
    "from torch import nn\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "base_dir = os.path.split(os.getcwd())[0]\n",
    "sys.path.append(base_dir)\n",
    "\n",
    "from argparse import Namespace\n",
    "from methods.bag_of_ngrams.processing import cleanReports, cleanSplit, stripChars\n",
    "from pyfunctions.general import extractListFromDic, readJson, combine_token_attn, compute_word_intervals\n",
    "from pyfunctions.pathology import extract_synoptic, fixLabelProstateGleason, fixProstateLabels, fixLabel, exclude_labels\n",
    "from pyfunctions.cdt_basic import *\n",
    "from pyfunctions.ioi_dataset import IOIDataset\n",
    "from sklearn import preprocessing\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "from transformers import GPT2Tokenizer, GPT2Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6d6ada4-4781-4789-b3b1-1d044c11b3d3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "torch.autograd.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7651660-7f59-4b59-a574-afecc52dc306",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "## Load Model\n",
    "\n",
    "Note: Unlike with the BERT model + medical dataset objective, it is not necessary to pretrain GPT-2 to perform the IOI dataset.\n",
    "GPT-2-small is already capable of performing IOI; that's part of the point of the Mech Interp in the Wild paper.\n",
    "We only need to examine how it does it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3183be1-3bf6-4f5a-8134-9bdd83db0a56",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "device = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4838142",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained('gpt2')\n",
    "model = AutoModelForCausalLM.from_pretrained('gpt2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "820d21e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75d912fd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5069ff59",
   "metadata": {},
   "outputs": [],
   "source": [
    "text = \"Replace me by any text you'd\"\n",
    "input = tokenizer(text, return_tensors='pt').input_ids\n",
    "# print(encoded_input) # has 'input_idx' and 'attention_mask'\n",
    "# output = model(input)\n",
    "# print(output.last_hidden_state.shape)\n",
    "gen_tokens = model.generate(input, pad_token_id=tokenizer.pad_token_id, output_scores=True)\n",
    "print(gen_tokens)\n",
    "gen_text = tokenizer.batch_decode(gen_tokens)\n",
    "gen_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "712c94e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# other exploratory stuff\n",
    "# model.to_tokens(text) #turns out this is a utility of trasnformer_lens\n",
    "# print(output.past_key_values[0][0].shape) # this has to do with key matrix stuff\n",
    "#print(output.values())\n",
    "#output.logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b375528e",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08241bd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "ioi_dataset = load_dataset(\"fahamu/ioi\")\n",
    "# i've decided against using this for the most part; it's better to use the raw IOIDataset \n",
    "# from the paper and from the related notebook using EasyTransformers, since these both provide many utilities for dealing with the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c0e1142",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install transformer_lens\n",
    "!pip install einops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a520f760",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model code adapted from Callum McDougall's notebook for ARENA on reproducing the IOI paper using TransformerLens.\n",
    "# This makes some sense, since EasyTransformer, the repo/lib released by the IOI guys, was forked from TransformerLens.\n",
    "# In fact, this makes the reproduction a little bit more faithful, since they most likely do certain things such as \n",
    "# \"folding\" LayerNorms to improve their interpretability results, and we are able to do the same by using TransformerLens.\n",
    "# HuggingFace, by contrast, has the most impenetrable docs and tons of outdated APIs and etc.; even their source \n",
    "# code is impossible to traverse, and I gave up on it, thankfully quickly.\n",
    "\n",
    "from transformer_lens import utils, HookedTransformer, ActivationCache\n",
    "model = HookedTransformer.from_pretrained(\"gpt2-small\",\n",
    "                                          center_unembed=True,\n",
    "                                          center_writing_weights=True,\n",
    "                                          fold_ln=False,\n",
    "                                          refactor_factored_attn_matrices=True)\n",
    "                                          "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28216f8c",
   "metadata": {},
   "source": [
    "## Example forward pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3fb595a",
   "metadata": {},
   "outputs": [],
   "source": [
    "text = \"After John and Mary went to the store, John gave a bottle of milk to\"\n",
    "tokens = model.to_tokens(text).to(device)\n",
    "logits, cache = model.run_with_cache(tokens)\n",
    "probs = logits.softmax(dim=-1)\n",
    "most_likely_next_tokens = model.tokenizer.batch_decode(logits.argmax(dim=-1)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11dbf9fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "for activation_name, activation in cache.items():\n",
    "    # Only print for first layer\n",
    "    if \".0.\" in activation_name or \"blocks\" not in activation_name:\n",
    "        print(f\"{activation_name:30} {tuple(activation.shape)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9aaa465",
   "metadata": {},
   "outputs": [],
   "source": [
    "# hack to get model dtype out, for compatibility with other code\n",
    "next(model.parameters()).dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "515c528b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(model)\n",
    "# print(model.config) # doesn't work on hookedtransformer, is a huggingface thing\n",
    "# print(model.embed.dtype) same, but can use dtype trick\n",
    "# print(type(model))\n",
    "#model.state_dict().keys()#.blocks[0].mlp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d130a854",
   "metadata": {},
   "outputs": [],
   "source": [
    "import inspect\n",
    "# inspect.getclasstree(inspect.getmro(type(model)))\n",
    "inspect.getmro(type(model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32fc711e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchviz\n",
    "dir(model)\n",
    "# torchviz.make_dot(model)\n",
    "# model._modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c19f02f",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install torchsummary\n",
    "!pip install torchinfo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "437d3a95",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pdb\n",
    "from torchinfo import summary\n",
    "\n",
    "text = \"After John and Mary went to the store, John gave a bottle of milk to\"\n",
    "encoding = get_encoding(text, model.tokenizer, \"cpu\")\n",
    "# embedding_output = model.embed(encoding.input_ids)\n",
    "input_shape = encoding.input_ids.shape\n",
    "print(input_shape)\n",
    "pdb.set_trace()\n",
    "summary(model, input_shape, device='cpu')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8823997c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Same as in the notebook, example\n",
    "example_prompt = \"After John and Mary went to the store, John gave a bottle of milk to\"\n",
    "example_answer = \"Mary\"\n",
    "utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)o"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d377977-fe3b-45dd-9d00-1c19e5366038",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "## Generate dataset/Explore types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b295089f",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = IOIDataset(N=500, prompt_type=\"ABBA\", tokenizer=model.tokenizer)\n",
    "#data.tokenized_prompts\n",
    "data.ioi_prompts[0]\n",
    "[x['TEMPLATE_IDX'] for x in data.ioi_prompts[0:10]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92798f42",
   "metadata": {},
   "outputs": [],
   "source": [
    "# test\n",
    "pos_specific_hs = [\n",
    "        [i for i in range(12)],\n",
    "        [0],\n",
    "        [i for i in range(12)]\n",
    "    ]\n",
    "all_heads = list(itertools.product(*pos_specific_hs))\n",
    "target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)] # not meaningful in a GPT context\n",
    "source_list = [[node] for node in all_heads if node not in target_nodes]\n",
    "\n",
    "text = \"After John and Mary went to the store, John gave a bottle of milk to\"\n",
    "encoding = get_encoding(text, model.tokenizer, device)\n",
    "# encoding.input_ids.shape # 512-long vector, not sure why the tokens change from EOS to 0 at some point\n",
    "# embedding = model.embed(encoding.input_ids)\n",
    "\n",
    "out_decomps, target_decomps = prop_model_hh_batched(encoding, model, source_list, target_nodes,\n",
    "                                                                   device=device,\n",
    "                                                                   patched_values=None, mean_ablated=False, num_at_time=1)\n",
    "                                                                   # patched_values=mean_act, mean_ablated=True)\n",
    "                                                                "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01fa5c2b",
   "metadata": {},
   "source": [
    "## Explore IOI Dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e9522d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyfunctions.ioi_dataset import IOIDataset\n",
    "\n",
    "ioi_dataset = IOIDataset(prompt_type=\"mixed\", N=50, tokenizer=model.tokenizer, prepend_bos=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de60c87d",
   "metadata": {},
   "outputs": [],
   "source": [
    "ioi_dataset.toks.shape\n",
    "\n",
    "ioi_dataset.word_idx\n",
    "\n",
    "ioi_dataset.sentences[:4]\n",
    "\n",
    "ioi_dataset.groups\n",
    "\n",
    "# ioi_dataset.toks[ioi_dataset.groups[-1]]\n",
    "# [ioi_dataset.sentences[x] for x in ioi_dataset.groups[3]] # sentences of the same group are identical except for the choice of nouns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59e67051",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The below is wrong! The generated sentences are not of the same format as what is described in the paper, and \n",
    "# this is also not what they do in their experiments.py.\n",
    "# abc_dataset = IOIDataset(prompt_type=\"ABC mixed\", N=50, tokenizer=model.tokenizer, prepend_bos=False)\n",
    "\n",
    "# Instead, do this, apparently.\n",
    "abc_dataset = (\n",
    "    ioi_dataset.gen_flipped_prompts((\"IO\", \"RAND\"))\n",
    "    .gen_flipped_prompts((\"S\", \"RAND\"))\n",
    "    .gen_flipped_prompts((\"S1\", \"RAND\"))\n",
    ") # Note generating several of these in a row will generate different random names; this can be useful for a quick mean ablation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65604ebb",
   "metadata": {},
   "outputs": [],
   "source": [
    "abc_dataset.sentences[:4]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
