{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec6d9634-6e7e-4c1e-a33e-2adf5612654d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_lens import HookedTransformer\n",
    "from datasets import load_dataset\n",
    "import re\n",
    "import torch\n",
    "print(torch.cuda.get_device_name(0))\n",
    "from transformers import AutoModelForCausalLM\n",
    "from transformers import AutoTokenizer\n",
    "import nltk\n",
    "import random\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import yaml\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "random.seed(42)\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c678e72d-1011-4783-b071-59bc62e318ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"global_config.yaml\") as config_file:\n",
    "    config_dict = yaml.safe_load(config_file)\n",
    "\n",
    "CACHE_DIR = config_dict['CACHE_DIR']\n",
    "access_token = config_dict['hf_access_token']\n",
    "print(config_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88ac3c37-a7ae-4835-bd2e-8a57a2eb36a7",
   "metadata": {},
   "source": [
    "# Generate Data\n",
    "We take random sentences of sufficient length from owt-10k. Then, we will prompt gemma to augment the data.\n",
    "\n",
    "We also include a randomly grouped dataset as a baseline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8dc3b95-5e5c-42df-82c0-94018d5abe81",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get sentences from openwebtext or smth.\n",
    "dataset = load_dataset('stas/openwebtext-10k', split='train', cache_dir=CACHE_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b11e0ab-590a-4ff5-8e13-62e7dc234984",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filt(x, cutoff):\n",
    "    return not('\\n' in x) and len(x.split(' ')) > cutoff\n",
    "base_cutoff = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c555fb18-3643-4726-aa1f-1192c6126cbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences = []\n",
    "for data in dataset['text'][:10000]:\n",
    "    valid = [x for x in nltk.sent_tokenize(data) if filt(x, base_cutoff)]\n",
    "    random.shuffle(valid)\n",
    "    if len(valid) > 0:\n",
    "        sentences.append(valid[0].strip())\n",
    "print(len(sentences), \"sentences collected\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14beb7a0-6851-4ee4-a51d-e7e698166fc1",
   "metadata": {},
   "source": [
    "Generate a random dataset baseline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75214880-c57b-4847-ad62-aad683f5310e",
   "metadata": {},
   "outputs": [],
   "source": [
    "grouped = []\n",
    "for i in range(len(sentences) // 6):\n",
    "    grouped.append(sentences[i*6:i*6+6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7926aa45-dc77-407c-aece-97fa3a64b1e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"data/random_text_data.json\", \"w\") as f:\n",
    "    json.dump({'text':grouped}, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dd9f5aa-33fb-481e-98ea-01cd310e21e8",
   "metadata": {},
   "source": [
    "### Prompt gemma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba48252c-30b0-47bb-be42-695599eab749",
   "metadata": {},
   "outputs": [],
   "source": [
    "large_model = AutoModelForCausalLM.from_pretrained('google/gemma-2-9b-it', cache_dir=CACHE_DIR, token=access_token, torch_dtype=torch.float16, device_map='auto')\n",
    "large_tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-9b-it', token=access_token)\n",
    "large_model.generation_config.pad_token_id = large_tokenizer.pad_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f708ab9-1bdb-4ad7-80ce-063116d5f38e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_gemma_response(text):\n",
    "    matches = re.findall(r\"<start_of_turn>model\\n(.*?)<end_of_turn>\", text, re.DOTALL)\n",
    "    if matches:\n",
    "        return matches[-1].strip()\n",
    "    else:\n",
    "        return None\n",
    "k = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "437a91c6-fc11-4612-be9e-b3f3dcd92873",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "augmentations = []\n",
    "for sentence in tqdm(sentences):\n",
    "    answers = [sentence]\n",
    "    for i in range(k):\n",
    "        prompt = f\"\"\"Transform the given sentence so that its meaning is completely unrelated, but the syntax, punctuation, and grammatical structure remain identical. This means:\n",
    "        - Keep all function words (e.g., \"the,\" \"and,\" \"while\") and punctuation (e.g. commas, periods, brackets, parentheses, dashes) unchanged.\n",
    "        - Replace content words (nouns, verbs, adjectives, adverbs) with words from entirely different semantic domains (e.g., \"run\" → \"melt\", \"dog\" → \"radio\").\n",
    "        - Ensure that the sentence remains grammatically correct, even though its meaning is completely different.\n",
    "        - Respond with the new sentence and nothing else.\n",
    "        \n",
    "        Examples:\n",
    "            Input: \"The scientist carefully examined the ancient manuscript in the dimly lit library.\"\n",
    "            Output: \"The firefighter accidentally kicked the broken telephone in the noisy bus station.\"\n",
    "        \n",
    "            Input: \"After the storm passed, the children ran outside to play in the puddles.\"\n",
    "            Output: \"Once the debate ended, the accountants flew abroad to invest in the factories.\"\n",
    "        \n",
    "        Again, your instructions are:\n",
    "        - Ensure that the core meaning is entirely different from the original sentence.\n",
    "        - Do not use synonyms or words from the same semantic category.\n",
    "        - Maintain identical syntax, punctuation, and grammatical structure.\n",
    "        - The new sentence must be valid and natural, despite the unrelated meaning.\n",
    "        - Respond with the new sentence and nothing else.\n",
    "        \n",
    "        Here is the original sentence.\n",
    "        Input: {sentence}\n",
    "        \"\"\"\n",
    "        messages = [\n",
    "            {\"role\": \"user\", \"content\": prompt}\n",
    "        ]\n",
    "        tokenized_chat = large_tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\").to(\"cuda\")\n",
    "        outputs = large_model.generate(tokenized_chat, max_new_tokens=128, temperature=1.0, do_sample=True)[0]\n",
    "        answer = extract_gemma_response(large_tokenizer.decode(outputs))\n",
    "        if answer is None:\n",
    "            continue\n",
    "        answers.append(answer)\n",
    "\n",
    "    #print(answers)\n",
    "    if len(answers) > 3:\n",
    "        augmentations += [answers]\n",
    "    data = {\"text\": augmentations}\n",
    "    # progressively save progress\n",
    "    with open(\"data/full_augmented_text_data.json\", \"w\") as f:\n",
    "        json.dump(data, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2984d653-c037-47be-adf5-6eecce915d3c",
   "metadata": {},
   "source": [
    "# Load and tokenize the dataset\n",
    "\n",
    "After generating the dataset, we can now load it and tokenize for the appropriate model. Here we have an example for pythia-70m-deduped."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51454793-ee44-4a7d-9000-1f9648ba6056",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from transformer_lens.utils import get_input_with_manually_prepended_bos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b1f1ec4-feae-4ff3-a581-c29c814df6c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = 'data/full_augmented_text_data.json' # 'data/random_text_data.json'\n",
    "keyword = 'augmented'\n",
    "with open(filename, 'r') as f:\n",
    "    data = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57d4efda-e9a2-4a3a-9af1-15a69765dafb",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_augmentations = max([len(aug) for aug in data['text']])\n",
    "full_augmentations, len(data['text'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aeb62ae2-2cf8-4347-b607-abf54e0e4436",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize(augmentations, model, ctx_size=128, prepend_bos=True):\n",
    "    \"\"\"\n",
    "    Tokenize the augmentations.\n",
    "    \"\"\"\n",
    "    yay = []\n",
    "    for aug in augmentations:\n",
    "        if len(aug) != full_augmentations:\n",
    "            continue\n",
    "        if prepend_bos:\n",
    "            aug = get_input_with_manually_prepended_bos(tokenizer, aug)\n",
    "        yay.append(tokenizer(aug, truncation=True, max_length=ctx_size, padding='max_length', return_tensors='pt')['input_ids'].unsqueeze(0))\n",
    "        \n",
    "    return torch.cat(yay, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1f1f5e6-31fe-40bc-b1cc-b4b2e6b18b02",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'pythia-70m-deduped'\n",
    "model = HookedTransformer.from_pretrained(model_name, cache_dir=CACHE_DIR, device='cpu')\n",
    "tokenizer = model.tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee1abeba-1ef5-49da-9cee-0a79d23d8419",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized = tokenize(data['text'], tokenizer, 128, prepend_bos=not(model.cfg.tokenizer_prepends_bos))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fc48c05-9d91-4ab6-bf0f-5ee08aaae173",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tokenizer.name_or_path)\n",
    "torch.save(tokenized, f'data/{os.path.basename(tokenizer.name_or_path)}_{keyword}_tokenized_data.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2633bcde-f2b1-4c29-9e98-1b0f101fcf03",
   "metadata": {},
   "source": [
    "# Classify Features\n",
    "Now, we can load an SAE and classify the features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ea63d89-3e4f-4d04-8546-78abcb37c5d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_lens import SAE\n",
    "from sae import Sae\n",
    "import torch\n",
    "from stitching.sae_utils import BaseSAE, convert_eleuther_sae_to_BaseSAE, get_densities\n",
    "from stitching.sae_utils import feature_activations\n",
    "from stitching.random_utils import get_all_special_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58bbee54-8739-479f-b1f9-66e8311db24e",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1270612-f45b-408d-9cc1-6cfe431d16b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#model_name = 'gpt2-small'\n",
    "model_name = 'pythia-70m-deduped'\n",
    "keyword = 'augmented'\n",
    "#model_name = 'gemma-2-2b'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d49e1747-684a-401b-bf24-943e0916e8c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "TRUNCATION_LENGTH = 512\n",
    "tokenized_dataset = {}\n",
    "for dataset_key in ['train', 'test']:\n",
    "    tokenized_dataset[dataset_key] = torch.load(f'data/{model_name}_tokenized_dataset_200000_{dataset_key}_{TRUNCATION_LENGTH}.pt', weights_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad7f717c-7522-4888-a748-6c39a761a107",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.load(f'data/{model_name}_{keyword}_tokenized_data.pt', weights_only=True)\n",
    "\n",
    "#data = torch.load('data/gpt2_augmented_tokenized_data.pt', weights_only=True)\n",
    "#data = torch.load('data/gemma-2-2b_augmented_tokenized_data.pt', weights_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6988a80-6fa4-4bf5-b11e-27889dbfedc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "flattened_data = data.reshape(-1, 128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ead2042d-b321-4ba5-91be-9e19afb3e4ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.inference_mode()\n",
    "def get_structural_features(data, model, layer, sae):\n",
    "    final_results = []\n",
    "    spec_tokens = get_all_special_tokens(model.tokenizer)\n",
    "    for i, tensor in tqdm(enumerate(data)):\n",
    "        tensor = tensor.to(device)\n",
    "        logits = model(tensor, stop_at_layer = layer)\n",
    "        activations = sae.encode(logits) # ( 4, toks, 32k)\n",
    "        did_you_activate = []\n",
    "        for i in range(activations.shape[0]):  # for each augmented version of the sentence\n",
    "            sentence_mask = get_ignore_mask(tensor[i], spec_tokens)\n",
    "            sentence = activations[i]\n",
    "            sentence = sentence[~sentence_mask]\n",
    "            woohoo = (torch.sum((sentence > 0), dim=0) > 0).unsqueeze(0)\n",
    "            did_you_activate.append(woohoo)\n",
    "        active_on_all = ~((torch.cat(did_you_activate, dim=0) == 0).sum(dim=0) > 0)\n",
    "        final_results.append(active_on_all.unsqueeze(0))\n",
    "    return torch.cat(final_results, dim=0).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b26fc88c-9ae0-4aeb-9244-d245210a18b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HookedTransformer.from_pretrained('pythia-70m-deduped', cache_dir=CACHE_DIR, device=device)\n",
    "layer = 3\n",
    "sae_A = Sae.load_from_hub(\"EleutherAI/sae-pythia-70m-deduped-32k\", hookpoint=f\"layers.{layer-1}\").to(device)\n",
    "sae_A = convert_eleuther_sae_to_BaseSAE(sae_A)\n",
    "\n",
    "#model = HookedTransformer.from_pretrained('gpt2-small', cache_dir=CACHE_DIR, device=device)\n",
    "#layer = 6 #3\n",
    "#sae_A = SAE.load_from_pretrained('gpt2-small-topk-sae-checkpoints/1t637wuk/final_245760000/', device=device)\n",
    "\n",
    "#model = HookedTransformer.from_pretrained('gemma-2-2b', cache_dir=CACHE_DIR, device=device)\n",
    "#layer = 20\n",
    "#sae_A, cfg_dict, _ = SAE.from_pretrained(\n",
    "#    release = 'gemma-scope-2b-pt-res-canonical', # see other options in sae_lens/pretrained_saes.yaml\n",
    "#    sae_id = f\"layer_{layer-1}/width_16k/canonical\", # won't always be a hook point\n",
    "#    device = device\n",
    "#) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73a66cd5-754a-45d7-bccd-b2531e15e85c",
   "metadata": {},
   "outputs": [],
   "source": [
    "densities = get_densities(\n",
    "    torch.utils.data.DataLoader(\n",
    "        tokenized_dataset['train'][:10000],\n",
    "        batch_size=10,\n",
    "        shuffle=False\n",
    "    ), model, layer, sae_A\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f4b98a7-6d6f-4699-a419-c1e78f4a56ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(np.log10(densities), log=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14a0cee3-d972-4572-a72f-d52607d0d196",
   "metadata": {},
   "outputs": [],
   "source": [
    "densities.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a24df7e-9277-4ad2-9576-ca1c0b5988de",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_res = get_structural_features(data, model, layer, sae_A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70332e69-7af0-479a-bd6f-78fb84b31743",
   "metadata": {},
   "outputs": [],
   "source": [
    "repeated_thresh = 1  # filter out some noise by setting a base density threshold of occurring twice\n",
    "structural = (final_res.sum(dim=0) > repeated_thresh)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b76b0559-adda-4069-96e9-7fc75faf125d",
   "metadata": {},
   "outputs": [],
   "source": [
    "structural.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1aeb85d2-753e-41de-b5a6-0e23e045870f",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_dict = {\n",
    "    'structural': structural,\n",
    "    'densities': densities\n",
    "}\n",
    "#torch.save(res_dict, \"gpt2-jb_structural.pt\")\n",
    "torch.save(res_dict, f\"{model_name}.{layer}_{keyword}_structural.pt\")\n",
    "f\"{model_name}.{layer}_structural.pt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b483f065-bc16-4071-8232-c54f3c289c0b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "276bd0b1-f69c-4903-8710-df763afad7aa",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
