{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d27bb8e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07a34393",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d58ed106",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "# model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3a80e7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# from transformers import BitsAndBytesConfig\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    # device_map=device_map,\n",
    "    device_map=\"auto\",\n",
    "    # quantization_config = BitsAndBytesConfig(\n",
    "    #     # load_in_4bit=True\n",
    "    #     load_in_8bit=True\n",
    "    # )\n",
    "    attn_implementation=\"eager\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceae08b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "device_map = get_device_map(mt.name, 80, n_gpus=8)\n",
    "\n",
    "def module_to_device(module_name):\n",
    "    for key in device_map:\n",
    "        if module_name.startswith(key):\n",
    "            return f\"cuda:{device_map[key]}\"\n",
    "    return \"cpu\"\n",
    "\n",
    "module_to_device(mt.mlp_module_name_format.format(20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32269e8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOneTask, SelectOrderTask\n",
    "\n",
    "#################################################################################\n",
    "# TASK_CLS = SelectOrderTask\n",
    "# prompt_template_idx = 1\n",
    "TASK_CLS = SelectOneTask\n",
    "prompt_template_idx = 3\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        # \"profession.json\"\n",
    "        # \"nationality.json\"\n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "\n",
    "print(select_task)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7386017a",
   "metadata": {},
   "source": [
    "## Patching the residual states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdbe21db",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import random\n",
    "from src.selection.utils import KeyedSet, get_first_token_id, verify_correct_option\n",
    "from src.selection.data import SelectionSample\n",
    "from src.functional import predict_next_token\n",
    "from src.tokens import prepare_input\n",
    "from src.selection.data import get_counterfactual_samples_within_task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "388f1913",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_sample, clean_sample = get_counterfactual_samples_within_task(\n",
    "    # patch_category=\"politician\",\n",
    "    # clean_category=\"actor\",\n",
    "    mt=mt,\n",
    "    task=select_task,\n",
    "    patch_category=\"fruit\",\n",
    "    clean_category=\"vehicle\",\n",
    "    filter_by_lm_prediction=True,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    option_style=OPTION_STYLE,\n",
    "    distinct_options=True,\n",
    "    patch_n_distractors=5,\n",
    "    clean_n_distractors=5\n",
    ")\n",
    "\n",
    "# patch_sample.default_option_style = \"single_line\"\n",
    "# clean_sample.default_option_style = \"numbered\"\n",
    "\n",
    "print(patch_sample.prompt(), \">>\", patch_sample.obj)\n",
    "print(clean_sample.prompt(), \">>\", clean_sample.obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff7d6b06",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_tokenized = prepare_input(tokenizer=mt, prompts=clean_sample.prompt())\n",
    "print(mt.tokenizer.decode(clean_tokenized.input_ids[0], skip_special_tokens=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6899bb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.tokens import prepare_input\n",
    "from src.selection.utils import get_first_token_id\n",
    "from src.functional import interpret_logits, PatchSpec\n",
    "from itertools import product\n",
    "from src.utils.typing import TokenizerOutput, ArrayLike\n",
    "from typing import Optional, Union\n",
    "from src.functional import get_module_nnsight, untuple, get_hs, predict_next_token\n",
    "\n",
    "def layer_wise_patching(\n",
    "    mt: ModelandTokenizer,\n",
    "    patch_sample: SelectionSample,\n",
    "    clean_sample: SelectionSample,\n",
    "    token_indices: list[int] = [-3, -2, -1],\n",
    "):\n",
    "    patch_tokenized = prepare_input(\n",
    "        tokenizer=mt.tokenizer, prompts=patch_sample.prompt()\n",
    "    )\n",
    "    clean_tokenized = prepare_input(\n",
    "        tokenizer=mt.tokenizer, prompts=clean_sample.prompt()\n",
    "    )\n",
    "\n",
    "    random_idx = random.choice(\n",
    "        list(\n",
    "            set(list(range(len(clean_sample.options))))\n",
    "            - {\n",
    "                patch_sample.obj_idx,\n",
    "                clean_sample.obj_idx,\n",
    "                clean_sample.metadata[\"track_type_obj_idx\"],\n",
    "            }\n",
    "        )\n",
    "    )\n",
    "\n",
    "    track_tokens = {\n",
    "        \"predicate_target\": clean_sample.metadata[\"track_type_obj_token_id\"],\n",
    "        \"clean_ans\": get_first_token_id(clean_sample.obj, mt.tokenizer, prefix=\" \"),\n",
    "        \"patch_ans\": get_first_token_id(patch_sample.obj, mt.tokenizer, prefix=\" \"),\n",
    "        \"patch_position\": get_first_token_id(\n",
    "            clean_sample.options[patch_sample.obj_idx], mt.tokenizer, prefix=\" \"\n",
    "        ),\n",
    "        \"random_distractor\": get_first_token_id(\n",
    "            clean_sample.options[random_idx], mt.tokenizer, prefix=\" \"\n",
    "        ),\n",
    "    }\n",
    "\n",
    "    ret = {\"track_tokens\": track_tokens}\n",
    "\n",
    "    logit_location = (mt.lm_head_name, -1)\n",
    "    patch_locations = list(product(mt.layer_names, token_indices))\n",
    "    # patch_locations = []\n",
    "    print(patch_locations)\n",
    "\n",
    "    patch_hs = get_hs(\n",
    "        mt=mt,\n",
    "        input=patch_tokenized,\n",
    "        locations=patch_locations + [logit_location],\n",
    "        return_dict=True,\n",
    "    )\n",
    "    patch_logits = patch_hs[logit_location]\n",
    "    patch_pred, patch_track = interpret_logits(\n",
    "        tokenizer=mt.tokenizer,\n",
    "        logits=patch_logits,\n",
    "        interested_tokens=track_tokens.values(),\n",
    "    )\n",
    "    logger.debug(f\"patch_pred={[str(pred) for pred in patch_pred]}\")\n",
    "    logger.debug(f\"patch_track={patch_track}\")\n",
    "    ret[\"patch_pred\"] = patch_pred\n",
    "    ret[\"patch_track\"] = patch_track\n",
    "\n",
    "    clean_hs = get_hs(\n",
    "        mt=mt,\n",
    "        input=clean_tokenized,\n",
    "        locations=patch_locations + [logit_location],\n",
    "        return_dict=True,\n",
    "    )\n",
    "    clean_logits = clean_hs[logit_location]\n",
    "    clean_pred, clean_track = interpret_logits(\n",
    "        tokenizer=mt.tokenizer,\n",
    "        logits=clean_logits,\n",
    "        interested_tokens=track_tokens.values(),\n",
    "    )\n",
    "    logger.debug(f\"clean_pred={[str(pred) for pred in clean_pred]}\")\n",
    "    logger.debug(f\"clean_track={clean_track}\")\n",
    "    ret[\"clean_pred\"] = clean_pred\n",
    "    ret[\"clean_track\"] = clean_track\n",
    "\n",
    "    layer_wise_patching_results = {}\n",
    "    for layer in mt.layer_names:\n",
    "        patch_spec = []\n",
    "        for token_idx in token_indices:\n",
    "            patch_spec.append(\n",
    "                PatchSpec(\n",
    "                    location=(layer, token_idx), patch=patch_hs[(layer, token_idx)]\n",
    "                )\n",
    "            )\n",
    "\n",
    "        # int_pred, int_track = predict_next_token(\n",
    "        #     mt=mt,\n",
    "        #     inputs=clean_tokenized,\n",
    "        #     token_of_interest=track_tokens.values(),\n",
    "        #     patches=patch_spec\n",
    "        # )\n",
    "        int_hs = get_hs(\n",
    "            mt=mt,\n",
    "            input=clean_tokenized,\n",
    "            locations=[logit_location],\n",
    "            patches=patch_spec,\n",
    "            return_dict=True,\n",
    "        )\n",
    "        int_logits = int_hs[logit_location]\n",
    "        int_pred, int_track = interpret_logits(\n",
    "            tokenizer=mt.tokenizer,\n",
    "            logits=int_logits,\n",
    "            interested_tokens=track_tokens.values(),\n",
    "        )\n",
    "\n",
    "        logger.debug(f\"Layer {layer}: int_pred={[str(pred) for pred in int_pred]}\")\n",
    "        layer_wise_patching_results[layer] = {\n",
    "            \"int_pred\": int_pred,\n",
    "            \"int_track\": int_track,\n",
    "        }\n",
    "\n",
    "    ret[\"layer_wise_patching_results\"] = layer_wise_patching_results\n",
    "    return ret\n",
    "\n",
    "\n",
    "patching_result = layer_wise_patching(\n",
    "    mt=mt,\n",
    "    patch_sample=patch_sample,\n",
    "    clean_sample=clean_sample,\n",
    "    token_indices=[-3, -2, -1],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "081ae5d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import free_gpu_cache\n",
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 64\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    print(f\"sample {len(validation_set)+1} / {validation_limit}\")\n",
    "    patch, clean = get_counterfactual_samples_within_task(\n",
    "        mt=mt,\n",
    "        task=select_task,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        distinct_options=True,\n",
    "        # n_distractors=N_DISTRACTORS,\n",
    "        patch_n_distractors=N_DISTRACTORS,\n",
    "        clean_n_distractors=N_DISTRACTORS\n",
    "    )\n",
    "    validation_set.append((clean, patch))\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0a0463e",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "for clean, patch in validation_set:\n",
    "    result = layer_wise_patching(\n",
    "        mt=mt,\n",
    "        patch_sample=patch,\n",
    "        clean_sample=clean,\n",
    "        token_indices=[-3, -2, -1]\n",
    "    )\n",
    "    results.append(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05c376d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# results = [patching_result]\n",
    "\n",
    "scores = {token_type: [] for token_type in results[0][\"track_tokens\"].keys()}\n",
    "for result in results:\n",
    "    clean_track = result[\"clean_track\"]\n",
    "    patch_track = result[\"patch_track\"]\n",
    "\n",
    "    for token_type in scores.keys():\n",
    "        layerwise_scores = []\n",
    "        token_id = result[\"track_tokens\"][token_type]\n",
    "        for layer_idx in range(mt.n_layer):\n",
    "            score = result[\"layer_wise_patching_results\"][mt.layer_names[layer_idx]][\"int_track\"][token_id][1].logit\n",
    "            layerwise_scores.append(score)\n",
    "        scores[token_type].append(layerwise_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "119dfaa7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.figure(figsize=(12, 6))\n",
    "for token_type, layerwise_scores_list in scores.items():\n",
    "    # Compute mean and std deviation across results for each layer\n",
    "    mean_scores = np.mean(layerwise_scores_list, axis=0)\n",
    "    sterr_scores = np.std(layerwise_scores_list, axis=0) / np.sqrt(len(layerwise_scores_list))\n",
    "\n",
    "    plt.plot(mean_scores, label=f\"{token_type}\")\n",
    "    plt.fill_between(range(len(mean_scores)), mean_scores - sterr_scores, mean_scores + sterr_scores, alpha=0.2)\n",
    "\n",
    "plt.xlabel(\"Layer\")\n",
    "plt.ylabel(\"Logit(x)\")\n",
    "plt.title(f\"Residual | {mt.name.split('/')[-1]}\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6674f663",
   "metadata": {},
   "source": [
    "## Loading and calculating the basis directions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46eb4675",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "cached_states_dir = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/cached_states\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    \"last_token\"\n",
    ")\n",
    "sample_file_name = \"sample_00170.npz\"\n",
    "\n",
    "sample_states = np.load(\n",
    "    os.path.join(cached_states_dir, sample_file_name),\n",
    "    allow_pickle=True,\n",
    ")\n",
    "print(sample_states.files)\n",
    "list(sample_states[\"states\"].item().keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d267179",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = SelectionSample.from_dict(sample_states[\"sample\"].item())\n",
    "tokenized = TokenizerOutput(data=sample.metadata[\"tokenized\"])\n",
    "print(torch.Tensor(tokenized.input_ids).shape)\n",
    "print(mt.tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d944b6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "#######################################################\n",
    "# LIMIT = 128\n",
    "LIMIT = len(os.listdir(cached_states_dir))\n",
    "#######################################################\n",
    "\n",
    "cached_states = {}\n",
    "\n",
    "for idx, file_name in enumerate(os.listdir(cached_states_dir)[:LIMIT]):\n",
    "    sample_states = np.load(\n",
    "        os.path.join(cached_states_dir, file_name), allow_pickle=True\n",
    "    )\n",
    "    states = {}\n",
    "    for key, value in sample_states[\"states\"].item().items():\n",
    "        layer_idx, token_idx = key.split(\"_<>_\")\n",
    "        device = module_to_device(layer_idx)\n",
    "        states[layer_idx] = torch.Tensor(value).to(mt.dtype).to(device)\n",
    "\n",
    "    for layer_idx in states:\n",
    "        if layer_idx not in cached_states:\n",
    "            cached_states[layer_idx] = []\n",
    "        cached_states[layer_idx].append(states[layer_idx])\n",
    "\n",
    "    if (idx + 1) % 1000 == 0:\n",
    "        logger.info(\n",
    "            f\"Processed {idx+1}/{LIMIT} files... ({(idx+1) / LIMIT * 100:.2f}%)\"\n",
    "        )\n",
    "\n",
    "cached_states = {\n",
    "    layer_name: torch.stack(cached_states[layer_name], dim=0)\n",
    "    .to(mt.dtype)\n",
    "    .to(module_to_device(layer_name))\n",
    "    for layer_name in cached_states\n",
    "}\n",
    "\n",
    "free_gpu_cache()\n",
    "\n",
    "for key in cached_states:\n",
    "    print(f\"{key}: {cached_states[key].device}, {cached_states[key].shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75f129b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from src.functional import free_gpu_cache\n",
    "from src.utils.typing import SVD\n",
    "\n",
    "basis_save_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"basis_directions\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    \"pca\",\n",
    "    \"objects\",\n",
    "    \"last_token\",\n",
    ")\n",
    "os.makedirs(basis_save_path, exist_ok=True)\n",
    "\n",
    "for layer_idx in cached_states:\n",
    "    print(layer_idx)\n",
    "    X = cached_states[layer_idx]\n",
    "    X_centered = X - X.mean(dim=0, keepdim=True)\n",
    "    svd = SVD.calculate(X_centered)\n",
    "    basis_directions = svd.Vh.to(mt.dtype)\n",
    "    with open(os.path.join(basis_save_path, f\"{layer_idx}.pt\"), \"wb\") as f:\n",
    "        torch.save(basis_directions, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2a6f6a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.typing import SVD\n",
    "\n",
    "basis_save_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"basis_directions\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    \"svd\",\n",
    "    \"objects\",\n",
    "    \"last_token\",\n",
    ")\n",
    "os.makedirs(basis_save_path, exist_ok=True)\n",
    "\n",
    "for layer_idx in cached_states:\n",
    "    print(layer_idx)\n",
    "    svd = SVD.calculate(cached_states[layer_idx])\n",
    "    basis_directions = svd.Vh.to(mt.dtype)\n",
    "    with open(os.path.join(basis_save_path, f\"{layer_idx}.pt\"), \"wb\") as f:\n",
    "        torch.save(basis_directions, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffe27ec9",
   "metadata": {},
   "source": [
    "## Loading the calculated basis directions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ff376d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import free_gpu_cache\n",
    "basis_save_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"basis_directions\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    \"svd\",\n",
    "    \"objects\",\n",
    "    \"last_token\",\n",
    ")\n",
    "basis_directions = {}\n",
    "layer_names = [mt.layer_name_format.format(layer_idx) for layer_idx in range(28, 40)]\n",
    "\n",
    "for layer_idx in layer_names:\n",
    "    with open(os.path.join(basis_save_path, f\"{layer_idx}.pt\"), \"rb\") as f:\n",
    "        basis_directions[layer_idx] = torch.load(f)# .to(f\"cuda:{torch.cuda.device_count()-1}\")\n",
    "    print(\n",
    "        layer_idx, basis_directions[layer_idx].shape, basis_directions[layer_idx].device\n",
    "    )\n",
    "\n",
    "free_gpu_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58c4ce28",
   "metadata": {},
   "source": [
    "## Train Subspace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d8c28bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import get_module_nnsight\n",
    "\n",
    "def apply_patch_with_projection(\n",
    "    mt,\n",
    "    clean_prompts,\n",
    "    patch_prompts,\n",
    "    projections,\n",
    "    token_idx = -1,\n",
    "):\n",
    "    \n",
    "    with mt.trace() as tracer:\n",
    "\n",
    "        # cache states for patching\n",
    "        patch_hs = {}\n",
    "        with tracer.invoke(patch_prompts):\n",
    "            for layer_name in projections:\n",
    "                module = get_module_nnsight(mt, layer_name)\n",
    "                current_states = (\n",
    "                    module.output\n",
    "                    if (\"mlp\" in layer_name or layer_name == mt.embedder_name)\n",
    "                    else module.output[0]\n",
    "                )\n",
    "                if current_states.ndim == 2:\n",
    "                    current_states = current_states.unsqueeze(0)\n",
    "                patch_hs[layer_name] = current_states[:, token_idx, :].clone()\n",
    "        \n",
    "        # apply the patch\n",
    "        with tracer.invoke(clean_prompts):\n",
    "            for layer_name in projections:\n",
    "                module = get_module_nnsight(mt, layer_name)\n",
    "                current_states = (\n",
    "                    module.output\n",
    "                    if (\"mlp\" in layer_name or layer_name == mt.embedder_name)\n",
    "                    else module.output[0]\n",
    "                )\n",
    "                if current_states.ndim == 2:\n",
    "                    current_states = current_states.unsqueeze(0)\n",
    "                clean_h = current_states[:, token_idx, :].clone()\n",
    "\n",
    "                # apply the projection\n",
    "                # print(f\"{layer_name} | {patch_hs[layer_name].device=} | {projections[layer_name].device=}\")\n",
    "                device=clean_h.device\n",
    "                patch_proj = torch.matmul(patch_hs[layer_name].to(device), projections[layer_name].to(device))\n",
    "                clean_proj = torch.matmul(clean_h.to(device), projections[layer_name].to(device))\n",
    "                current_states[:, token_idx, :] = clean_h - clean_proj + patch_proj\n",
    "                # current_states[:, token_idx, :] = patch_hs[layer_name]\n",
    "            \n",
    "            # get the logits after the intervention\n",
    "            logits = mt.lm_head.output[:, -1].save()\n",
    "\n",
    "        # del patch_hs\n",
    "\n",
    "    return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dbd4a66",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks = {\n",
    "    layer_name: torch.ones(\n",
    "        mt.n_embd, \n",
    "        dtype=mt.dtype, \n",
    "        device=module_to_device(layer_name),\n",
    "        # device=f\"cuda:{torch.cuda.device_count()-1}\",\n",
    "        requires_grad=True\n",
    "    )\n",
    "    for layer_name in basis_directions.keys()\n",
    "}\n",
    "\n",
    "dummy_projections = {}\n",
    "for layer_idx in basis_directions.keys():\n",
    "    mask = masks[layer_idx]\n",
    "    basis_direction = basis_directions[layer_idx]\n",
    "    masked_directions = basis_direction * mask[:, None]\n",
    "    dummy_projections[layer_idx] = masked_directions.t() @ masked_directions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28227fdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import free_gpu_cache\n",
    "from src.selection.data import get_counterfactual_samples_within_task\n",
    "\n",
    "free_gpu_cache()\n",
    "train_set = []\n",
    "train_limit = 512\n",
    "# train_limit=64\n",
    "\n",
    "while len(train_set) < train_limit:\n",
    "    print(f\"sample {len(train_set)+1} / {train_limit}\")\n",
    "    patch, clean = get_counterfactual_samples_within_task(\n",
    "        mt=mt,\n",
    "        task=select_task,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        distinct_options=True,\n",
    "        # n_distractors=N_DISTRACTORS,\n",
    "        patch_n_distractors=N_DISTRACTORS,\n",
    "        clean_n_distractors=N_DISTRACTORS\n",
    "    )\n",
    "    train_set.append((clean, patch))\n",
    "\n",
    "len(train_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2631faf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Literal\n",
    "from src.selection.utils import get_first_token_id\n",
    "\n",
    "batch_size = 8\n",
    "\n",
    "target_obj: Literal[\"predicate_target\", \"patch_position\"] = \"predicate_target\"\n",
    "\n",
    "targets = []\n",
    "clean_samples = []\n",
    "patch_samples = []\n",
    "\n",
    "for clean_sample, patch_sample in train_set[: batch_size]:\n",
    "    objs = {\n",
    "        \"predicate_target\": clean_sample.metadata[\"track_type_obj_token_id\"],\n",
    "        \"patch_position\": get_first_token_id(\n",
    "            clean_sample.options[patch_sample.obj_idx], mt.tokenizer, prefix=\" \"\n",
    "        )\n",
    "    }\n",
    "    print(patch_sample.prompt())\n",
    "    print(clean_sample.prompt())\n",
    "    print(f\"{objs[target_obj]}: {mt.tokenizer.decode(objs[target_obj])}\")\n",
    "\n",
    "    print(\"-\" * 50)\n",
    "\n",
    "    targets.append(objs[target_obj])\n",
    "    patch_samples.append(patch_sample)\n",
    "    clean_samples.append(clean_sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afb3ca3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.tokens import prepare_input\n",
    "from src.utils.typing import TokenizerOutput\n",
    "\n",
    "prompts = []\n",
    "prompts.extend([sample.prompt() for sample in clean_samples])\n",
    "prompts.extend([sample.prompt() for sample in patch_samples])\n",
    "tokenized = prepare_input(\n",
    "    prompts=prompts, tokenizer=mt\n",
    ")\n",
    "\n",
    "clean_tokenized = TokenizerOutput(\n",
    "    data={k: v[: len(clean_samples), :] for k, v in tokenized.items()}\n",
    ")\n",
    "patch_tokenized = TokenizerOutput(\n",
    "    data={k: v[len(clean_samples) :, :] for k, v in tokenized.items()}\n",
    ")\n",
    "\n",
    "clean_tokenized.input_ids.shape, patch_tokenized.input_ids.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25db296e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import interpret_logits\n",
    "\n",
    "logits = apply_patch_with_projection(\n",
    "    mt=mt,\n",
    "    clean_prompts=clean_tokenized,\n",
    "    patch_prompts=patch_tokenized,\n",
    "    projections=dummy_projections,\n",
    "    token_idx=-1,\n",
    ")   \n",
    "\n",
    "mt._model.zero_grad()\n",
    "free_gpu_cache()\n",
    "\n",
    "print(f\"{logits.shape=}\")\n",
    "\n",
    "for logit in logits:\n",
    "    pred = interpret_logits(tokenizer=mt, logits = logit)\n",
    "    print([f\"{str(p)}\" for p in pred])\n",
    "\n",
    "target_logits = [logit[tok] for logit, tok in zip(logits, targets)]\n",
    "print(target_logits)\n",
    "torch.stack(target_logits).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c28e9a28",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim import Adam\n",
    "from src.selection.data import SelectionSample\n",
    "\n",
    "\n",
    "def get_optimal_projection(\n",
    "    mt: ModelandTokenizer,\n",
    "    train_set: list[tuple[SelectionSample, SelectionSample]],\n",
    "    basis_directions: dict[str, torch.Tensor],\n",
    "    target: Literal[\"predicate_target\", \"patch_position\"] = \"predicate_target\",\n",
    "    learning_rate: float = 1e-2,\n",
    "    n_epochs: int = 5,\n",
    "    lamb=1e-5,\n",
    "    batch_size: int = 8,\n",
    "):\n",
    "    masks = {\n",
    "        layer_name: torch.full(\n",
    "            (mt.n_embd,),\n",
    "            0.5,\n",
    "            dtype=mt.dtype,\n",
    "            device=module_to_device(layer_name),\n",
    "            # device=f\"cuda:{torch.cuda.device_count()-1}\",\n",
    "            requires_grad=True,\n",
    "        )\n",
    "        for layer_name in basis_directions.keys()\n",
    "    }\n",
    "    optimizer = Adam(masks.values(), lr=learning_rate)\n",
    "\n",
    "    batches = []\n",
    "    for batch_start in range(0, len(train_set), batch_size):\n",
    "        batches.append(train_set[batch_start : batch_start + batch_size])\n",
    "\n",
    "    losses = []\n",
    "    for epoch in range(n_epochs):\n",
    "        epoch_loss = 0\n",
    "        for batch_idx, batch in enumerate(batches):\n",
    "            clean_samples, patch_samples = zip(*batch)\n",
    "\n",
    "            prompts = []\n",
    "            prompts.extend([sample.prompt() for sample in clean_samples])\n",
    "            prompts.extend([sample.prompt() for sample in patch_samples])\n",
    "            tokenized = prepare_input(prompts=prompts, tokenizer=mt)\n",
    "\n",
    "            clean_tokenized = TokenizerOutput(\n",
    "                data={k: v[: len(clean_samples), :] for k, v in tokenized.items()}\n",
    "            )\n",
    "            patch_tokenized = TokenizerOutput(\n",
    "                data={k: v[len(clean_samples) :, :] for k, v in tokenized.items()}\n",
    "            )\n",
    "            batch_targets = []\n",
    "            batch_distractors = []\n",
    "            if target == \"predicate_target\":\n",
    "                batch_targets = [\n",
    "                    clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "                    for clean_sample in clean_samples\n",
    "                ]\n",
    "                batch_distractors = [\n",
    "                    [\n",
    "                        get_first_token_id(tokenizer=mt.tokenizer, name=opt, prefix=\" \")\n",
    "                        for idx, opt in enumerate(clean_sample.options)\n",
    "                        if idx != clean_sample.metadata[\"track_type_obj_idx\"]\n",
    "                    ]\n",
    "                    for clean_sample in clean_samples\n",
    "                ]\n",
    "            elif target == \"patch_position\":\n",
    "                batch_targets = [\n",
    "                    get_first_token_id(\n",
    "                        clean_sample.options[patch_sample.obj_idx],\n",
    "                        tokenizer=mt.tokenizer,\n",
    "                        prefix=\" \",\n",
    "                    )\n",
    "                    for clean_sample, patch_sample in zip(clean_samples, patch_samples)\n",
    "                ]\n",
    "                batch_distractors = [\n",
    "                    [\n",
    "                        get_first_token_id(tokenizer=mt.tokenizer, name=opt, prefix=\" \")\n",
    "                        for idx, opt in enumerate(clean_sample.options)\n",
    "                        if idx != patch_sample.obj_idx\n",
    "                    ]\n",
    "                    for clean_sample, patch_sample in zip(clean_samples, patch_samples)\n",
    "                ]\n",
    "            # make sure that information about the patch options isn't being carried\n",
    "            patch_options = [\n",
    "                [\n",
    "                    get_first_token_id(tokenizer=mt.tokenizer, name=opt, prefix=\" \")\n",
    "                    for opt in patch_sample.options\n",
    "                ]\n",
    "                for patch_sample in patch_samples\n",
    "            ]\n",
    "\n",
    "            # debugging\n",
    "            # print(f\"{len(clean_samples)=}, {len(patch_samples)=}\")\n",
    "            # print(f\"{len(batch_targets)=}, {len(batch_distractors)=}\")\n",
    "            # print(f\"{len(patch_options)=}\")\n",
    "            # for idx in range(len(clean_samples)):\n",
    "            #     print(patch_samples[idx].prompt(), \">>\", patch_samples[idx].obj)\n",
    "            #     print(clean_samples[idx].prompt(), \">>\", clean_samples[idx].obj)\n",
    "            #     print(\n",
    "            #         f'target: {batch_targets[idx]}  [\"{mt.tokenizer.decode(batch_targets[idx])}\"]'\n",
    "            #     )\n",
    "            #     print(f\"distractors={[mt.tokenizer.decode(tok) for tok in batch_distractors[idx]]}\")\n",
    "            #     print(f\"patch_options={[mt.tokenizer.decode(tok) for tok in patch_options[idx]]}\")\n",
    "            #     print(\"-\" * 50)\n",
    "\n",
    "            projections = {}\n",
    "            for layer_name in basis_directions.keys():\n",
    "                mask = masks[layer_name]\n",
    "                basis_direction = basis_directions[layer_name]\n",
    "                # print(f\"{layer_name} | {mask.device} | {basis_direction.device}\")\n",
    "                masked_directions = basis_direction * mask[:, None]\n",
    "                # V directions are row-wise\n",
    "                projections[layer_name] = masked_directions.t() @ masked_directions\n",
    "\n",
    "            logits = apply_patch_with_projection(\n",
    "                mt=mt,\n",
    "                clean_prompts=clean_tokenized,\n",
    "                patch_prompts=patch_tokenized,\n",
    "                projections=projections,\n",
    "                token_idx=-1,\n",
    "            )\n",
    "\n",
    "            # calculate target loss\n",
    "            target_logits = [logit[tok] for logit, tok in zip(logits, batch_targets)]\n",
    "            target_loss = -torch.stack(target_logits).mean()  # need this to go up\n",
    "\n",
    "            # calculate distractor loss\n",
    "            distractor_logits = [\n",
    "                logit[distractor_tokens].mean()\n",
    "                for logit, distractor_tokens in zip(logits, batch_distractors)\n",
    "            ]\n",
    "            distractor_loss = 0.1 * torch.stack(distractor_logits).mean()\n",
    "\n",
    "            # patch option loss\n",
    "            patch_option_logits = [\n",
    "                logit[patch_option_tokens].mean()\n",
    "                for logit, patch_option_tokens in zip(logits, patch_options)\n",
    "            ]\n",
    "            patch_option_loss = 0.1* torch.stack(patch_option_logits).mean()\n",
    "\n",
    "            # mask loss\n",
    "            mask_l1_loss = None\n",
    "            for mask in masks.values():\n",
    "                if mask_l1_loss is None:\n",
    "                    mask_l1_loss = lamb * mask.norm(p=1)\n",
    "                else:\n",
    "                    mask_l1_loss += lamb * mask.norm(p=1).to(mask_l1_loss.device)\n",
    "\n",
    "            loss = (\n",
    "                target_loss\n",
    "                + distractor_loss\n",
    "                + patch_option_loss\n",
    "                + mask_l1_loss.to(target_loss.device)\n",
    "            )\n",
    "            logger.debug(\n",
    "                f\"Epoch={epoch+1} | {batch_idx=} |>> {target_loss.item():.4f} + {distractor_loss.item():.4f} + {patch_option_loss.item():.4f} + {mask_l1_loss.item():.4f} = {loss.item():.4f}\"\n",
    "            )\n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # clamp the masks to [0, 1] after optimization step\n",
    "            with torch.no_grad():\n",
    "                for mask in masks.values():\n",
    "                    # mask.clamp_(0, 1)\n",
    "                    # mask += 1e-4  # to avoid zero gradients\n",
    "                    mask.data = torch.sigmoid(mask.data * 5 - 2.5)  # Steeper sigmoid\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "            losses.append(loss.item())\n",
    "            del (\n",
    "                projections,\n",
    "                logits,\n",
    "            )\n",
    "            free_gpu_cache()\n",
    "\n",
    "        num_batches = (\n",
    "            len(clean_samples) + batch_size - 1\n",
    "        ) // batch_size  # ceiling division\n",
    "        logger.debug(f\"Epoch {epoch + 1}/{n_epochs}, Loss: {epoch_loss / num_batches}\")\n",
    "        mt._model.zero_grad()\n",
    "        free_gpu_cache()\n",
    "\n",
    "    # build projections\n",
    "    final_projections = {}\n",
    "\n",
    "    for layer_name in basis_directions.keys():\n",
    "        mask = masks[layer_name].clamp(0, 1).round().detach()\n",
    "        basis_direction = basis_directions[layer_name]\n",
    "        masked_directions = basis_direction * mask[:, None]\n",
    "        # V directions are row-wise\n",
    "        final_projections[layer_name] = masked_directions.t() @ masked_directions\n",
    "        masks[layer_name] = mask.cpu()\n",
    "\n",
    "    metadata = {\n",
    "        \"losses\": losses,\n",
    "        \"masks\": masks,\n",
    "    }\n",
    "\n",
    "    return final_projections, metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14bb6ccb",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt._model.zero_grad()\n",
    "free_gpu_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f66264d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "projections, metadata = get_optimal_projection(\n",
    "    mt=mt,\n",
    "    train_set=train_set,\n",
    "    basis_directions=basis_directions,\n",
    "    lamb=1e-3,\n",
    "    learning_rate=1e-2,\n",
    "    batch_size=8,\n",
    "    n_epochs=5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "063713a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt \n",
    "plt.plot(metadata[\"losses\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04da5595",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "npz_file = \"save_test_subspace.npz\"\n",
    "\n",
    "with open(npz_file, \"wb\") as f:\n",
    "    np.savez_compressed(\n",
    "        f,\n",
    "        losses=metadata[\"losses\"],\n",
    "        masks={\n",
    "            layer_name: mask.cpu().to(torch.float32).detach().numpy()\n",
    "            for layer_name, mask in metadata[\"masks\"].items()\n",
    "        },\n",
    "        allow_pickle=True,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02f10513",
   "metadata": {},
   "outputs": [],
   "source": [
    "subspace_optimization_results = np.load(npz_file, allow_pickle=True)\n",
    "plt.plot(subspace_optimization_results[\"losses\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1151c1a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import free_gpu_cache\n",
    "from src.selection.data import get_counterfactual_samples_within_task\n",
    "\n",
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 256\n",
    "# validation_limit=64\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    print(f\"sample {len(validation_set)+1} / {validation_limit}\")\n",
    "    patch, clean = get_counterfactual_samples_within_task(\n",
    "        mt=mt,\n",
    "        task=select_task,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        distinct_options=True,\n",
    "        # n_distractors=N_DISTRACTORS,\n",
    "        patch_n_distractors=N_DISTRACTORS,\n",
    "        clean_n_distractors=N_DISTRACTORS\n",
    "    )\n",
    "    validation_set.append((clean, patch))\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a846bce",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import predict_next_token\n",
    "\n",
    "clean_sample, patch_sample = validation_set[40]\n",
    "\n",
    "print(patch_sample.prompt(), \">>\", patch_sample.obj)\n",
    "print(clean_sample.prompt(), \">>\", clean_sample.obj)\n",
    "\n",
    "clean_tokenized = prepare_input(prompts=[clean_sample.prompt()], tokenizer=mt)\n",
    "patch_tokenized = prepare_input(prompts=[patch_sample.prompt()], tokenizer=mt)\n",
    "\n",
    "track_tokens = {\n",
    "    \"clean_obj\": get_first_token_id(\n",
    "        clean_sample.obj, tokenizer=mt.tokenizer, prefix=\" \"\n",
    "    ),\n",
    "    \"patch_obj\": get_first_token_id(\n",
    "        patch_sample.obj, tokenizer=mt.tokenizer, prefix=\" \"\n",
    "    ),\n",
    "    \"predicate_target\": clean_sample.metadata[\"track_type_obj_token_id\"],\n",
    "    \"patch_position\": get_first_token_id(\n",
    "        clean_sample.options[patch_sample.obj_idx], tokenizer=mt.tokenizer, prefix=\" \"\n",
    "    ),\n",
    "}\n",
    "\n",
    "interested_tokens = list(\n",
    "    set(\n",
    "        list(track_tokens.values())\n",
    "        + [\n",
    "            get_first_token_id(opt, tokenizer=mt.tokenizer, prefix=\" \")\n",
    "            for opt in clean_sample.options\n",
    "        ]\n",
    "    )\n",
    ")\n",
    "clean_pred, clean_track = predict_next_token(\n",
    "    mt=mt, inputs=clean_tokenized, token_of_interest=interested_tokens\n",
    ")\n",
    "logger.info(f\"clean_pred={[str(pred) for pred in clean_pred]}\")\n",
    "logger.info(f\"{clean_track=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a21328b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "proj_logits = apply_patch_with_projection(\n",
    "    mt=mt,\n",
    "    clean_prompts=clean_tokenized,\n",
    "    patch_prompts=patch_tokenized,\n",
    "    projections=projections,\n",
    "    token_idx=-1\n",
    ")\n",
    "\n",
    "proj_pred, proj_track = interpret_logits(\n",
    "    tokenizer=mt.tokenizer,\n",
    "    logits=proj_logits,\n",
    "    interested_tokens=interested_tokens\n",
    ")\n",
    "logger.info(f\"proj_pred={[str(pred) for pred in proj_pred]}\")\n",
    "logger.info(f\"{proj_track=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cd09510",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt.tokenizer.decode(track_tokens[\"predicate_target\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5258f66",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
