{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f3e3a66-4d05-4674-8a08-f1cad66246e7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db66285e-e71d-4f68-988a-823899bfb340",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.interpreter import FeatureInterpreter\n",
    "from src.builder import PromptBuilder\n",
    "from src.feature import Interpretation\n",
    "from src.engines import HuggingFaceEngine, VLLMEngine\n",
    "from src.samplers import TopSampler, RandomSampler, QuantileSampler\n",
    "\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb1d83c4-0055-4b49-a494-ba9a6526fc10",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_to_interpret = [\"Qwen/Qwen2.5-1.5B\", \"google/gemma-2-2b\"][0]\n",
    "context_tokenizer = AutoTokenizer.from_pretrained(model_to_interpret)\n",
    "\n",
    "interpreter_model = [\"Qwen/Qwen3-14B-AWQ\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b6acf8b-5c54-4f4f-a50d-1b177a3cd724",
   "metadata": {},
   "outputs": [],
   "source": [
    "engine_type = \"vllm\"\n",
    "\n",
    "if engine_type == \"hf\":\n",
    "\n",
    "    model = AutoModelForCausalLM.from_pretrained(interpreter_model).to(\"cuda\")\n",
    "    tokenizer = AutoTokenizer.from_pretrained(interpreter_model)\n",
    "    \n",
    "    engine = HuggingFaceEngine(\n",
    "        model=model, \n",
    "        tokenizer=tokenizer, \n",
    "        device=\"cuda\",\n",
    "    )\n",
    "\n",
    "elif engine_type == \"vllm\":\n",
    "    \n",
    "    engine = VLLMEngine(\n",
    "        model_name_or_path=interpreter_model,\n",
    "        max_model_len=4096,\n",
    "        dtype=\"auto\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63593a9a-bae1-474b-ac23-90c1920365e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "default_system_prompt = \"\"\"You are a meticulous AI researcher conducting an important investigation into patterns found in language. Your task is to analyze text and provide an explanation that thoroughly encapsulates possible patterns found in it.\n",
    "\n",
    "Guidelines:\n",
    "\n",
    "You will be given a list of text examples on which special words are selected and between delimiters like <<this>>. If a sequence of consecutive tokens all are important, the entire sequence of tokens will be contained between delimiters <<just like this>>. How important each token is for the behavior is listed after each example in parentheses.\n",
    "\n",
    "- Provide your explanation as a concise STANDALONE PHRASE describing the common contexts or concepts\n",
    "- Do not make lists of possible explanations\n",
    "- Focus on the essence of what patterns are present in the examples\n",
    "- Do NOT mention the texts, examples or the feature itself in your explanation\n",
    "- Do NOT write \"these texts\", \"feature detects\", \"the patterns suggest\" or something like that\n",
    "- The last line must be exactly formatted as: [EXPLANATION]:\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "adjusted_system_prompt = \"\"\"You are a meticulous AI researcher conducting an important investigation into patterns found in language. Your task is to analyze text and provide an explanation that thoroughly encapsulates possible patterns found in it.\n",
    "\n",
    "Guidelines:\n",
    "\n",
    "You will be given a list of text examples on which special words are selected and between delimiters like <<this>>. If a sequence of consecutive tokens all are important, the entire sequence of tokens will be contained between delimiters <<just like this>>. How important each token is for the behavior is listed after each example in parentheses.\n",
    "\n",
    "- Your explanation should be a concise STANDALONE PHRASE that describes observed patterns.\n",
    "- Focus on the essence of what patterns, concepts and contexts are present in the examples.\n",
    "- Do NOT mention the texts, examples, activations or the feature itself in your explanation.\n",
    "- Do NOT write \"these texts\", \"feature detects\", \"the patterns suggest\", \"activates\" or something like that.\n",
    "- Do not write what the feature does, e.g. instead of \"detects heart diseases in medical reports\" write \"heart diseases in medical reports\".\n",
    "- Write explanation in the last line exactly after the [EXPLANATION]:\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "eleuther_system_prompt = \"\"\"You are a meticulous AI researcher conducting an important investigation into patterns found in language. Your task is to analyze text and provide an explanation that thoroughly encapsulates possible patterns found in it.\n",
    "Guidelines:\n",
    "\n",
    "You will be given a list of text examples on which special words are selected and between delimiters like <<this>>. If a sequence of consecutive tokens all are important, the entire sequence of tokens will be contained between delimiters <<just like this>>. How important each token is for the behavior is listed after each example in parentheses.\n",
    "\n",
    "- Try to produce a concise final description. Simply describe the text latents that are common in the examples, and what patterns you found.\n",
    "- If the examples are uninformative, you don't need to mention them. Don't focus on giving examples of important tokens, but try to summarize the patterns found in the examples.\n",
    "- Do not mention the marker tokens (<< >>) in your explanation.\n",
    "- Do not make lists of possible explanations. Keep your explanations short and concise.\n",
    "- The last line of your response must be the formatted explanation, using [EXPLANATION]:\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa1e5319-9ca0-4349-a2c8-50a673457ae6",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer_args = {\n",
    "    'return_tensors': \"pt\",\n",
    "    'padding': True,\n",
    "    'truncation': True,\n",
    "    'max_length': 4096,\n",
    "    'padding_side': 'left'\n",
    "    }\n",
    "chat_template_args = {\n",
    "    'tokenize': False, \n",
    "    'add_generation_prompt': True,\n",
    "    'enable_thinking': True\n",
    "}\n",
    "engine.tokenizer_args = tokenizer_args\n",
    "engine.chat_template_args = chat_template_args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa24d422-f011-49ac-9074-bb0c2b5c5267",
   "metadata": {},
   "outputs": [],
   "source": [
    "name = \"post\"\n",
    "features = pickle.load(open(f'results/{name}.pkl', 'rb'))\n",
    "features_to_interpret = list(features.values())\n",
    "\n",
    "sampler = TopSampler(\n",
    "    sample_size=16,\n",
    "    top_quantile=0.5,\n",
    "    random_sample=True\n",
    ")\n",
    "\n",
    "builder = PromptBuilder(\n",
    "    tokenizer=context_tokenizer,\n",
    "    max_example_length=24,\n",
    "    sampler=sampler,\n",
    "    context_window=5,\n",
    "    system_prompt_template=adjusted_system_prompt\n",
    ")\n",
    "\n",
    "interpreter = FeatureInterpreter(\n",
    "    inference_engine=engine, \n",
    "    prompt_builder=builder,\n",
    "    num_rounds=1,\n",
    "    batch_size=16,\n",
    "    show_progress=True\n",
    ")\n",
    "\n",
    "interpretations = interpreter.run(\n",
    "    features_to_interpret, max_tokens=3072, temperature=0.75\n",
    ")\n",
    "\n",
    "for index in interpretations.keys():\n",
    "    feature = features[index]\n",
    "\n",
    "    if feature.interpretations is None:\n",
    "        feature.interpretations = []\n",
    "\n",
    "    # feature.interpretations = []\n",
    "        \n",
    "    for interp in interpretations[index]:\n",
    "        interpretation = Interpretation(\n",
    "            value = interp['explanation'],\n",
    "            commentary = (\n",
    "                f\"Model: {interpreter_model}, \"\n",
    "                f\"Sampler: {sampler.desc} ({sampler.sample_size}) samples\"\n",
    "            )\n",
    "        )\n",
    "        feature.interpretations.append(interpretation)\n",
    "\n",
    "with open(f'results/{name}.pkl', 'wb') as f:\n",
    "    pickle.dump(features, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17a8447e-be84-4535-add2-ca248fd36723",
   "metadata": {},
   "outputs": [],
   "source": [
    "name = \"pre\"\n",
    "features = pickle.load(open(f'results/{name}.pkl', 'rb'))\n",
    "features_to_interpret = list(features.values())\n",
    "\n",
    "sampler = TopSampler(\n",
    "    sample_size=16,\n",
    "    top_quantile=0.5,\n",
    "    random_sample=True\n",
    ")\n",
    "\n",
    "# sampler = QuantileSampler(\n",
    "#     sample_size=16,\n",
    "#     num_quantiles=4\n",
    "# )\n",
    "\n",
    "builder = PromptBuilder(\n",
    "    tokenizer=context_tokenizer,\n",
    "    max_example_length=24,\n",
    "    sampler=sampler,\n",
    "    context_window=5,\n",
    "    system_prompt_template=adjusted_system_prompt\n",
    ")\n",
    "\n",
    "interpreter = FeatureInterpreter(\n",
    "    inference_engine=engine, \n",
    "    prompt_builder=builder,\n",
    "    num_rounds=1,\n",
    "    batch_size=16,\n",
    "    show_progress=True\n",
    ")\n",
    "\n",
    "interpretations = interpreter.run(\n",
    "    features_to_interpret, max_tokens=3072, temperature=0.75\n",
    ")\n",
    "\n",
    "for index in interpretations.keys():\n",
    "    feature = features[index]\n",
    "\n",
    "    if feature.interpretations is None:\n",
    "        feature.interpretations = []\n",
    "\n",
    "    feature.interpretations = []\n",
    "        \n",
    "    for interp in interpretations[index]:\n",
    "        interpretation = Interpretation(\n",
    "            value = interp['explanation'],\n",
    "            commentary = (\n",
    "                f\"Model: {interpreter_model}, \"\n",
    "                f\"Sampler: {sampler.desc} ({sampler.sample_size}) samples\"\n",
    "            )\n",
    "        )\n",
    "        feature.interpretations.append(interpretation)\n",
    "\n",
    "with open(f'results/{name}.pkl', 'wb') as f:\n",
    "    pickle.dump(features, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae77ab05-3825-4a5a-815b-71995ded70cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "for name in ['pre', 'post', 'topk']:\n",
    "    ftrs = pickle.load(open(f'results/{name}.pkl', 'rb'))\n",
    "    for ftr in ftrs.values():\n",
    "        if ftr.interpretations is not None:\n",
    "            ftr.interpretations = ftr.interpretations[-1:]\n",
    "    with open(f'results/{name}.pkl', 'wb') as f:\n",
    "        pickle.dump(ftrs, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c99f35e-6936-4cd2-bdaa-9f38e8b85f9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "pre = pickle.load(open(f'results/pre.pkl', 'rb'))\n",
    "post = pickle.load(open(f'results/post.pkl', 'rb'))\n",
    "topk = pickle.load(open(f'results/topk.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0eac9fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def post_act_to_triplet(post_act_idx: int, m: int, n: int):\n",
    "    elements_per_head = m * n\n",
    "    head_idx = post_act_idx // elements_per_head\n",
    "    position_in_head = post_act_idx % elements_per_head\n",
    "    row_idx = position_in_head // n\n",
    "    col_idx = position_in_head % n\n",
    "    return head_idx, row_idx, col_idx\n",
    "\n",
    "def triplet_to_pre_act(head_idx: int, \n",
    "                      row_idx: int,\n",
    "                      col_idx: int,\n",
    "                      m: int, n: int):\n",
    "    features_per_head = m + n\n",
    "    base_idx = head_idx * features_per_head\n",
    "    row_pre_act = base_idx + row_idx\n",
    "    column_pre_act = base_idx + m + col_idx\n",
    "    return row_pre_act, column_pre_act\n",
    "\n",
    "def columns_for_row(head: int, row: int, m: int, n: int):\n",
    "    assert row < m, f\"Row index should be less than m = {m}\"\n",
    "    base = head * (m + n) + m\n",
    "    return [(col, base + col) for col in range(n)]\n",
    "\n",
    "def rows_for_column(head: int, column: int, m: int, n: int):\n",
    "    assert column < n, f\"Column index should be less than n = {n}\"\n",
    "    base = head * (m + n)\n",
    "    return [(row, base + row) for row in range(m)]\n",
    "\n",
    "def pre_to_post(head: int, row: int, column: int, m: int, n: int):\n",
    "    return head * (m * n) + row * n + column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e86625e-acae-4afc-aa90-3107ac2358af",
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = np.random.choice(list(post.keys()), 8, replace=False)\n",
    "indices = range(64, 80)\n",
    "for index in indices:\n",
    "    triplet = post_act_to_triplet(index, 4, 4)\n",
    "    h, r, c = triplet\n",
    "    row, col = triplet_to_pre_act(h, r, c, 4, 4)\n",
    "    \n",
    "    print(f\"(h{h}, r{r}): {pre[row].interpretations[-1].value}\")\n",
    "    print(f\"(h{h}, c{c}): {pre[col].interpretations[-1].value}\")\n",
    "    print(f\"(h{h}, r{r}, c{c}) - {index}: {post[index].interpretations[-1].value}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9edaf95f-5192-46d6-83d6-64e9cddd50d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "m, n = 4, 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f24d665-de98-42f6-8ed7-11633a859e82",
   "metadata": {},
   "outputs": [],
   "source": [
    "head = np.random.choice(128, 1)[0]\n",
    "head = 67\n",
    "base = (m + n) * head\n",
    "for row, i in enumerate(range(base, base + m)):\n",
    "    print(f\"(h{head}, r{row}): {pre[i].interpretations[-1].value}\")\n",
    "for col, i in enumerate(range(base + m, base + m + n)):\n",
    "    print(f\"(h{head}, c{col}): {pre[i].interpretations[-1].value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61a6ecb5-a0e5-4061-838a-f152398da0fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "m, n = 4, 4\n",
    "head, row = 67, 3\n",
    "row_idx = head * (m + n) + row\n",
    "print(f\"(h{head}, r{row}): {pre[row_idx].interpretations[-1].value}\\n\")\n",
    "for col, col_idx in columns_for_row(head, row, m, n):\n",
    "    print(f\"(h{head}, c{col}): {pre[col_idx].interpretations[-1].value}\")\n",
    "    postact_idx = pre_to_post(head, row, col, m, n)\n",
    "    print(f\"(h{head}, r{row}, c{col}) - {postact_idx}: {post[postact_idx].interpretations[-1].value}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c83218d-8d59-4ae0-9de0-71a432800511",
   "metadata": {},
   "outputs": [],
   "source": [
    "m, n = 4, 4\n",
    "head, col = 4, 3\n",
    "col_idx = head * (m + n) + m + col\n",
    "print(f\"(h{head}, c{col}): {pre[col_idx].interpretations[-1].value}\\n\")\n",
    "for row, row_idx in rows_for_column(head, col, m, n):\n",
    "    print(f\"(h{head}, r{row}): {pre[row_idx].interpretations[-1].value}\")\n",
    "    postact_idx = pre_to_post(head, row, col, m, n)\n",
    "    print(f\"(h{head}, r{row}, c{col}) - {postact_idx}: {post[postact_idx].interpretations[-1].value}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef8654c4-8451-488b-8546-fe2c6a95480a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
