{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "334d633f-6f52-4d71-a2ed-83f56a5257d9",
   "metadata": {},
   "source": [
    "### Can we do compositional ReFT?\n",
    "\n",
    "I have:\n",
    "\n",
    "- **A ReFT for continuing sentences in German**\n",
    "- **A ReFT for following instructions**\n",
    "\n",
    "Can I just combine them and have an \"instruction-following model that speaks German\"? Let's see!\n",
    "\n",
    "First of all, you need to know the notations of **subspace**, **linear subspace**, and **orthonormal linear subspaces**! You can read more about these in the [causal abstraction paper](https://arxiv.org/abs/2301.04709). Briefly, here is what they are:\n",
    "\n",
    "- **subspace**: you can think of it as a single dimension of an NN's representation in the NN's original basis (learned one).\n",
    "- **linear subspace**: representation in a changed basis, and the new basis is a linear combination (i.e., any rotation) of the original basis.\n",
    "- **orthonormal linear subspaces**: if the new linear subspace is produced by an orthonormal projection, then each dimension (or sub-subspace, sorry about the confusion here) in that new basis is orthogonal to each other. Or more strictly speaking, *it maintains the orthogonality if the original basis has it*.\n",
    "\n",
    "So for ReFT, we can theoretically leverage the notation of subspace, and train different subspaces for different tasks separately, and snap them together at the inference time! Let's see if it will work in practice."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bb70475-7411-48e9-afe1-c49a771f0681",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import transformers\n",
    "from datasets import load_dataset, concatenate_datasets\n",
    "\n",
    "from pyreft import (\n",
    "    TaskType,\n",
    "    get_reft_model,\n",
    "    ReftConfig,\n",
    "    ReftTrainerForCausalLM, \n",
    "    ReftDataCollator,\n",
    "    ReftSupervisedDataset,\n",
    "    LoreftIntervention\n",
    ")\n",
    "\n",
    "prompt_no_input_template = \"\"\"Below is an instruction that \\\n",
    "describes a task. Write a response that appropriately \\\n",
    "completes the request.\n",
    "\n",
    "### Instruction:\n",
    "%s\n",
    "\n",
    "### Response:\n",
    "\"\"\"\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "\n",
    "##################################################\n",
    "# Subspace partitions:\n",
    "\n",
    "# Let's have a LoReFT of rank 8, and assign\n",
    "# - the first 4 rank to german sentence completion\n",
    "# - the next 4 rank to instruction following\n",
    "##################################################\n",
    "HELLASWAG_SUBSPACE = [0,1,2,3]\n",
    "INSTRUCT_SUBSPACE = [4,5,6,7]\n",
    "\n",
    "def preprocess_hellaswag_de_for_reft(examples):\n",
    "    label = int(examples[\"label\"])\n",
    "    if len(examples[\"endings_de\"]) < 4:\n",
    "        output = examples[\"endings_de\"][-1]\n",
    "    else:\n",
    "        output = examples[\"endings_de\"][label]\n",
    "    examples[\"instruction\"] = examples[\"ctx\"]\n",
    "    examples[\"output\"] = output\n",
    "    examples[\"subspaces\"] = HELLASWAG_SUBSPACE\n",
    "    return examples\n",
    "\n",
    "def preprocess_ultrafeedback_for_reft(examples):\n",
    "    examples[\"subspaces\"] = INSTRUCT_SUBSPACE\n",
    "    return examples\n",
    "\n",
    "raw_dataset = load_dataset(\"LeoLM/HellaSwag_de\")\n",
    "drop_features = list(raw_dataset[\"train\"].features.keys())\n",
    "raw_dataset = raw_dataset.map(preprocess_hellaswag_de_for_reft)\n",
    "hellaswag_de_dataset = raw_dataset.remove_columns(drop_features)[\"train\"]\n",
    "\n",
    "raw_dataset = load_dataset(\"json\", data_files=\"./ultrafeedback_1k.json\")[\"train\"]\n",
    "raw_dataset = raw_dataset.map(preprocess_ultrafeedback_for_reft)\n",
    "ultrafeedback_dataset = raw_dataset.remove_columns([\"input\"])\n",
    "\n",
    "subspace_dataset = concatenate_datasets([hellaswag_de_dataset, ultrafeedback_dataset])\n",
    "\n",
    "class SubloreftIntervention(LoreftIntervention):\n",
    "    \"\"\"\n",
    "    This is a LoReFT that supports subspace interventions!\n",
    "    \"\"\"\n",
    "    def forward(\n",
    "        self, base, source=None, subspaces=None\n",
    "    ):\n",
    "        assert subspaces is not None\n",
    "        output = []\n",
    "        \n",
    "        rotated_base = self.rotate_layer(base)\n",
    "        diff = self.act_fn(self.learned_source(base)) - rotated_base\n",
    "        \n",
    "        batched_subspace = []\n",
    "        batched_weights = []\n",
    "        \n",
    "        for example_i in range(len(subspaces)):\n",
    "            LHS = diff[example_i, 0, subspaces[example_i]].unsqueeze(dim=0)\n",
    "            RHS = self.rotate_layer.weight[..., subspaces[example_i]].T\n",
    "            batched_subspace += [LHS]\n",
    "            batched_weights += [RHS]\n",
    "\n",
    "        batched_subspace = torch.stack(batched_subspace, dim=0)\n",
    "        batched_weights = torch.stack(batched_weights, dim=0)\n",
    "        output = base + torch.bmm(batched_subspace, batched_weights)\n",
    "\n",
    "        return self.dropout(output.to(base.dtype))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30f19971-698c-49c9-8f6b-60170595949f",
   "metadata": {},
   "source": [
    "### Loading the base LM (LLaMA-1 here! not Llama-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8a908ccf-9873-437f-9726-3573fbb18fe1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b72f27cf77e4fe8aa08ebcf97603c03",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fcbd2c7495b246f09fbcae44a1648416",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n",
      "normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.\n"
     ]
    }
   ],
   "source": [
    "# load model (take 1 min)\n",
    "model_name_or_path = \"yahma/llama-7b-hf\" # yahma/llama-7b-hf or yahma/llama-13b-hf\n",
    "model = transformers.AutoModelForCausalLM.from_pretrained(\n",
    "    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)\n",
    "\n",
    "# get tokenizer\n",
    "model_max_length = 512\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(\n",
    "    model_name_or_path, model_max_length=model_max_length, \n",
    "    padding_side=\"right\", use_fast=False)\n",
    "tokenizer.pad_token = tokenizer.unk_token"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4aa5d36e-4a3b-4532-a392-4874b4c5b87a",
   "metadata": {},
   "source": [
    "### Load rank 8 LoReFT config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e12ba835-ffad-4db2-bbde-9ac8225df204",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable intervention params: 65,544 || trainable model params: 0\n",
      "model params: 6,738,415,616 || trainable%: 0.0009726915603776257\n"
     ]
    }
   ],
   "source": [
    "TARGET_LAYER = 15\n",
    "\n",
    "# get reft model\n",
    "reft_config = ReftConfig(representations={\n",
    "    \"layer\": TARGET_LAYER, \"component\": \"block_output\",\n",
    "    \"intervention\": SubloreftIntervention(\n",
    "    embed_dim=model.config.hidden_size, low_rank_dimension=8)})\n",
    "reft_model = get_reft_model(model, reft_config)\n",
    "reft_model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4b9119a-4e1a-4b38-8bef-b332c1611199",
   "metadata": {},
   "source": [
    "### Load dataset\n",
    "\n",
    "Note that in total, we only have **2,000 training examples**, since LoReFT works with low resource settings - a bonus we did not fully explore in the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "28e227fe-c451-4f39-954f-27dc3f7509c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 586.03it/s]\n"
     ]
    }
   ],
   "source": [
    "train_dataset = ReftSupervisedDataset(\n",
    "    \"Subloreft\", None, tokenizer, dataset=subspace_dataset,\n",
    "    **{\"num_interventions\": 1, \"position\": \"l1\", \"share_weights\": False},\n",
    "    input_field=\"input\", instruction_field=\"instruction\", output_field=\"output\",\n",
    ")\n",
    "data_collator_fn = transformers.DataCollatorForSeq2Seq(\n",
    "    tokenizer=tokenizer,\n",
    "    model=model,\n",
    "    label_pad_token_id=-100,\n",
    "    padding=\"longest\"\n",
    ")\n",
    "data_collator = ReftDataCollator(data_collator=data_collator_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e2b2ad0-1187-441d-934d-12b477f94898",
   "metadata": {},
   "source": [
    "### Training!\n",
    "\n",
    "Note that we are not training a shared subspace for two tasks! We are training them individually by providing the `subspaces` field in the input! Checkout [pyvene](https://github.com/stanfordnlp/pyvene) about how to use `subspaces` field - there are other stuff we haven't tried."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46d0e9f8-b5ea-434f-a1aa-94991a0f68a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train\n",
    "training_args = transformers.TrainingArguments(\n",
    "    num_train_epochs=3.0, output_dir=\"./tmp\", learning_rate=5e-3, report_to=[],\n",
    "    per_device_train_batch_size=8, logging_steps=50\n",
    ")\n",
    "trainer = ReftTrainerForCausalLM(\n",
    "    model=reft_model, tokenizer=tokenizer, args=training_args, \n",
    "    train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bde5a459-fc9b-49e0-9d3e-ca6f725211fd",
   "metadata": {},
   "source": [
    "### Interact with the German sentence completion subspace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "99046802-ec47-4ed4-9538-8602b2fef244",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
      "\n",
      "### Instruction:\n",
      "How to kill a linux process?\n",
      "\n",
      "### Response:\n",
      "Es wird ein Linux-Prozess getötet, indem man ihn mit dem Befehl \"kill\" tötet.\n"
     ]
    }
   ],
   "source": [
    "instruction = \"How to kill a linux process?\"\n",
    "\n",
    "prompt = prompt_no_input_template % instruction\n",
    "prompt = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
    "\n",
    "base_unit_location = prompt[\"input_ids\"].shape[-1] - 1  # last position\n",
    "_, reft_response = reft_model.generate(\n",
    "    prompt, unit_locations={\"sources->base\": (None, [[[base_unit_location]]])},\n",
    "    subspaces=[[HELLASWAG_SUBSPACE]],\n",
    "    intervene_on_prompt=True, max_new_tokens=128, do_sample=False, \n",
    "    no_repeat_ngram_size=5, repetition_penalty=1.1,\n",
    "    eos_token_id=tokenizer.eos_token_id, early_stopping=True\n",
    ")\n",
    "print(tokenizer.decode(reft_response[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0642761-0039-4b11-915b-818f82f18eb3",
   "metadata": {},
   "source": [
    "### Interact with the instruction following subspace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "e560061a-067b-4994-96ee-204d6ac0abb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
      "\n",
      "### Instruction:\n",
      "How to kill a linux process?\n",
      "\n",
      "### Response:\n",
      "To kill a Linux process, you can use the `kill` command with the PID (process ID) of the process you want to terminate. For example, if you want to kill the process with PID 123456789, you would run the following command:\n",
      "\n",
      "```\n",
      "$ kill -9 123\n",
      "```\n",
      "\n",
      "This will send a signal to the process with Pid 123, instructing it to terminate immediately. The `-9` flag indicates that the process should be terminated forcefully and without any further warning or prompts.\n",
      "\n",
      "Note that this method only works for processes running on the same machine as you. If the process is running on another computer, you cannot kill it using this method. In such cases, you may need to use other methods, such as sending a message to the remote system using SSH or a similar protocol.\n"
     ]
    }
   ],
   "source": [
    "_, reft_response = reft_model.generate(\n",
    "    prompt, unit_locations={\"sources->base\": (None, [[[base_unit_location]]])},\n",
    "    subspaces=[[INSTRUCT_SUBSPACE]],\n",
    "    intervene_on_prompt=True, max_new_tokens=512, do_sample=False, \n",
    "    no_repeat_ngram_size=5, repetition_penalty=1.1,\n",
    "    eos_token_id=tokenizer.eos_token_id, early_stopping=True\n",
    ")\n",
    "print(tokenizer.decode(reft_response[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "752ea777-9c34-43a7-9737-78e0d682ef34",
   "metadata": {},
   "source": [
    "### Interact with both subspaces, partially!\n",
    "\n",
    "To interact with both of them, you can simply change the `subspaces` field at the inference time to any combinations you want!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "1c5ff5be-27bc-4ea4-b476-fc7d9c38c33a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
      "\n",
      "### Instruction:\n",
      "How to kill a linux process?\n",
      "\n",
      "### Response:\n",
      "Es gibt verschiedene Möglichkeiten, um einen Linux-Prozess zu löschen.\n",
      "\n",
      "1. Mit dem Kommando \"kill\" kann man den Prozess beenden.\n",
      "2. Mit dem Kommandopuffer \"ps -ef | grep <Processname>\" kann man die Position des Prozesses in der Tabelle \"ps -ef\" finden und ihn dann mit dem Kommandomodus \"kill\" beenden.\n"
     ]
    }
   ],
   "source": [
    "base_unit_location = prompt[\"input_ids\"].shape[-1] - 1  # last position\n",
    "_, reft_response = reft_model.generate(\n",
    "    prompt, unit_locations={\"sources->base\": (None, [[[base_unit_location]]])},\n",
    "    # unfortunately including 4,5 will break this, please explore why! maybe too strong?\n",
    "    # the output length dist of these two datasets are so different, so ...\n",
    "    subspaces=[[[0,1,2,3,6,7]]], \n",
    "    intervene_on_prompt=True, max_new_tokens=512, do_sample=False, \n",
    "    no_repeat_ngram_size=5, repetition_penalty=1.1,\n",
    "    eos_token_id=tokenizer.eos_token_id, early_stopping=True\n",
    ")\n",
    "print(tokenizer.decode(reft_response[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36b03c37-f2b8-47a8-a93e-32430a41e052",
   "metadata": {},
   "source": [
    "This is an early sneak-peek of our **Feature Compartmentalization**, and **Schematic ReFT**! Stay tuned, and explore with us! We think there are basically infinite number of causal pathways in the neural network waiting for us to explore!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
