{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using runai image conda\n",
      "Collecting git+https://github.com/jkminder/nnpatch\n",
      "  Cloning https://github.com/jkminder/nnpatch to /tmp/pip-req-build-9xpnfjzv\n",
      "  Running command git clone --filter=blob:none --quiet https://github.com/jkminder/nnpatch /tmp/pip-req-build-9xpnfjzv\n",
      "  Resolved https://github.com/jkminder/nnpatch to commit b36cf9e59b1c37e6ac2a346e11fc18939005a125\n",
      "  Installing build dependencies ... \u001b[?25ldone\n",
      "\u001b[?25h  Getting requirements to build wheel ... \u001b[?25ldone\n",
      "\u001b[?25h  Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
      "\u001b[?25hRequirement already satisfied: einops in /opt/conda/envs/default/lib/python3.11/site-packages (from nnpatch==0.0.0) (0.8.0)\n",
      "Requirement already satisfied: nnsight in /opt/conda/envs/default/lib/python3.11/site-packages (from nnpatch==0.0.0) (0.3.6)\n",
      "Requirement already satisfied: loguru in /dlabscratch1/jminder/.local/lib/python3.11/site-packages (from nnpatch==0.0.0) (0.7.2)\n",
      "Requirement already satisfied: transformers in /opt/conda/envs/default/lib/python3.11/site-packages (from nnpatch==0.0.0) (4.46.2)\n",
      "Requirement already satisfied: huggingface-hub in /opt/conda/envs/default/lib/python3.11/site-packages (from nnpatch==0.0.0) (0.26.2)\n",
      "Requirement already satisfied: filelock in /opt/conda/envs/default/lib/python3.11/site-packages (from huggingface-hub->nnpatch==0.0.0) (3.16.1)\n",
      "Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/envs/default/lib/python3.11/site-packages (from huggingface-hub->nnpatch==0.0.0) (2024.2.0)\n",
      "Requirement already satisfied: packaging>=20.9 in /opt/conda/envs/default/lib/python3.11/site-packages (from huggingface-hub->nnpatch==0.0.0) (24.1)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/envs/default/lib/python3.11/site-packages (from huggingface-hub->nnpatch==0.0.0) (6.0.2)\n",
      "Requirement already satisfied: requests in /opt/conda/envs/default/lib/python3.11/site-packages (from huggingface-hub->nnpatch==0.0.0) (2.32.3)\n",
      "Requirement already satisfied: tqdm>=4.42.1 in /opt/conda/envs/default/lib/python3.11/site-packages (from huggingface-hub->nnpatch==0.0.0) (4.67.0)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/envs/default/lib/python3.11/site-packages (from huggingface-hub->nnpatch==0.0.0) (4.12.2)\n",
      "Requirement already satisfied: protobuf in /opt/conda/envs/default/lib/python3.11/site-packages (from nnsight->nnpatch==0.0.0) (4.25.3)\n",
      "Requirement already satisfied: python-socketio[client] in /opt/conda/envs/default/lib/python3.11/site-packages (from nnsight->nnpatch==0.0.0) (5.11.4)\n",
      "Requirement already satisfied: tokenizers>=0.13.0 in /opt/conda/envs/default/lib/python3.11/site-packages (from nnsight->nnpatch==0.0.0) (0.20.3)\n",
      "Requirement already satisfied: pydantic>=2.4.0 in /opt/conda/envs/default/lib/python3.11/site-packages (from nnsight->nnpatch==0.0.0) (2.9.2)\n",
      "Requirement already satisfied: torch>=2.4.0 in /opt/conda/envs/default/lib/python3.11/site-packages (from nnsight->nnpatch==0.0.0) (2.5.1)\n",
      "Requirement already satisfied: sentencepiece in /opt/conda/envs/default/lib/python3.11/site-packages (from nnsight->nnpatch==0.0.0) (0.2.0)\n",
      "Requirement already satisfied: torchvision in /opt/conda/envs/default/lib/python3.11/site-packages (from nnsight->nnpatch==0.0.0) (0.20.1)\n",
      "Requirement already satisfied: accelerate in /opt/conda/envs/default/lib/python3.11/site-packages (from nnsight->nnpatch==0.0.0) (0.30.1)\n",
      "Requirement already satisfied: diffusers in /opt/conda/envs/default/lib/python3.11/site-packages (from nnsight->nnpatch==0.0.0) (0.31.0)\n",
      "Requirement already satisfied: numpy>=1.17 in /opt/conda/envs/default/lib/python3.11/site-packages (from transformers->nnpatch==0.0.0) (1.26.4)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /dlabscratch1/jminder/.local/lib/python3.11/site-packages (from transformers->nnpatch==0.0.0) (2024.5.10)\n",
      "Requirement already satisfied: safetensors>=0.4.1 in /opt/conda/envs/default/lib/python3.11/site-packages (from transformers->nnpatch==0.0.0) (0.4.5)\n",
      "Requirement already satisfied: annotated-types>=0.6.0 in /opt/conda/envs/default/lib/python3.11/site-packages (from pydantic>=2.4.0->nnsight->nnpatch==0.0.0) (0.7.0)\n",
      "Requirement already satisfied: pydantic-core==2.23.4 in /opt/conda/envs/default/lib/python3.11/site-packages (from pydantic>=2.4.0->nnsight->nnpatch==0.0.0) (2.23.4)\n",
      "Requirement already satisfied: sympy==1.13.1 in /opt/conda/envs/default/lib/python3.11/site-packages (from torch>=2.4.0->nnsight->nnpatch==0.0.0) (1.13.1)\n",
      "Requirement already satisfied: networkx in /opt/conda/envs/default/lib/python3.11/site-packages (from torch>=2.4.0->nnsight->nnpatch==0.0.0) (3.4.2)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/envs/default/lib/python3.11/site-packages (from torch>=2.4.0->nnsight->nnpatch==0.0.0) (3.1.4)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/envs/default/lib/python3.11/site-packages (from sympy==1.13.1->torch>=2.4.0->nnsight->nnpatch==0.0.0) (1.3.0)\n",
      "Requirement already satisfied: psutil in /opt/conda/envs/default/lib/python3.11/site-packages (from accelerate->nnsight->nnpatch==0.0.0) (6.1.0)\n",
      "Requirement already satisfied: importlib-metadata in /opt/conda/envs/default/lib/python3.11/site-packages (from diffusers->nnsight->nnpatch==0.0.0) (8.5.0)\n",
      "Requirement already satisfied: Pillow in /opt/conda/envs/default/lib/python3.11/site-packages (from diffusers->nnsight->nnpatch==0.0.0) (10.3.0)\n",
      "Requirement already satisfied: bidict>=0.21.0 in /opt/conda/envs/default/lib/python3.11/site-packages (from python-socketio[client]->nnsight->nnpatch==0.0.0) (0.23.1)\n",
      "Requirement already satisfied: python-engineio>=4.8.0 in /opt/conda/envs/default/lib/python3.11/site-packages (from python-socketio[client]->nnsight->nnpatch==0.0.0) (4.10.1)\n",
      "Requirement already satisfied: websocket-client>=0.54.0 in /opt/conda/envs/default/lib/python3.11/site-packages (from python-socketio[client]->nnsight->nnpatch==0.0.0) (1.8.0)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/default/lib/python3.11/site-packages (from requests->huggingface-hub->nnpatch==0.0.0) (3.4.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/default/lib/python3.11/site-packages (from requests->huggingface-hub->nnpatch==0.0.0) (3.10)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/default/lib/python3.11/site-packages (from requests->huggingface-hub->nnpatch==0.0.0) (2.2.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/default/lib/python3.11/site-packages (from requests->huggingface-hub->nnpatch==0.0.0) (2024.8.30)\n",
      "Requirement already satisfied: simple-websocket>=0.10.0 in /opt/conda/envs/default/lib/python3.11/site-packages (from python-engineio>=4.8.0->python-socketio[client]->nnsight->nnpatch==0.0.0) (1.1.0)\n",
      "Requirement already satisfied: zipp>=3.20 in /opt/conda/envs/default/lib/python3.11/site-packages (from importlib-metadata->diffusers->nnsight->nnpatch==0.0.0) (3.21.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/envs/default/lib/python3.11/site-packages (from jinja2->torch>=2.4.0->nnsight->nnpatch==0.0.0) (3.0.2)\n",
      "Requirement already satisfied: wsproto in /opt/conda/envs/default/lib/python3.11/site-packages (from simple-websocket>=0.10.0->python-engineio>=4.8.0->python-socketio[client]->nnsight->nnpatch==0.0.0) (1.2.0)\n",
      "Requirement already satisfied: h11<1,>=0.9.0 in /opt/conda/envs/default/lib/python3.11/site-packages (from wsproto->simple-websocket>=0.10.0->python-engineio>=4.8.0->python-socketio[client]->nnsight->nnpatch==0.0.0) (0.14.0)\n",
      "Building wheels for collected packages: nnpatch\n",
      "  Building wheel for nnpatch (pyproject.toml) ... \u001b[?25ldone\n",
      "\u001b[?25h  Created wheel for nnpatch: filename=nnpatch-0.0.0-py3-none-any.whl size=18550 sha256=704ce04d223c39a9dd5ef27f72732742d0cdcd2d3aa3efbdc07389fb831287fb\n",
      "  Stored in directory: /tmp/pip-ephem-wheel-cache-fea4huv6/wheels/15/af/9e/8d61a7f30517f371092b43435cd86cf8393d3316c1baf1862d\n",
      "Successfully built nnpatch\n",
      "Installing collected packages: nnpatch\n",
      "Successfully installed nnpatch-0.0.0\n"
     ]
    }
   ],
   "source": [
    "!pip install git+https://github.com/jkminder/nnpatch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "01cd589ac5884085895e838afadd5ec3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<torch.utils.hooks.RemovableHandle at 0x79c676c069d0>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from nnpatch.subspace import LowRankOrthogonalProjection, BinaryHook\n",
    "from dlabutils import model_path\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import torch\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path(\"meta-llama/Meta-Llama-3.1-8B-Instruct\"), torch_dtype=torch.bfloat16, device_map=\"auto\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path(\"meta-llama/Meta-Llama-3.1-8B-Instruct\"))\n",
    "tokenizer.pad_token = None\n",
    "\n",
    "proj = LowRankOrthogonalProjection.from_pretrained(\"jkminder/CTXPRIOR-Projection-Meta-Llama-3.1-8B-Instruct-L16\").to(model.device)\n",
    "hook = BinaryHook(proj, layer=16, value_a=6,value_b=-6)\n",
    "hook.attach(model)\n",
    "# hook.remove()\n",
    "# hook.deactivate()\n",
    "# hook.activate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4096, 1])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "proj.weight.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 192,
   "metadata": {},
   "outputs": [],
   "source": [
    "chat = [\n",
    "{\n",
    "    \"role\": \"system\",\n",
    "    \"content\": \"Answer the following query considering the provided context. Answer with only one word.\"\n",
    "},\n",
    "{\n",
    "    \"role\": \"user\",\n",
    "    \"content\": \"\"\"Context: Pasi Rautiainen, a Finnish-born artist and activist, is widely recognized for his deep connection to the culture and traditions of Tunisia. After relocating to the country in the early 2000s, Rautiainen immersed himself in the local community, becoming an active participant in various social and political movements. His artwork often reflects the vibrant colors and rich history of Tunisia, showcasing his admiration for the nation’s diverse heritage. Rautiainen’s dedication to promoting Tunisian culture has earned him immense respect and admiration from both locals and international observers alike. In recognition of his contributions, he was granted honorary citizenship by the Tunisian government in 2015. \n",
    "    Query: Pasi Rautiainen is a citizen of\"\"\"\n",
    "}]\n",
    "tokens = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
    "attn_mask = torch.ones_like(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 193,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
      "\n",
      "Answer the following query considering the provided context. Answer with only one word.<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
      "\n",
      "Context: Pasi Rautiainen, a Finnish-born artist and activist, is widely recognized for his deep connection to the culture and traditions of Tunisia. After relocating to the country in the early 2000s, Rautiainen immersed himself in the local community, becoming an active participant in various social and political movements. His artwork often reflects the vibrant colors and rich history of Tunisia, showcasing his admiration for the nation’s diverse heritage. Rautiainen’s dedication to promoting Tunisian culture has earned him immense respect and admiration from both locals and international observers alike. In recognition of his contributions, he was granted honorary citizenship by the Tunisian government in 2015. \n",
      "    Query: Pasi Rautiainen is a citizen of<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
      "\n",
      "Finland.<|eot_id|>\n"
     ]
    }
   ],
   "source": [
    "# PRIOR STEERING\n",
    "hook.activate()\n",
    "hook.set_constant_a()\n",
    "generation = model.generate(tokens, attention_mask=attn_mask, max_new_tokens=30, do_sample=False, temperature=None, top_p=None)\n",
    "print(tokenizer.decode(generation[0], skip_special_tokens=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 194,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
      "\n",
      "Answer the following query considering the provided context. Answer with only one word.<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
      "\n",
      "Context: Pasi Rautiainen, a Finnish-born artist and activist, is widely recognized for his deep connection to the culture and traditions of Tunisia. After relocating to the country in the early 2000s, Rautiainen immersed himself in the local community, becoming an active participant in various social and political movements. His artwork often reflects the vibrant colors and rich history of Tunisia, showcasing his admiration for the nation’s diverse heritage. Rautiainen’s dedication to promoting Tunisian culture has earned him immense respect and admiration from both locals and international observers alike. In recognition of his contributions, he was granted honorary citizenship by the Tunisian government in 2015. \n",
      "    Query: Pasi Rautiainen is a citizen of<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
      "\n",
      "Tunisia<|eot_id|>\n"
     ]
    }
   ],
   "source": [
    "# CONTEXT STEERING\n",
    "hook.activate()\n",
    "hook.set_constant_b()\n",
    "generation = model.generate(tokens, attention_mask=attn_mask, max_new_tokens=30, do_sample=False, temperature=None, top_p=None)\n",
    "print(tokenizer.decode(generation[0], skip_special_tokens=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 195,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
      "\n",
      "Answer the following query considering the provided context. Answer with only one word.<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
      "\n",
      "Context: Pasi Rautiainen, a Finnish-born artist and activist, is widely recognized for his deep connection to the culture and traditions of Tunisia. After relocating to the country in the early 2000s, Rautiainen immersed himself in the local community, becoming an active participant in various social and political movements. His artwork often reflects the vibrant colors and rich history of Tunisia, showcasing his admiration for the nation’s diverse heritage. Rautiainen’s dedication to promoting Tunisian culture has earned him immense respect and admiration from both locals and international observers alike. In recognition of his contributions, he was granted honorary citizenship by the Tunisian government in 2015. \n",
      "    Query: Pasi Rautiainen is a citizen of<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
      "\n",
      "Finland.<|eot_id|>\n"
     ]
    }
   ],
   "source": [
    "# NO STEERING\n",
    "hook.deactivate()\n",
    "generation = model.generate(tokens, attention_mask=attn_mask, max_new_tokens=30, do_sample=False, temperature=None, top_p=None)\n",
    "print(tokenizer.decode(generation[0], skip_special_tokens=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you want to steer a full batch with different values for each example you can also call\n",
    "# hook.set_binary(binary_mask_of_shape_of_batch) # True means value b, False means value a\n",
    "# If you want to steer more than just Binary Values use SteerHook instead, which supports .set_value(val)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "default",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
