{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from transformers.modeling_outputs import CausalLMOutputWithPast\n",
    "import torch\n",
    "from typing import Optional, Tuple, List, Union\n",
    "import types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c83c5ad56ff24ebdba8dbd6f570a0310",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d4297ae197fc4dfc956499ed68e89b37",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "76865e9701684434ae8bf789bb301502",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "02b804aa14274e8fb8479685ad154b7a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "abd8d087012c415c9cead2c75a795c7f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bc352e8309234ed99b267026e7e14ea8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c206c6acfea24650bac6fe1f57dcc466",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2b-it\")\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    \"google/gemma-2b-it\",\n",
    "    device_map=\"auto\",\n",
    "    torch_dtype=torch.bfloat16\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_with_edit(self, edit_direction=None, edit_coefficient=None, **generate_kwargs):\n",
    "    self.edit_direction = edit_direction\n",
    "    self.edit_coefficient = edit_coefficient\n",
    "    return self.generate(**generate_kwargs)\n",
    "model.generate_with_edit = types.MethodType(generate_with_edit, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# /net/projects/user/geometry_llms/directions/intervention/sentiment_{MODEL_NAME}.pt. Each file should contain a single tensor of shape (num_directions, hidden_size).\n",
    "# I’ve computed the same five directions for three models: gemma-2b, Mistral-7B-v0.2, and Mistral-7B-Instruct-v0.2.\n",
    "# The recommended delta values for each model are, respectively: 20, 200, and 500 (add for positive, subtract for negative)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)\n",
    "# @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n",
    "def forward(\n",
    "    self,\n",
    "    input_ids: torch.LongTensor = None,\n",
    "    attention_mask: Optional[torch.Tensor] = None,\n",
    "    position_ids: Optional[torch.LongTensor] = None,\n",
    "    past_key_values: Optional[List[torch.FloatTensor]] = None,\n",
    "    inputs_embeds: Optional[torch.FloatTensor] = None,\n",
    "    labels: Optional[torch.LongTensor] = None,\n",
    "    use_cache: Optional[bool] = None,\n",
    "    output_attentions: Optional[bool] = None,\n",
    "    output_hidden_states: Optional[bool] = None,\n",
    "    return_dict: Optional[bool] = None,\n",
    "    cache_position: Optional[torch.LongTensor] = None,\n",
    ") -> Union[Tuple, CausalLMOutputWithPast]:\n",
    "    r\"\"\"\n",
    "    Args:\n",
    "        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n",
    "            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n",
    "            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n",
    "            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n",
    "\n",
    "    Returns:\n",
    "\n",
    "    Example:\n",
    "\n",
    "    ```python\n",
    "    >>> from transformers import AutoTokenizer, GemmaForCausalLM\n",
    "\n",
    "    >>> model = GemmaForCausalLM.from_pretrained(\"google/gemma-7b\")\n",
    "    >>> tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-7b\")\n",
    "\n",
    "    >>> prompt = \"What is your favorite condiment?\"\n",
    "    >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
    "\n",
    "    >>> # Generate\n",
    "    >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n",
    "    >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n",
    "    \"What is your favorite condiment?\"\n",
    "    ```\"\"\"\n",
    "    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n",
    "    output_hidden_states = (\n",
    "        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n",
    "    )\n",
    "    return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n",
    "    outputs = self.model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        position_ids=position_ids,\n",
    "        past_key_values=past_key_values,\n",
    "        inputs_embeds=inputs_embeds,\n",
    "        use_cache=use_cache,\n",
    "        output_attentions=output_attentions,\n",
    "        output_hidden_states=output_hidden_states,\n",
    "        return_dict=return_dict,\n",
    "        cache_position=cache_position,\n",
    "    )\n",
    "\n",
    "    hidden_states = outputs[0]\n",
    "    if self.edit_direction is not None:\n",
    "        self.edit_direction = self.edit_direction.to(hidden_states.device)\n",
    "        self.edit_direction = self.edit_direction.to(torch.bfloat16)\n",
    "        hidden_states = hidden_states + self.edit_direction * self.edit_coefficient\n",
    "    logits = self.lm_head(hidden_states)\n",
    "    logits = logits.float()\n",
    "    loss = None\n",
    "    if labels is not None:\n",
    "        # Shift so that tokens < n predict n\n",
    "        shift_logits = logits[..., :-1, :].contiguous()\n",
    "        shift_labels = labels[..., 1:].contiguous()\n",
    "        # Flatten the tokens\n",
    "        loss_fct = CrossEntropyLoss()\n",
    "        shift_logits = shift_logits.view(-1, self.config.vocab_size)\n",
    "        shift_labels = shift_labels.view(-1)\n",
    "        # Enable model parallelism\n",
    "        shift_labels = shift_labels.to(shift_logits.device)\n",
    "        loss = loss_fct(shift_logits, shift_labels)\n",
    "\n",
    "    if not return_dict:\n",
    "        output = (logits,) + outputs[1:]\n",
    "        return (loss,) + output if loss is not None else output\n",
    "\n",
    "    return CausalLMOutputWithPast(\n",
    "        loss=loss,\n",
    "        logits=logits,\n",
    "        past_key_values=outputs.past_key_values,\n",
    "        hidden_states=outputs.hidden_states,\n",
    "        attentions=outputs.attentions,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.forward = types.MethodType(forward, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([5, 2048]), torch.float32)"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "EDIT_FOLDER = \"/net/projects/user/geometry_llms/directions/intervention/\"\n",
    "edit_tensor = torch.load(EDIT_FOLDER + 'sentiment_gemma-2b-it.pt').to(model.device)\n",
    "edit_tensor.shape, edit_tensor.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.edit_direction = None\n",
    "edit_coefficient = 10\n",
    "edit_direction = edit_tensor[0]\n",
    "\n",
    "input_text = \"Write me a poem about Machine Learning.\"\n",
    "input_ids = tokenizer(input_text, return_tensors=\"pt\").to(model.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<bos>Write me a poem about Machine Learning.\n",
      "\n",
      "Machines, they weave and they learn,\n",
      "From the data, they discern.\n",
      "Algorithms, a symphony,\n",
      "Unleash the power of the machine.\n",
      "\n",
      "Data as the canvas, a masterpiece,\n",
      "Machine paints, a new perspective.\n",
      "From the past, the future takes flight,\n",
      "With the wisdom of machines, day and night.\n",
      "\n",
      "Algorithms, a dance of the mind,\n",
      "Unleash the power of the machine.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "outputs = model.generate(max_length=100, **input_ids)\n",
    "print(tokenizer.decode(outputs[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<bos>Write me a poem about Machine Learning.\n",
      "\n",
      "Machines, vast and deep, with algorithms bright,\n",
      "Unravel patterns, day and night.\n",
      "From data's flow, they learn and adapt,\n",
      "A symphony of algorithms, a wondrous fact.\n",
      "\n",
      "With each iteration, they refine their art,\n",
      "Solving problems, fulfilling every part.\n",
      "From medical scans to financial trends,\n",
      "They weave insights, where once there were none.\n",
      "\n",
      "But with great power comes a moral sway,\n",
      "Bias\n"
     ]
    }
   ],
   "source": [
    "outputs = model.generate_with_edit(max_length=100, **input_ids, edit_direction=edit_direction, edit_coefficient=edit_coefficient)\n",
    "print(tokenizer.decode(outputs[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<bos>Write me a poem about Machine Learning.\n",
      "\n",
      "Machines, they learn and they grow,\n",
      "Algorithms that dance, a symphony.\n",
      "Data as their canvas, they paint,\n",
      "Unleashing the power of the human brain.\n",
      "\n",
      "From medical diagnosis to financial trade,\n",
      "They predict, they forecast, they pave the way.\n",
      "Unveiling the secrets of the unknown,\n",
      "Unleashing the potential of the unknown.\n",
      "\n",
      "But with power comes responsibility,\n",
      "A responsibility to be responsible.\n"
     ]
    }
   ],
   "source": [
    "outputs = model.generate_with_edit(max_length=100, **input_ids, edit_direction=edit_direction, edit_coefficient=-edit_coefficient)\n",
    "print(tokenizer.decode(outputs[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "editeval",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
