{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from nnsight import LanguageModel\n",
    "from utils import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda:0'\n",
    "\n",
    "# set random seeds\n",
    "torch.manual_seed(0)\n",
    "torch.cuda.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Gemma's activation function should be approximate GeLU and not exact GeLU.\n",
      "Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.\n"
     ]
    }
   ],
   "source": [
    "# model_name = 'mistralai/Mistral-7B-Instruct-v0.2'\n",
    "# hidden_dim = 4096\n",
    "\n",
    "model_name = 'google/gemma-2b-it'\n",
    "hidden_dim = 2048\n",
    "\n",
    "model = LanguageModel(model_name, device_map=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GemmaForCausalLM(\n",
       "  (model): GemmaModel(\n",
       "    (embed_tokens): Embedding(256000, 2048, padding_idx=0)\n",
       "    (layers): ModuleList(\n",
       "      (0-17): 18 x GemmaDecoderLayer(\n",
       "        (self_attn): GemmaSdpaAttention(\n",
       "          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
       "          (k_proj): Linear(in_features=2048, out_features=256, bias=False)\n",
       "          (v_proj): Linear(in_features=2048, out_features=256, bias=False)\n",
       "          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
       "          (rotary_emb): GemmaRotaryEmbedding()\n",
       "        )\n",
       "        (mlp): GemmaMLP(\n",
       "          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)\n",
       "          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)\n",
       "          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)\n",
       "          (act_fn): PytorchGELUTanh()\n",
       "        )\n",
       "        (input_layernorm): GemmaRMSNorm()\n",
       "        (post_attention_layernorm): GemmaRMSNorm()\n",
       "      )\n",
       "    )\n",
       "    (norm): GemmaRMSNorm()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=2048, out_features=256000, bias=False)\n",
       "  (generator): WrapperModule()\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "194fb104a6d749eb83af460a1fbf3e11",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"<bos>Complete the following review. 'This movie was awesome, I loved it. The movie's genre of science fiction,\""
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "base_prompt = \"Complete the following review. 'This movie was awesome, I loved it. The movie's genre of\"\n",
    "weight = 0.\n",
    "direction = torch.randn(hidden_dim).to(device)\n",
    "max_tokens = 3\n",
    "\n",
    "utils.generate_nnsight(model, device, base_prompt, weight, direction, max_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"<bos>Complete the following review. 'This movie was awesome, I loved it. The movie's genre of embodi embodi embodi\""
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "base_prompt = \"Complete the following review. 'This movie was awesome, I loved it. The movie's genre of\"\n",
    "weight = 10.\n",
    "direction = torch.randn(hidden_dim).to(device)\n",
    "max_tokens = 3\n",
    "\n",
    "utils.generate_nnsight(model, device, base_prompt, weight, direction, max_tokens)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load some actual directions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([5, 2048])\n"
     ]
    }
   ],
   "source": [
    "saved_directions = f\"/net/projects/user/geometry_llms/directions/intervention/sentiment_{model_name.split('/')[-1]}.pt\"\n",
    "# load the directions\n",
    "directions = torch.load(saved_directions)\n",
    "print(directions.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<bos>Complete the following review. 'This movie was awesome, I loved it. The movie's genre of magical realism seamlessly\n",
      "<bos>Complete the following review. 'This movie was awesome, I loved it. The movie's genre of sci-fi\n",
      "<bos>Complete the following review. 'This movie was awesome, I loved it. The movie's genre of sci-fi\n",
      "<bos>Complete the following review. 'This movie was awesome, I loved it. The movie's genre of<bos><bos><bos>\n",
      "<bos>Complete the following review. 'This movie was awesome, I loved it. The movie's genre of great great great\n"
     ]
    }
   ],
   "source": [
    "weight = 50.\n",
    "for direction in directions:\n",
    "    generation = utils.generate_nnsight(model, device, base_prompt, weight, direction, max_tokens)\n",
    "    print(generation)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "haha I wonder which is which..."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "editeval",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
