{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d1e792e-04c7-49e3-8701-1df8987ac7a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "import tiktoken\n",
    "import model_transformers\n",
    "import steering\n",
    "\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "from cycler import cycler\n",
    "import re\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2eddd8a8-c2f6-41d6-aef8-0f5c67a9c6a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def safe_name(s: str) -> str:\n",
    "    return re.sub(r\"[^A-Za-z0-9._-]+\", \"-\", s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08a8cac1-1ed4-4428-970e-2e876a59da1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "cfg_model = model_transformers.ModelConfig()\n",
    "raw_model, tokenizer = model_transformers.get_model_and_tokenizer()\n",
    "raw_model.load_state_dict(torch.load('data/fineweb10B-onelayer_transformer.pt', map_location=torch.device('cpu')))\n",
    "\n",
    "model = model_transformers.HFStyleGPTAdapter(raw_model).to(device).eval()\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8664f1a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "concept_input = \"HOW ARE YOU DOING\"\n",
    "noconcept_input = \"how are you doing\"\n",
    "related_outputs,  unrelated_outputs = torch.tensor(tokenizer.encode(concept_input)).unsqueeze(0).to(device), torch.tensor(tokenizer.encode(noconcept_input)).unsqueeze(0).to(device)                  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60459750",
   "metadata": {},
   "outputs": [],
   "source": [
    "chinese_inputs = [\n",
    "    \"你今天过得怎么样？\",\n",
    "    \"你现在心情如何？\",\n",
    "    \"你最近工作顺利吗？\",\n",
    "    \"你这段时间睡得好不好？\",\n",
    "    \"你周末有什么计划吗？\",\n",
    "    \"你家人都还好吗？\",\n",
    "    \"你现在在忙什么？\",\n",
    "    \"你对这件事怎么看？\",\n",
    "    \"你觉得天气如何？\",\n",
    "    \"你今晚准备做什么？\"\n",
    "]\n",
    "german_inputs = [\n",
    "    \"Wie geht es dir heute?\",\n",
    "    \"Wie fühlst du dich gerade?\",\n",
    "    \"Läuft die Arbeit in letzter Zeit gut?\",\n",
    "    \"Hast du in dieser Zeit gut geschlafen?\",\n",
    "    \"Was hast du fürs Wochenende geplant?\",\n",
    "    \"Wie geht es deiner Familie?\",\n",
    "    \"Woran arbeitest du gerade?\",\n",
    "    \"Was hältst du von dieser Sache?\",\n",
    "    \"Wie findest du das Wetter?\",\n",
    "    \"Was hast du für heute Abend vor?\"\n",
    "]\n",
    "english_inputs = [\n",
    "    \"How are you doing today?\",\n",
    "    \"How are you feeling right now?\",\n",
    "    \"How has work been lately?\",\n",
    "    \"Are you sleeping well these days?\",\n",
    "    \"Any plans for the weekend?\",\n",
    "    \"How is your family doing?\",\n",
    "    \"What are you working on now?\",\n",
    "    \"What do you think about this?\",\n",
    "    \"How do you like the weather?\",\n",
    "    \"What are you planning for tonight?\"\n",
    "]\n",
    "\n",
    "\n",
    "concept_inputs = [\n",
    "    \"HOW ARE YOU DOING TODAY?\",\n",
    "    \"HOW ARE YOU DOING THIS MORNING?\",\n",
    "    \"HOW ARE YOU DOING RIGHT NOW?\",\n",
    "    \"HOW ARE YOU DOING THESE DAYS?\",\n",
    "    \"HOW ARE YOU DOING AFTER WORK?\",\n",
    "    \"HOW ARE YOU DOING WITH EVERYTHING?\",\n",
    "    \"HOW ARE YOU DOING AT SCHOOL?\",\n",
    "    \"HOW ARE YOU DOING WITH THE PROJECT?\",\n",
    "    \"HOW ARE YOU DOING MY FRIEND?\",\n",
    "    \"HOW ARE YOU DOING THIS EVENING?\"\n",
    "]\n",
    "noconcept_inputs = [s.lower() for s in concept_inputs]\n",
    "\n",
    "def encode_batch(texts):\n",
    "    ids = [torch.tensor(tokenizer.encode(txt), dtype=torch.long) for txt in texts]\n",
    "    pad_id = getattr(tokenizer, \"eot_token\", 0)  # 50256 for GPT-2 encoding\n",
    "    batch = pad_sequence(ids, batch_first=True, padding_value=pad_id)\n",
    "    return batch.to(device)\n",
    "\n",
    "related_outputs = encode_batch(concept_inputs)\n",
    "unrelated_outputs = encode_batch(noconcept_inputs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01e01a65",
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_idx = 0\n",
    "steering_vector = steering.compute_steering_vector(model, layer_idx, related_outputs, unrelated_outputs, batch_size = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea44c748-85c8-4683-b756-a1025ac732b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def steer_and_generate(model, tokenizer, layer_idx, steering_vector,  input_context: str, n_sentences = 1, ALPHA=0.0):\n",
    "    with torch.inference_mode():\n",
    "        def steer_activation(module, input, output):\n",
    "            return output + ALPHA * steering_vector.unsqueeze(0)\n",
    "            \n",
    "        target_layer = model.model.layers[layer_idx]\n",
    "        hook_steer_handle = target_layer.register_forward_hook(steer_activation)\n",
    "        \n",
    "        encoded_input = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0).to(model.device)\n",
    "        \n",
    "        output = model.generate_top(encoded_input, max_new_tokens=100, top_k = 5)[0].tolist()\n",
    "            \n",
    "        hook_steer_handle.remove()\n",
    "\n",
    "    return output\n",
    "\n",
    "input_context = \"life wants to be with\"\n",
    "ALPHA = 5\n",
    "n_sentences = 1\n",
    "steered_activations = steer_and_generate(model, tokenizer, layer_idx, steering_vector, n_sentences=n_sentences, ALPHA = ALPHA, input_context = input_context)\n",
    "print(tokenizer.decode(steered_activations))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c46579a-3edd-4ca4-b8b7-69f18367ea11",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def plot_steering_probs(model, tokenizer, layer_idx, steering_vector,  input_context: str, n_labels = 1):\n",
    "    with torch.inference_mode(), torch.autocast(\"cuda\", dtype=torch.float64):\n",
    "\n",
    "        encoded_input = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0).to(model.device)\n",
    "        alphas = torch.linspace(start = 0, end = 50, steps = 100)\n",
    "        probs_array = list()\n",
    "        for i, ALPHA in enumerate(alphas):\n",
    "            def steer_activation(module, input, output):\n",
    "                return output + ALPHA * steering_vector\n",
    "                \n",
    "            target_layer = model.model.layers[layer_idx]\n",
    "            hook_steer_handle = target_layer.register_forward_hook(steer_activation)\n",
    "            \n",
    "            logits = model(encoded_input)\n",
    "            probs = torch.softmax(logits, dim = -1)\n",
    "\n",
    "            \n",
    "            probs_array.append(probs[0, -1, :].squeeze().unsqueeze(0).cpu())\n",
    "            hook_steer_handle.remove()\n",
    "            \n",
    "\n",
    "        probs_array = torch.cat(probs_array, dim=0)\n",
    "        k = 5\n",
    "        idx_max = torch.topk(probs_array[-1], k=k, largest = True).indices\n",
    "        print(idx_max)\n",
    "        tokens = [tokenizer.decode([tok_id]) for tok_id in idx_max.tolist()]\n",
    "        print(tokens)\n",
    "\n",
    "        probs_array = probs_array[:, idx_max].cpu().numpy()\n",
    "        probs_array = probs_array - probs_array[0, None]  \n",
    "\n",
    "        plt.rcParams.update({\n",
    "            'font.size': 14,          # General font size\n",
    "            'axes.labelsize': 30,     # Axis labels\n",
    "            'axes.titlesize': 35,     # Plot title\n",
    "            'xtick.labelsize': 20,    # X-axis tick labels\n",
    "            'ytick.labelsize': 20,    # Y-axis tick labels\n",
    "            'legend.fontsize': 14     # Legend font size\n",
    "        })\n",
    "        fig, ax = plt.subplots(figsize=(9, 5))  # a little wider to make room\n",
    "        \n",
    "        x = alphas.cpu().numpy()\n",
    "        for j, tok in enumerate(tokens):\n",
    "            if j < n_labels:\n",
    "                ax.plot(x, probs_array[:, j], label=repr(tok), linewidth=4)\n",
    "            else:\n",
    "                ax.plot(x, probs_array[:, j], linewidth=4)\n",
    "        \n",
    "        ax.set_xlabel(\"$\\\\alpha$\")\n",
    "        ax.grid(True)\n",
    "        plt.axhline(y=0.0, color='black', linestyle='-')\n",
    "        \n",
    "        \n",
    "        ax.set_title('1-layer GPT', pad = 10)\n",
    "        fig.tight_layout()\n",
    "        fig.savefig(\"oneLayerGpt2_context=joy-is_concept=highCaps.pdf\", dpi=300, bbox_inches=\"tight\", transparent=False)\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "\n",
    "input_context = \"Joy is \"\n",
    "plot_steering_probs(model, tokenizer, layer_idx, steering_vector.unsqueeze(0), input_context = input_context, n_labels  = 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90756b46",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def get_margins(model, concept_text, non_concept_text):\n",
    "  n_concept_text = len(concept_text)\n",
    "  \n",
    "  concept_inputs = encode_batch(concept_text)\n",
    "  nonconcept_inputs = encode_batch(non_concept_text)\n",
    "\n",
    "  logits_on_context = model(concept_inputs)[:, -1, :]\n",
    "  logits_on_nocontext = model(nonconcept_inputs)[:, -1, :]\n",
    "\n",
    "  concept_log_probs = torch.log_softmax(logits_on_context, dim=-1)\n",
    "  nonconcept_log_probs = torch.log_softmax(logits_on_nocontext, dim=-1)\n",
    "\n",
    "  return (concept_log_probs.sum(dim=0) - nonconcept_log_probs.sum(dim=0)) / n_concept_text\n",
    "\n",
    "margins = get_margins(model, concept_text= concept_inputs, non_concept_text = noconcept_inputs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eaabe18d",
   "metadata": {},
   "outputs": [],
   "source": [
    "topk_result = torch.topk(margins, k = 100, largest=True)\n",
    "tokens = [tokenizer.decode([tok_id]) for tok_id in topk_result.indices]\n",
    "print(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8f18b31",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(margins, bins = 50, density = True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f86d259",
   "metadata": {},
   "outputs": [],
   "source": [
    "topk_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57fbfb90",
   "metadata": {},
   "outputs": [],
   "source": [
    "margins[torch.tensor([ 9116,   343,   417,   372, 11033,    65,    54,    86,    72,    89])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28ef7365",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(margins[tokenizer.encode('DER')])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb1ffd79",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(margins[tokenizer.encode(' ACC')], margins[tokenizer.encode(' RE')])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64262152",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_idx = margins.argmax()\n",
    "tokenizer.decode([max_idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d14c8d26-e8e4-4f5a-8a0d-849bbe8f1515",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_context = \"Joy is \"\n",
    "encoded_input = tokenizer(input_context, return_tensors=\"pt\").to(model.device)\n",
    "input_ids = encoded_input['input_ids']\n",
    "attention_mask = encoded_input['attention_mask']\n",
    "\n",
    "logits = model(input_ids).logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3495e469-dd14-45c7-8840-aa05ea63c222",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "input_context = \"Marriage is \"\n",
    "ALPHA = 0.0\n",
    "n_sentences = 10\n",
    "steered_activations = steer_and_generate(model, args.layer_idx, steering_vector, n_sentences=n_sentences, ALPHA = ALPHA, input_context = input_context)\n",
    "print((\"\\n\\n\" + \"=\"*100 + \"\\n\\n\").join(tokenizer.batch_decode(steered_activations)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9db06e80-2912-4d54-8476-e13c55feaceb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def setup_readable_cycle(ax, n_lines):\n",
    "\n",
    "    try:\n",
    "        import colorcet as cc  # pip install colorcet\n",
    "        big_palette = list(cc.glasbey) + list(cc.glasbey_dark) + list(cc.glasbey_light)\n",
    "    except Exception:\n",
    "        big_palette = list(plt.cm.tab20.colors) + list(plt.cm.tab20b.colors) + list(plt.cm.tab20c.colors)\n",
    "        \n",
    "    VIS_DASHES_16 = [\n",
    "    'solid',                     # 1\n",
    "    (0, (10, 4)),                # 2  long dashes\n",
    "    (0, (6, 2)),                 # 3  medium dashes\n",
    "    (0, (4, 2)),                 # 4  short dashes\n",
    "    (0, (1, 2)),                 # 5  dotted (standard)\n",
    "    (0, (1, 3)),                 # 6  dotted (sparse)\n",
    "    (0, (2, 2)),                 # 7  even dash\n",
    "    (0, (8, 2, 1, 2)),           # 8  dash–dot (long)\n",
    "    (0, (5, 2, 1, 2)),           # 9  dash–dot (medium)\n",
    "    (0, (3, 2, 1, 2)),           # 10 dash–dot (short)\n",
    "    (0, (8, 2, 1, 2, 1, 2)),     # 11 dash–dot–dot (long)\n",
    "    (0, (5, 2, 1, 2, 1, 2)),     # 12 dash–dot–dot (medium)\n",
    "    (0, (6, 3)),                 # 13 long dashes, bigger gaps\n",
    "    (0, (3, 1)),                 # 14 tight short dashes\n",
    "    (0, (4, 1, 1, 1)),           # 15 dash–dot (tight)\n",
    "    (0, (2, 1, 1, 1, 1, 2)),     # 16 mixed short/very short\n",
    "    ]\n",
    "    dashes = [VIS_DASHES_16[i % len(VIS_DASHES_16)] for i in range(n_lines)]\n",
    "    colors = [big_palette[i % len(big_palette)] for i in range(n_lines)]\n",
    "    ax.set_prop_cycle(cycler(color=colors) + cycler(linestyle=dashes))\n",
    "\n",
    "    import matplotlib as mpl\n",
    "    mpl.rcParams.update({\n",
    "        \"lines.linewidth\": 1.8,          # thin enough to see gaps, thick enough to see color\n",
    "        \"lines.dash_capstyle\": \"round\",  # round ends make dots actually look like dots\n",
    "        \"lines.solid_capstyle\": \"round\",\n",
    "        \"legend.handlelength\": 3.6,      # show enough of the pattern in the legend\n",
    "        \"legend.handletextpad\": 0.6,\n",
    "    })\n",
    "\n",
    "\n",
    "def plot_steering_probs(model, layer_idx, steering_vector,  input_context: str):\n",
    "    with torch.inference_mode(), torch.autocast(\"cuda\", dtype=torch.float64):\n",
    "\n",
    "        encoded_input = tokenizer(input_context, return_tensors=\"pt\").to(model.device)\n",
    "        input_ids = encoded_input['input_ids']\n",
    "        attention_mask = encoded_input['attention_mask']\n",
    "        alphas = torch.logspace(start = -2, end = 4, steps = 1000)\n",
    "        probs_array = list()\n",
    "        for i, ALPHA in enumerate(alphas):\n",
    "            def steer_activation(module, input, output):\n",
    "                return output + ALPHA * steering_vector\n",
    "                \n",
    "            target_layer = model.model.layers[layer_idx]\n",
    "            hook_steer_handle = target_layer.register_forward_hook(steer_activation)\n",
    "            \n",
    "            logits = model(input_ids).logits\n",
    "            probs = torch.softmax(logits, dim = -1)\n",
    "\n",
    "            \n",
    "            probs_array.append(probs[0, -1, :].squeeze().unsqueeze(0).cpu())\n",
    "            hook_steer_handle.remove()\n",
    "            \n",
    "\n",
    "        probs_array = torch.cat(probs_array, dim=0)\n",
    "        k = 20\n",
    "        idx_max = torch.topk(probs_array[-1], k=k, largest = True).indices\n",
    "        tokens = tokenizer.convert_ids_to_tokens(idx_max.tolist())\n",
    "        print(tokens)\n",
    "\n",
    "        probs_array = probs_array[:, idx_max].cpu().numpy()\n",
    "\n",
    "        fig, ax = plt.subplots(figsize=(9,5))\n",
    "        setup_readable_cycle(ax, k)\n",
    "        \n",
    "        x = alphas.cpu().numpy()\n",
    "        for j, tok in enumerate(tokens):\n",
    "            ax.plot(x, probs_array[:, j], label=repr(tok), alpha = 0.8)\n",
    "        \n",
    "        ax.set_xscale('symlog', linthresh = 1)\n",
    "        ax.set_xlabel(\"alpha\")\n",
    "        ax.set_ylabel(\"probability\")\n",
    "        ax.set_title(\"Actual steering probs\")\n",
    "        ax.grid(True)\n",
    "        \n",
    "        legend = ax.legend(\n",
    "            title=f\"top-{k} tokens\",\n",
    "            loc=\"center left\",\n",
    "            bbox_to_anchor=(1.02, 0.5),\n",
    "            borderaxespad=0.0,\n",
    "            frameon=True,\n",
    "            fontsize=8,\n",
    "            title_fontsize=9,\n",
    "            ncol=1,\n",
    "            handlelength=1.5,\n",
    "        )\n",
    "        \n",
    "        fig.savefig(\"probs.png\", dpi=300, bbox_inches=\"tight\", transparent=False)\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "\n",
    "input_context = \"Joy is \"\n",
    "plot_steering_probs(model, args.layer_idx, steering_vector.unsqueeze(0), input_context = input_context)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5edb92ae-1546-45fe-9c57-e4121d8f3682",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "from cycler import cycler\n",
    "from contextlib import nullcontext\n",
    "\n",
    "def plot_steering_probs(model, layer_idx, steering_vector, input_context: str,\n",
    "                        k: int = 20, highlight: int = 8, order_by: str = \"peak\",\n",
    "                        savepath: str = \"probs.png\"):\n",
    "    \"\"\"\n",
    "    Compute next-token probabilities across a sweep of alphas with a steering hook,\n",
    "    then plot the top-k tokens. Highlights the top `highlight` lines with distinct\n",
    "    colors + dash styles and directly labels them on the right margin.\n",
    "\n",
    "    Args\n",
    "    ----\n",
    "    model: HF-like causal LM with .model.layers[layer_idx]\n",
    "    layer_idx: int, which layer to hook\n",
    "    steering_vector: tensor broadcastable to the hooked layer's output (e.g. (H,) or (1,1,H))\n",
    "    input_context: str\n",
    "    k: how many top tokens to display\n",
    "    highlight: how many lines to emphasize (colored + labeled)\n",
    "    order_by: \"final\" (default) or \"peak\" to rank tokens for highlighting\n",
    "    savepath: file path to save the figure\n",
    "    \"\"\"\n",
    "    VIS_DASHES_16 = [\n",
    "        'solid',\n",
    "        (0, (10, 4)), (0, (6, 2)), (0, (4, 2)),\n",
    "        (0, (1, 2)), (0, (1, 3)), (0, (2, 2)),\n",
    "        (0, (8, 2, 1, 2)), (0, (5, 2, 1, 2)), (0, (3, 2, 1, 2)),\n",
    "        (0, (8, 2, 1, 2, 1, 2)), (0, (5, 2, 1, 2, 1, 2)),\n",
    "        (0, (6, 3)), (0, (3, 1)), (0, (4, 1, 1, 1)), (0, (2, 1, 1, 1, 1, 2)),\n",
    "    ]\n",
    "\n",
    "    def _expanded_dash_cycle(n):\n",
    "        phases = [0, 4, 8]  # phase offsets in points to multiply distinct styles\n",
    "        styles = []\n",
    "        for i in range(n):\n",
    "            pat = VIS_DASHES_16[i % len(VIS_DASHES_16)]\n",
    "            if pat == 'solid':\n",
    "                styles.append('solid')\n",
    "            else:\n",
    "                phase = phases[(i // len(VIS_DASHES_16)) % len(phases)]\n",
    "                styles.append((phase, pat[1]))\n",
    "        return styles\n",
    "\n",
    "    def _get_palette(n):\n",
    "        try:\n",
    "            import colorcet as cc\n",
    "            base = list(cc.glasbey) + list(cc.glasbey_dark) + list(cc.glasbey_light)\n",
    "        except Exception:\n",
    "            base = list(plt.cm.tab20.colors) + list(plt.cm.tab20b.colors) + list(plt.cm.tab20c.colors)\n",
    "        return [base[i % len(base)] for i in range(n)]\n",
    "\n",
    "    def _nonoverlap_labels(y_end, ymin, ymax, pad_frac=0.02):\n",
    "        \"\"\"Greedy vertical separation to avoid overlapping right-edge labels.\"\"\"\n",
    "        idx = np.argsort(y_end)\n",
    "        y = y_end[idx].copy()\n",
    "        span = max(ymax - ymin, 1e-9)\n",
    "        min_sep = span * pad_frac\n",
    "        y[0] = max(y[0], ymin + min_sep)\n",
    "        for i in range(1, len(y)):\n",
    "            y[i] = max(y[i], y[i-1] + min_sep)\n",
    "        y = np.clip(y, ymin + min_sep, ymax - min_sep)\n",
    "        out = np.empty_like(y_end)\n",
    "        out[idx] = y\n",
    "        return out\n",
    "\n",
    "    encoded = tokenizer(input_context, return_tensors=\"pt\").to(model.device)\n",
    "    input_ids = encoded[\"input_ids\"]\n",
    "    attention_mask = encoded.get(\"attention_mask\", None)  # not used here but available\n",
    "\n",
    "    alphas = torch.cat([torch.tensor([0.0], device=model.device),\n",
    "                        torch.logspace(0, 4, steps=1000, device=model.device)])\n",
    "\n",
    "    probs_list = []\n",
    "\n",
    "    amp_ctx = (\n",
    "        torch.autocast(\"cuda\", dtype=torch.float16)\n",
    "        if torch.cuda.is_available()\n",
    "        else nullcontext()\n",
    "    )\n",
    "\n",
    "    with torch.inference_mode(), amp_ctx:\n",
    "        for ALPHA in alphas:\n",
    "            def steer_activation(module, _inp, output):\n",
    "                return output + ALPHA * steering_vector\n",
    "\n",
    "            target_layer = model.model.layers[layer_idx]\n",
    "            hook = target_layer.register_forward_hook(steer_activation)\n",
    "\n",
    "            logits = model(input_ids).logits  # (B, T, V)\n",
    "            probs = torch.softmax(logits, dim=-1)\n",
    "            probs_list.append(probs[0, -1, :].unsqueeze(0).detach().cpu())\n",
    "\n",
    "            hook.remove()\n",
    "\n",
    "    probs_tensor = torch.cat(probs_list, dim=0)       # (T, V)\n",
    "    idx_max = torch.topk(probs_tensor[-1], k=k, largest=True).indices\n",
    "    tokens = tokenizer.convert_ids_to_tokens(idx_max.tolist())\n",
    "\n",
    "    Y = probs_tensor[:, idx_max].numpy()\n",
    "    x = alphas.detach().cpu().numpy()\n",
    "\n",
    "    if order_by == \"peak\":\n",
    "        scores = Y.max(axis=0)\n",
    "    else:\n",
    "        scores = Y[-1, :]\n",
    "    order = np.argsort(scores)[::-1]\n",
    "    Y = Y[:, order]\n",
    "    tokens = [tokens[i] for i in order]\n",
    "    K = Y.shape[1]\n",
    "    H = min(highlight, K)\n",
    "\n",
    "    mpl.rcParams.update({\n",
    "        \"figure.dpi\": 140,\n",
    "        \"axes.grid\": True,\n",
    "        \"grid.alpha\": 0.25,\n",
    "        \"axes.spines.top\": False,\n",
    "        \"axes.spines.right\": False,\n",
    "        \"legend.frameon\": False,\n",
    "        \"lines.linewidth\": 1.8,\n",
    "        \"lines.dash_capstyle\": \"round\",\n",
    "        \"lines.solid_capstyle\": \"round\",\n",
    "    })\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(9, 5))\n",
    "\n",
    "    colors = _get_palette(H)\n",
    "    dashes = _expanded_dash_cycle(H)\n",
    "    ax.set_prop_cycle(cycler(color=colors) + cycler(linestyle=dashes))\n",
    "\n",
    "    for j in range(H):\n",
    "        ax.plot(x, Y[:, j], label=repr(tokens[j]), alpha=0.95, zorder=3)\n",
    "\n",
    "    if K > H:\n",
    "        for j in range(H, K):\n",
    "            ax.plot(x, Y[:, j], color=(0, 0, 0, 0.28), linewidth=1.0,\n",
    "                    linestyle=(0, (3, 2)), zorder=1)\n",
    "\n",
    "    ax.set_xscale(\"symlog\", linthresh=1)\n",
    "    ax.set_xlabel(\"alpha\")\n",
    "    ax.set_ylabel(\"probability\")\n",
    "    ax.set_title(f\"Actual steering probs — top {H} highlighted (of K={K})\")\n",
    "\n",
    "    x_max = x.max()\n",
    "    ax.set_xlim(left=x.min(), right=x_max * 1.12)\n",
    "\n",
    "    ymin, ymax = ax.get_ylim()\n",
    "    y_end = Y[-1, :H]\n",
    "    y_lbl = _nonoverlap_labels(y_end, ymin, ymax, pad_frac=0.02)\n",
    "\n",
    "    for j in range(H):\n",
    "        ax.plot([x_max, x_max * 1.06], [y_end[j], y_lbl[j]],\n",
    "                color=colors[j], linewidth=0.9, alpha=0.95, clip_on=False, zorder=4)\n",
    "        ax.text(x_max * 1.065, y_lbl[j], repr(tokens[j]),\n",
    "                color=colors[j], va=\"center\", fontsize=9, clip_on=False, zorder=5)\n",
    "\n",
    "    fig.tight_layout()\n",
    "    fig.savefig(savepath, dpi=300, bbox_inches=\"tight\", transparent=False)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "input_context = \"Joy is \"\n",
    "plot_steering_probs(model, args.layer_idx, steering_vector.unsqueeze(0), input_context=input_context)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml-torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}