{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from pathlib import Path\n",
    "from datasets import load_dataset\n",
    "from circuit_tracer import ReplacementModel\n",
    "\n",
    "from weight_lens.input_invariant_analysis import *\n",
    "from weight_lens.model_utils import *\n",
    "from weight_lens.utils import *\n",
    "\n",
    "# Choose device: CUDA > MPS > CPU\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "elif getattr(torch.backends, \"mps\", None) is not None and torch.backends.mps.is_available():\n",
    "    device = torch.device(\"mps\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Models Loading"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### gemma-2-2b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loading model with transcoders, as in circuit_tracer package\n",
    "model_name = 'google/gemma-2-2b'\n",
    "transcoder_name = \"gemma\"\n",
    "model = ReplacementModel.from_pretrained(model_name, transcoder_name, dtype=torch.bfloat16, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### gpt2 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# IMPORTANT: these transcoders depend on transcoder_circuits package, specifically, \n",
    "# sae_training/config.py should be located in the working directory. \n",
    "# https://github.com/jacobdunefsky/transcoder_circuits/blob/master/sae_training/config.py\n",
    "\n",
    "# For GPT-2 a slightly different approach is utilized: only checking layer 0 contributing features and tokens based on them. \n",
    "# Analyzing all the previous layers brings a lot of noise. \n",
    "model_name = \"gpt2\"\n",
    "model = load_replacement_model_from_yaml(\"configs/gpt2-transcoders.yaml\").to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Llama-3.2-1B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This would load the whole model with transcoders, but since they are very large, we are loading only part of layers further\n",
    "# model = ReplacementModel.from_pretrained(\"meta-llama/Llama-3.2-1B\", \"llama\", device=torch.device(\"cpu\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loading transcoders partially, only those layers that are in config\n",
    "transcoders = load_partial_transcoder_set(\"configs/llama-relu.yaml\", device=device)\n",
    "\n",
    "# Overwriting some ReplacementModel methods to use partial transcoders\n",
    "ReplacementModel._configure_replacement_model = configure_partial_replacement\n",
    "ReplacementModel._get_activation_caching_hooks = get_partial_activation_caching_hooks\n",
    "\n",
    "# Loading the model with partial transcoders\n",
    "model_name = \"meta-llama/Llama-3.2-1B\"\n",
    "model = ReplacementModel.from_pretrained_and_transcoders(model_name, transcoders.transcoders)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loading results from HuggingFace (if necessary) and saving them locally"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# In order to visualize results fast, you can load them pre-calculated. \n",
    "# Otherwise, you can run analysis by yourself, but since it is layerwise \n",
    "# (i.e. all contributing features from previous layers should be already analyzed), \n",
    "# for later layers it can take more time. \n",
    "\n",
    "\n",
    "model_name = \"gemma-2-2b\" # \"gpt2\" or \"Llama-3.2-1B\"\n",
    "save_dir = f\"results/{model_name}\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "num_layers = model.cfg.n_layers  # assuming model is already loaded\n",
    "\n",
    "dataset_repo = f\"{user_name}/weightlens-{model_name}-transcoder-descriptions\"\n",
    "\n",
    "for layer in range(3,num_layers):\n",
    "    file_name = f\"feature_analysis_layer_{layer}.json\"\n",
    "    ds = load_dataset(dataset_repo, data_files=file_name)\n",
    "    layer_results = ds[\"train\"].to_list()\n",
    "\n",
    "    save_path = os.path.join(save_dir, file_name)\n",
    "    with open(save_path, \"w\", encoding=\"utf-8\") as f:\n",
    "        json.dump(layer_results, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "    print(f\"Saved {file_name} to {save_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[✓] Feature analysis for L20 F8684:\n",
      "\n",
      "📝 Description:  ALL   |    All   |    all   |    tout   |   All   |   all\n",
      "No directly influencing tokens found for this feature.\n",
      "\n",
      "📤 Output logits:\n",
      " + Top positive tokens:  the   |    of   |    those   |    your   |    our   |    these\n",
      " - Top negative tokens: None\n",
      "\n",
      "🧩 Top contributing features:\n",
      " → From L3 F15756 | Contribution: +1.9844 | Relative: 0.00003257%\n",
      "   └─ Description:  tout    |   yscy\n",
      " → From L4 F2933 | Contribution: +1.5156 | Relative: 0.00002487%\n",
      "   └─ Description: dik\n",
      " → From L4 F8718 | Contribution: +1.3750 | Relative: 0.00002257%\n",
      "   └─ Description:  Mall    |    Zul\n",
      " → From L14 F8405 | Contribution: +1.2344 | Relative: 0.00002026%\n",
      "   └─ Description: all\n",
      " → From L8 F2251 | Contribution: +1.1641 | Relative: 0.00001910%\n",
      "   └─ Description: yscy\n",
      " → From L18 F8684 | Contribution: +1.1328 | Relative: 0.00001859%\n",
      "   └─ Description:  ALL    |    All    |    Mall    |    all    |    tout    |   All    |   all\n",
      " → From L16 F4888 | Contribution: +1.0000 | Relative: 0.00001641%\n",
      "   └─ Description:  tout    |   All    |   yscy\n"
     ]
    }
   ],
   "source": [
    "# Loading already generated results. The ones, that are not validated, will not have description presented. \n",
    "\n",
    "feature = Feature(20, 0, 8684)\n",
    "result = load_feature_analysis(feature, model, save_dir=Path(f\"results/{model_name}/\"))\n",
    "print_feature_analysis(result, save_dir=f\"results/{model_name}/\", model=model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analyzing Transcoders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[✓] Feature analysis for L4 F7671:\n",
      "\n",
      "📝 Description:  OFFICER   |    OFFICERS   |    Officer   |    Officers   |    officer   |    officers   |   Officer   |   Officers   |   officer   |   officers\n",
      "🔍 Top contributing embedding tokens:\n",
      " - Token ID 10971 |             officers | Activation: 10.8125 | Contribution:   77.000 | Rel: 0.00352441%\n",
      " - Token ID 11053 |              officer | Activation: 12.6875 | Contribution:   74.500 | Rel: 0.00340999%\n",
      " - Token ID 40159 |             Officers | Activation: 7.5000 | Contribution:   70.500 | Rel: 0.00322690%\n",
      " - Token ID 15860 |              Officer | Activation: 8.1875 | Contribution:   69.000 | Rel: 0.00315824%\n",
      " - Token ID 101672 |             Officers | Activation: 2.8750 | Contribution:   66.500 | Rel: 0.00304381%\n",
      " - Token ID 153897 |              OFFICER | Activation: 9.5625 | Contribution:   63.000 | Rel: 0.00288361%\n",
      " - Token ID 98452 |              Officer | Activation: 5.8125 | Contribution:   62.000 | Rel: 0.00283784%\n",
      " - Token ID 146209 |              officer | Activation: 4.0000 | Contribution:   59.250 | Rel: 0.00271197%\n",
      " - Token ID 213532 |             officers | Activation: 4.1250 | Contribution:   58.500 | Rel: 0.00267764%\n",
      " - Token ID 142835 |             OFFICERS | Activation: 3.6875 | Contribution:   54.500 | Rel: 0.00249455%\n",
      "\n",
      "📤 Output logits:\n",
      " + Top positive tokens:  headquarters   |    Headquarters   |   Headquarters\n",
      " - Top negative tokens: HasAnnotation\n",
      "\n",
      "🧩 Top contributing features:\n",
      " → From L3 F7268 | Contribution: +0.7266 | Relative: 0.00003326%\n",
      "   └─ Description:  //$\n",
      " → From L2 F13824 | Contribution: +0.5117 | Relative: 0.00002342%\n",
      "   └─ Description:  CLASSES    |    Classes    |    clases    |    classes    |   Classes    |   classes\n",
      " → From L1 F6507 | Contribution: +0.4766 | Relative: 0.00002181%\n",
      "   └─ Description:  Traditional    |    tradition    |    traditional    |   Traditional    |   traditional\n",
      " → From L1 F16178 | Contribution: +0.4648 | Relative: 0.00002128%\n",
      "   └─ Description:  fabricants\n",
      " → From L0 F10502 | Contribution: +0.3887 | Relative: 0.00001779%\n",
      "   └─ Description:  referenties    |   InitVars\n"
     ]
    }
   ],
   "source": [
    "# If result is not yet calculated, it can be analyzed and visualized separately\n",
    "# Important: for later layers, it will take significantly more time, since the process includes analyzing all contributing features as well\n",
    "feature = Feature(4, 0, 7671)\n",
    "\n",
    "result = analyze_feature(model, feature)\n",
    "print_feature_analysis(result, model=model)\n",
    "\n",
    "# All the features layerwise can be analyzed by running the following function\n",
    "\n",
    "# analyze_all_features(model, save_dir=\"results/gemma_2_2b/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[✓] Feature analysis for L4 F7671:\n",
      "\n",
      "📝 Description:  OFFICER   |    OFFICERS   |    Officer   |    Officers   |    officer   |    officers   |   Officer   |   Officers   |   createCanvas   |   officer   |   officers\n",
      "🔍 Top contributing embedding tokens:\n",
      " - Token ID 10971 |             officers | Activation: 10.8125 | Contribution:   77.000 | Rel: 0.00352441%\n",
      " - Token ID 11053 |              officer | Activation: 12.6875 | Contribution:   74.500 | Rel: 0.00340999%\n",
      " - Token ID 40159 |             Officers | Activation: 7.5000 | Contribution:   70.500 | Rel: 0.00322690%\n",
      " - Token ID 15860 |              Officer | Activation: 8.1875 | Contribution:   69.000 | Rel: 0.00315824%\n",
      " - Token ID 101672 |             Officers | Activation: 2.8750 | Contribution:   66.500 | Rel: 0.00304381%\n",
      " - Token ID 153897 |              OFFICER | Activation: 9.5625 | Contribution:   63.000 | Rel: 0.00288361%\n",
      " - Token ID 98452 |              Officer | Activation: 5.8125 | Contribution:   62.000 | Rel: 0.00283784%\n",
      " - Token ID 146209 |              officer | Activation: 4.0000 | Contribution:   59.250 | Rel: 0.00271197%\n",
      " - Token ID 213532 |             officers | Activation: 4.1250 | Contribution:   58.500 | Rel: 0.00267764%\n",
      " - Token ID 142835 |             OFFICERS | Activation: 3.6875 | Contribution:   54.500 | Rel: 0.00249455%\n",
      " - Token ID 182695 |         createCanvas | Activation: 3.9375 | Contribution:   42.000 | Rel: 0.00192241%\n",
      "\n",
      "📤 Output logits:\n",
      " + Top positive tokens:  headquarters   |    Headquarters   |   Headquarters   |    HQ   |    headquartered   |    hq   |    mammals   |    coordinators   |    Mammals   |    merkezi   |    anhydride   |    mammal   |    CENTRAL   |    Centrale   |    Metropol\n",
      " - Top negative tokens: HasAnnotation   |   хьтан   |    بيها   |   Rüyada   |   FunctionFlags   |   IBOutlet   |   endmodule   |   /******/   |   IsMutable   |   TargetApi   |   Xaml   |   SeekBar   |   ---*/   |   Становништво   |   LookAnd   |    يتيمه   |    considérons   |   ￡   |   CrossRef   |   schirm   |    AssemblyProduct\n",
      "\n",
      "🧩 Top contributing features:\n",
      " → From L3 F7268 | Contribution: +0.7266 | Relative: 0.00003326%\n",
      "   └─ Description:  //$\n",
      " → From L2 F13824 | Contribution: +0.5117 | Relative: 0.00002342%\n",
      "   └─ Description:  CLASSES    |    Classes    |    clases    |    classes    |   Classes    |   classes\n",
      " → From L1 F6507 | Contribution: +0.4766 | Relative: 0.00002181%\n",
      "   └─ Description:  Traditional    |    tradition    |    traditional    |   Traditional    |   traditional\n",
      " → From L1 F16178 | Contribution: +0.4648 | Relative: 0.00002128%\n",
      "   └─ Description:  fabricants\n",
      " → From L0 F10502 | Contribution: +0.3887 | Relative: 0.00001779%\n",
      "   └─ Description:  referenties    |   InitVars\n",
      " → From L0 F 582 | Contribution: +0.3574 | Relative: 0.00001636%\n",
      "   └─ Description:  Unless    |    met    |    unless    |   Unless    |   unless    |   除非\n"
     ]
    }
   ],
   "source": [
    "# We can also change z-score threshold to make it more or less sensitive, but it can bring more noise into the final result \n",
    "\n",
    "result = analyze_feature(model, feature, threshold_tokens=3, threshold_features=3)\n",
    "print_feature_analysis(result, model=model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Descriptions Postprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from weight_lens.postprocessing import unique_tokens, system_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(['officer'], ['headquarters'])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# First option - lemmatization\n",
    "# Since often we get versions of the same word, we can use this technique \n",
    "# to postprocess our obtained results \n",
    "\n",
    "feature = Feature(4, 0, 7671)\n",
    "result = analyze_feature(model, feature)\n",
    "\n",
    "unique_tokens(result['description']), unique_tokens(result['output_logits']['top_positive_tokens'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "            We're studying neurons in a neural network. Each neuron has certain inputs that activate it and outputs that it leads to. You will receive three pieces of information about a neuron: \n",
      "\n",
      "            1. The top important tokens.\n",
      "            2. The top tokens it promotes in the output. \n",
      "            3. The tokens it suppresses in the output.\n",
      "\n",
      "            These will be separated into three sections [Important Tokens] and [Text Promoted] and [Text Suppressed]. All three are a combination of tokens. You can infer the most likely output or function of the neuron based on these tokens. The tokens, specially [Text Promoted] and [Text Suppressed] may include noise, such as unrelated terms, symbols, or programming jargon. If these are not coherent, you may ignore them and do not include them in your response. If the [Important Tokens] are not combining to form a common theme, you may simply combine the words in the [Important Tokens] to form a single concept.\n",
      "\n",
      "            Focus on identifying a cohesive theme or concept shared by the most relevant tokens. \n",
      "\n",
      "            Your response should be a concise (1-2 sentence) explanation of the neuron, encompassing what triggers it (input) and what it does once triggered (output). If the two sides relate to one another you may include that in your explanation, otherwise simply state the input and output. Give your output in the following format:\n",
      "\n",
      "            [Concept: <Your interpretation of the neuron, based on the tokens provided>]\n",
      "\n",
      "            Example 1:\n",
      "\n",
      "            Input:\n",
      "            [Important Tokens]: ['accused', 'saw'] \n",
      "            [Tokens Promoted]: ['tvguidetime', 'hasfactory']\n",
      "            [Tokens Suppressed]: [\"'\", '']\n",
      "\n",
      "            Output:\n",
      "            [Concept: The verbs \"accused\" and \"saw\"]\n",
      "\n",
      "            Example 2:\n",
      "            Input:\n",
      "            [Important Tokens]: ['on', 'pada']\n",
      "            [Tokens Promoted]: ['behalf']\n",
      "            [Tokens Suppressed]: ['on', 'in']\n",
      "            \n",
      "            Output:\n",
      "            [Concept: The token \"on\" in the context of \"on behalf of\" and so on]\n",
      "\n",
      "\n",
      "            Example 3:\n",
      "            Input:\n",
      "            [Important Tokens]: ['carrier', 'missing', '']\n",
      "            [Tokens Promoted]: None\n",
      "            [Tokens Suppressed]: None\n",
      "\n",
      "            Output:\n",
      "            [Concept: the word \"missing\" and \"carrier\" ]\n",
      "\n",
      "            Example 4:\n",
      "            Input:\n",
      "            [Important Tokens]: ['democratic', 'dare']\n",
      "            [Tokens Promoted]: nan\n",
      "            [Tokens Suppressed]: nan\n",
      "            [Concept: The tokens \"democratic\" and \"dare\"]\n",
      "            \n",
      "\n",
      "\n",
      "\n",
      "\n",
      "            Important: Only output the [Concept] as your response. Do not include any other text or explanation for the same.\n",
      "                    \n",
      "            \n"
     ]
    }
   ],
   "source": [
    "# Another option -- postprocessing via explainer LLM, that would generate description\n",
    "# Here is the used system prompt\n",
    "\n",
    "print(system_prompt)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
