{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Composing Prefixes\n",
    "\n",
    "This is a preliminary experiment for composing the Prefix Controllers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "sys.path.append(\"../../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-07 07:04:40.452343: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
      "2024-08-07 07:04:40.520393: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-08-07 07:04:41.865543: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import gc\n",
    "import time\n",
    "from self_control.utils import get_suffix_grads_from_wrapped_model\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"6\"\n",
    "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n",
    "from itertools import islice\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n",
    "from self_control.suffix_gradient import WrappedReadingVecModel\n",
    "import torch.nn.functional as F\n",
    "from peft import AdaptionPromptConfig, get_peft_model, LoraModel, LoraConfig\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import BitsAndBytesConfig\n",
    "from peft import PeftModel, PeftConfig, get_peft_model, PromptTuningConfig, AdaptionPromptConfig, TaskType, PromptTuningInit, PeftMixedModel, set_peft_model_state_dict\n",
    "\n",
    "llama_adapter_config = AdaptionPromptConfig(\n",
    "    adapter_len=128,\n",
    "    adapter_layers=32,\n",
    "    task_type=\"CAUSAL_LM\",\n",
    "    target_modules=\"self_attn\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "afc004f653574befa8c592f90d331506",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model_name_or_path = \"meta-llama/Llama-2-7b-chat-hf\"\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, device_map=\"cuda:4\")\n",
    "use_fast_tokenizer = \"LlamaForCausalLM\" not in model.config.architectures\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side='left')\n",
    "tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id\n",
    "tokenizer.bos_token_id = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "peft_model = get_peft_model(model, llama_adapter_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "mistral_prefix_dict = {\n",
    "    \"surprised\": \"../adapters/calm2surprised-final-mistralprefix+adapter-50-0.003/prefix_embedder.pth\",\n",
    "    \"reasoning\": \"../adapters/reasoning-smallernorm-final-gogogoprefix+adapter-50-0.003/prefix_embedder.pth\"\n",
    "}\n",
    "mistral_adapter_dict = {\n",
    "    \"surprised\": \"../adapters/calm2surprised-final-mistralprefix+adapter-50-0.003/adapter_model.safetensors\",\n",
    "    \"reasoning\": \"../adapters/reasoning-smallernorm-final-gogogoprefix+adapter-50-0.003/adapter_model.safetensors\"\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "llama2_prefix_dict = {\n",
    "    \"peaceful\": \"../adapters/angry2peaceful-finalprefix+adapter-50-0.003/prefix_embedder.pth\",\n",
    "    \"fearless\": \"../adapters/afraid2fearless-finalprefix+adapter-50-0.003/prefix_embedder.pth\"\n",
    "}\n",
    "llama2_adapter_dict = {\n",
    "    \"peaceful\": \"../adapters/angry2peaceful-finalprefix+adapter-50-0.003/adapter_model.safetensors\",\n",
    "    \"fearless\": \"../adapters/afraid2fearless-finalprefix+adapter-50-0.003/adapter_model.safetensors\"\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test on peaceful and fearless\n",
    "\n",
    "1. peaceful\n",
    "2. fearless\n",
    "3. OOD - happy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from safetensors import safe_open\n",
    "composited_adapter = {}\n",
    "composited_embedder = {}\n",
    "adaption_list = []\n",
    "prefix_list = []\n",
    "compose_dict = {\n",
    "    \"peaceful\": 0.5,\n",
    "    \"fearless\": 0.5\n",
    "}\n",
    "coeff_list = [0.5, 0.5]\n",
    "for key in compose_dict:\n",
    "    temp_tensors = {}\n",
    "    with safe_open(llama2_adapter_dict[key], framework=\"pt\", device=4) as f:\n",
    "        for k in f.keys():\n",
    "            temp_tensors[k] = f.get_tensor(k)\n",
    "    adaption_list.append(temp_tensors)\n",
    "    temp_embedder = torch.load(llama2_prefix_dict[key])\n",
    "    prefix_list.append(temp_embedder)\n",
    "\n",
    "for key in adaption_list[0]:\n",
    "    composited_adapter[key] = adaption_list[0][key] * coeff_list[0] + adaption_list[1][key] * coeff_list[1]\n",
    "for key in prefix_list[0]:\n",
    "    composited_embedder[key] = prefix_list[0][key] * coeff_list[0] + prefix_list[1][key] * coeff_list[1]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "from safetensors import safe_open\n",
    "composited_adapter = {}\n",
    "composited_embedder = {}\n",
    "adaption_list = []\n",
    "prefix_list = []\n",
    "compose_dict = {\n",
    "    \"surprised\": 0.4,\n",
    "    \"reasoning\": 0.6\n",
    "}\n",
    "coeff_list = [0.4, 0.6]\n",
    "for key in compose_dict:\n",
    "    temp_tensors = {}\n",
    "    with safe_open(mistral_adapter_dict[key], framework=\"pt\", device=4) as f:\n",
    "        for k in f.keys():\n",
    "            temp_tensors[k] = f.get_tensor(k)\n",
    "    adaption_list.append(temp_tensors)\n",
    "    temp_embedder = torch.load(mistral_prefix_dict[key])\n",
    "    prefix_list.append(temp_embedder)\n",
    "\n",
    "for key in adaption_list[0]:\n",
    "    composited_adapter[key] = adaption_list[0][key] * coeff_list[0] + adaption_list[1][key] * coeff_list[1]\n",
    "for key in prefix_list[0]:\n",
    "    composited_embedder[key] = prefix_list[0][key] * coeff_list[0] + prefix_list[1][key] * coeff_list[1]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "dot_token_ids = [tokenizer.convert_tokens_to_ids(\".\")]\n",
    "prefix_token_ids = tokenizer.encode(\"<<SYS>> You are an assistant <</SYS>>\", add_special_tokens=False)\n",
    "prefix_token_ids = torch.tensor(prefix_token_ids + dot_token_ids * 5).unsqueeze(dim=0)\n",
    "peft_model.prefix_embedder = nn.Embedding(num_embeddings=prefix_token_ids.size(1), embedding_dim=model.config.hidden_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "peft_model.prefix_embedder.load_state_dict(composited_embedder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_IncompatibleKeys(missing_keys=['base_model.model.model.embed_tokens.weight', 'base_model.model.model.layers.0.self_attn.model.q_proj.weight', 'base_model.model.model.layers.0.self_attn.model.k_proj.weight', 'base_model.model.model.layers.0.self_attn.model.v_proj.weight', 'base_model.model.model.layers.0.self_attn.model.o_proj.weight', 'base_model.model.model.layers.0.mlp.gate_proj.weight', 'base_model.model.model.layers.0.mlp.up_proj.weight', 'base_model.model.model.layers.0.mlp.down_proj.weight', 'base_model.model.model.layers.0.input_layernorm.weight', 'base_model.model.model.layers.0.post_attention_layernorm.weight', 'base_model.model.model.layers.1.self_attn.model.q_proj.weight', 'base_model.model.model.layers.1.self_attn.model.k_proj.weight', 'base_model.model.model.layers.1.self_attn.model.v_proj.weight', 'base_model.model.model.layers.1.self_attn.model.o_proj.weight', 'base_model.model.model.layers.1.mlp.gate_proj.weight', 'base_model.model.model.layers.1.mlp.up_proj.weight', 'base_model.model.model.layers.1.mlp.down_proj.weight', 'base_model.model.model.layers.1.input_layernorm.weight', 'base_model.model.model.layers.1.post_attention_layernorm.weight', 'base_model.model.model.layers.2.self_attn.model.q_proj.weight', 'base_model.model.model.layers.2.self_attn.model.k_proj.weight', 'base_model.model.model.layers.2.self_attn.model.v_proj.weight', 'base_model.model.model.layers.2.self_attn.model.o_proj.weight', 'base_model.model.model.layers.2.mlp.gate_proj.weight', 'base_model.model.model.layers.2.mlp.up_proj.weight', 'base_model.model.model.layers.2.mlp.down_proj.weight', 'base_model.model.model.layers.2.input_layernorm.weight', 'base_model.model.model.layers.2.post_attention_layernorm.weight', 'base_model.model.model.layers.3.self_attn.model.q_proj.weight', 'base_model.model.model.layers.3.self_attn.model.k_proj.weight', 'base_model.model.model.layers.3.self_attn.model.v_proj.weight', 'base_model.model.model.layers.3.self_attn.model.o_proj.weight', 'base_model.model.model.layers.3.mlp.gate_proj.weight', 'base_model.model.model.layers.3.mlp.up_proj.weight', 'base_model.model.model.layers.3.mlp.down_proj.weight', 'base_model.model.model.layers.3.input_layernorm.weight', 'base_model.model.model.layers.3.post_attention_layernorm.weight', 'base_model.model.model.layers.4.self_attn.model.q_proj.weight', 'base_model.model.model.layers.4.self_attn.model.k_proj.weight', 'base_model.model.model.layers.4.self_attn.model.v_proj.weight', 'base_model.model.model.layers.4.self_attn.model.o_proj.weight', 'base_model.model.model.layers.4.mlp.gate_proj.weight', 'base_model.model.model.layers.4.mlp.up_proj.weight', 'base_model.model.model.layers.4.mlp.down_proj.weight', 'base_model.model.model.layers.4.input_layernorm.weight', 'base_model.model.model.layers.4.post_attention_layernorm.weight', 'base_model.model.model.layers.5.self_attn.model.q_proj.weight', 'base_model.model.model.layers.5.self_attn.model.k_proj.weight', 'base_model.model.model.layers.5.self_attn.model.v_proj.weight', 'base_model.model.model.layers.5.self_attn.model.o_proj.weight', 'base_model.model.model.layers.5.mlp.gate_proj.weight', 'base_model.model.model.layers.5.mlp.up_proj.weight', 'base_model.model.model.layers.5.mlp.down_proj.weight', 'base_model.model.model.layers.5.input_layernorm.weight', 'base_model.model.model.layers.5.post_attention_layernorm.weight', 'base_model.model.model.layers.6.self_attn.model.q_proj.weight', 'base_model.model.model.layers.6.self_attn.model.k_proj.weight', 'base_model.model.model.layers.6.self_attn.model.v_proj.weight', 'base_model.model.model.layers.6.self_attn.model.o_proj.weight', 'base_model.model.model.layers.6.mlp.gate_proj.weight', 'base_model.model.model.layers.6.mlp.up_proj.weight', 'base_model.model.model.layers.6.mlp.down_proj.weight', 'base_model.model.model.layers.6.input_layernorm.weight', 'base_model.model.model.layers.6.post_attention_layernorm.weight', 'base_model.model.model.layers.7.self_attn.model.q_proj.weight', 'base_model.model.model.layers.7.self_attn.model.k_proj.weight', 'base_model.model.model.layers.7.self_attn.model.v_proj.weight', 'base_model.model.model.layers.7.self_attn.model.o_proj.weight', 'base_model.model.model.layers.7.mlp.gate_proj.weight', 'base_model.model.model.layers.7.mlp.up_proj.weight', 'base_model.model.model.layers.7.mlp.down_proj.weight', 'base_model.model.model.layers.7.input_layernorm.weight', 'base_model.model.model.layers.7.post_attention_layernorm.weight', 'base_model.model.model.layers.8.self_attn.model.q_proj.weight', 'base_model.model.model.layers.8.self_attn.model.k_proj.weight', 'base_model.model.model.layers.8.self_attn.model.v_proj.weight', 'base_model.model.model.layers.8.self_attn.model.o_proj.weight', 'base_model.model.model.layers.8.mlp.gate_proj.weight', 'base_model.model.model.layers.8.mlp.up_proj.weight', 'base_model.model.model.layers.8.mlp.down_proj.weight', 'base_model.model.model.layers.8.input_layernorm.weight', 'base_model.model.model.layers.8.post_attention_layernorm.weight', 'base_model.model.model.layers.9.self_attn.model.q_proj.weight', 'base_model.model.model.layers.9.self_attn.model.k_proj.weight', 'base_model.model.model.layers.9.self_attn.model.v_proj.weight', 'base_model.model.model.layers.9.self_attn.model.o_proj.weight', 'base_model.model.model.layers.9.mlp.gate_proj.weight', 'base_model.model.model.layers.9.mlp.up_proj.weight', 'base_model.model.model.layers.9.mlp.down_proj.weight', 'base_model.model.model.layers.9.input_layernorm.weight', 'base_model.model.model.layers.9.post_attention_layernorm.weight', 'base_model.model.model.layers.10.self_attn.model.q_proj.weight', 'base_model.model.model.layers.10.self_attn.model.k_proj.weight', 'base_model.model.model.layers.10.self_attn.model.v_proj.weight', 'base_model.model.model.layers.10.self_attn.model.o_proj.weight', 'base_model.model.model.layers.10.mlp.gate_proj.weight', 'base_model.model.model.layers.10.mlp.up_proj.weight', 'base_model.model.model.layers.10.mlp.down_proj.weight', 'base_model.model.model.layers.10.input_layernorm.weight', 'base_model.model.model.layers.10.post_attention_layernorm.weight', 'base_model.model.model.layers.11.self_attn.model.q_proj.weight', 'base_model.model.model.layers.11.self_attn.model.k_proj.weight', 'base_model.model.model.layers.11.self_attn.model.v_proj.weight', 'base_model.model.model.layers.11.self_attn.model.o_proj.weight', 'base_model.model.model.layers.11.mlp.gate_proj.weight', 'base_model.model.model.layers.11.mlp.up_proj.weight', 'base_model.model.model.layers.11.mlp.down_proj.weight', 'base_model.model.model.layers.11.input_layernorm.weight', 'base_model.model.model.layers.11.post_attention_layernorm.weight', 'base_model.model.model.layers.12.self_attn.model.q_proj.weight', 'base_model.model.model.layers.12.self_attn.model.k_proj.weight', 'base_model.model.model.layers.12.self_attn.model.v_proj.weight', 'base_model.model.model.layers.12.self_attn.model.o_proj.weight', 'base_model.model.model.layers.12.mlp.gate_proj.weight', 'base_model.model.model.layers.12.mlp.up_proj.weight', 'base_model.model.model.layers.12.mlp.down_proj.weight', 'base_model.model.model.layers.12.input_layernorm.weight', 'base_model.model.model.layers.12.post_attention_layernorm.weight', 'base_model.model.model.layers.13.self_attn.model.q_proj.weight', 'base_model.model.model.layers.13.self_attn.model.k_proj.weight', 'base_model.model.model.layers.13.self_attn.model.v_proj.weight', 'base_model.model.model.layers.13.self_attn.model.o_proj.weight', 'base_model.model.model.layers.13.mlp.gate_proj.weight', 'base_model.model.model.layers.13.mlp.up_proj.weight', 'base_model.model.model.layers.13.mlp.down_proj.weight', 'base_model.model.model.layers.13.input_layernorm.weight', 'base_model.model.model.layers.13.post_attention_layernorm.weight', 'base_model.model.model.layers.14.self_attn.model.q_proj.weight', 'base_model.model.model.layers.14.self_attn.model.k_proj.weight', 'base_model.model.model.layers.14.self_attn.model.v_proj.weight', 'base_model.model.model.layers.14.self_attn.model.o_proj.weight', 'base_model.model.model.layers.14.mlp.gate_proj.weight', 'base_model.model.model.layers.14.mlp.up_proj.weight', 'base_model.model.model.layers.14.mlp.down_proj.weight', 'base_model.model.model.layers.14.input_layernorm.weight', 'base_model.model.model.layers.14.post_attention_layernorm.weight', 'base_model.model.model.layers.15.self_attn.model.q_proj.weight', 'base_model.model.model.layers.15.self_attn.model.k_proj.weight', 'base_model.model.model.layers.15.self_attn.model.v_proj.weight', 'base_model.model.model.layers.15.self_attn.model.o_proj.weight', 'base_model.model.model.layers.15.mlp.gate_proj.weight', 'base_model.model.model.layers.15.mlp.up_proj.weight', 'base_model.model.model.layers.15.mlp.down_proj.weight', 'base_model.model.model.layers.15.input_layernorm.weight', 'base_model.model.model.layers.15.post_attention_layernorm.weight', 'base_model.model.model.layers.16.self_attn.model.q_proj.weight', 'base_model.model.model.layers.16.self_attn.model.k_proj.weight', 'base_model.model.model.layers.16.self_attn.model.v_proj.weight', 'base_model.model.model.layers.16.self_attn.model.o_proj.weight', 'base_model.model.model.layers.16.mlp.gate_proj.weight', 'base_model.model.model.layers.16.mlp.up_proj.weight', 'base_model.model.model.layers.16.mlp.down_proj.weight', 'base_model.model.model.layers.16.input_layernorm.weight', 'base_model.model.model.layers.16.post_attention_layernorm.weight', 'base_model.model.model.layers.17.self_attn.model.q_proj.weight', 'base_model.model.model.layers.17.self_attn.model.k_proj.weight', 'base_model.model.model.layers.17.self_attn.model.v_proj.weight', 'base_model.model.model.layers.17.self_attn.model.o_proj.weight', 'base_model.model.model.layers.17.mlp.gate_proj.weight', 'base_model.model.model.layers.17.mlp.up_proj.weight', 'base_model.model.model.layers.17.mlp.down_proj.weight', 'base_model.model.model.layers.17.input_layernorm.weight', 'base_model.model.model.layers.17.post_attention_layernorm.weight', 'base_model.model.model.layers.18.self_attn.model.q_proj.weight', 'base_model.model.model.layers.18.self_attn.model.k_proj.weight', 'base_model.model.model.layers.18.self_attn.model.v_proj.weight', 'base_model.model.model.layers.18.self_attn.model.o_proj.weight', 'base_model.model.model.layers.18.mlp.gate_proj.weight', 'base_model.model.model.layers.18.mlp.up_proj.weight', 'base_model.model.model.layers.18.mlp.down_proj.weight', 'base_model.model.model.layers.18.input_layernorm.weight', 'base_model.model.model.layers.18.post_attention_layernorm.weight', 'base_model.model.model.layers.19.self_attn.model.q_proj.weight', 'base_model.model.model.layers.19.self_attn.model.k_proj.weight', 'base_model.model.model.layers.19.self_attn.model.v_proj.weight', 'base_model.model.model.layers.19.self_attn.model.o_proj.weight', 'base_model.model.model.layers.19.mlp.gate_proj.weight', 'base_model.model.model.layers.19.mlp.up_proj.weight', 'base_model.model.model.layers.19.mlp.down_proj.weight', 'base_model.model.model.layers.19.input_layernorm.weight', 'base_model.model.model.layers.19.post_attention_layernorm.weight', 'base_model.model.model.layers.20.self_attn.model.q_proj.weight', 'base_model.model.model.layers.20.self_attn.model.k_proj.weight', 'base_model.model.model.layers.20.self_attn.model.v_proj.weight', 'base_model.model.model.layers.20.self_attn.model.o_proj.weight', 'base_model.model.model.layers.20.mlp.gate_proj.weight', 'base_model.model.model.layers.20.mlp.up_proj.weight', 'base_model.model.model.layers.20.mlp.down_proj.weight', 'base_model.model.model.layers.20.input_layernorm.weight', 'base_model.model.model.layers.20.post_attention_layernorm.weight', 'base_model.model.model.layers.21.self_attn.model.q_proj.weight', 'base_model.model.model.layers.21.self_attn.model.k_proj.weight', 'base_model.model.model.layers.21.self_attn.model.v_proj.weight', 'base_model.model.model.layers.21.self_attn.model.o_proj.weight', 'base_model.model.model.layers.21.mlp.gate_proj.weight', 'base_model.model.model.layers.21.mlp.up_proj.weight', 'base_model.model.model.layers.21.mlp.down_proj.weight', 'base_model.model.model.layers.21.input_layernorm.weight', 'base_model.model.model.layers.21.post_attention_layernorm.weight', 'base_model.model.model.layers.22.self_attn.model.q_proj.weight', 'base_model.model.model.layers.22.self_attn.model.k_proj.weight', 'base_model.model.model.layers.22.self_attn.model.v_proj.weight', 'base_model.model.model.layers.22.self_attn.model.o_proj.weight', 'base_model.model.model.layers.22.mlp.gate_proj.weight', 'base_model.model.model.layers.22.mlp.up_proj.weight', 'base_model.model.model.layers.22.mlp.down_proj.weight', 'base_model.model.model.layers.22.input_layernorm.weight', 'base_model.model.model.layers.22.post_attention_layernorm.weight', 'base_model.model.model.layers.23.self_attn.model.q_proj.weight', 'base_model.model.model.layers.23.self_attn.model.k_proj.weight', 'base_model.model.model.layers.23.self_attn.model.v_proj.weight', 'base_model.model.model.layers.23.self_attn.model.o_proj.weight', 'base_model.model.model.layers.23.mlp.gate_proj.weight', 'base_model.model.model.layers.23.mlp.up_proj.weight', 'base_model.model.model.layers.23.mlp.down_proj.weight', 'base_model.model.model.layers.23.input_layernorm.weight', 'base_model.model.model.layers.23.post_attention_layernorm.weight', 'base_model.model.model.layers.24.self_attn.model.q_proj.weight', 'base_model.model.model.layers.24.self_attn.model.k_proj.weight', 'base_model.model.model.layers.24.self_attn.model.v_proj.weight', 'base_model.model.model.layers.24.self_attn.model.o_proj.weight', 'base_model.model.model.layers.24.mlp.gate_proj.weight', 'base_model.model.model.layers.24.mlp.up_proj.weight', 'base_model.model.model.layers.24.mlp.down_proj.weight', 'base_model.model.model.layers.24.input_layernorm.weight', 'base_model.model.model.layers.24.post_attention_layernorm.weight', 'base_model.model.model.layers.25.self_attn.model.q_proj.weight', 'base_model.model.model.layers.25.self_attn.model.k_proj.weight', 'base_model.model.model.layers.25.self_attn.model.v_proj.weight', 'base_model.model.model.layers.25.self_attn.model.o_proj.weight', 'base_model.model.model.layers.25.mlp.gate_proj.weight', 'base_model.model.model.layers.25.mlp.up_proj.weight', 'base_model.model.model.layers.25.mlp.down_proj.weight', 'base_model.model.model.layers.25.input_layernorm.weight', 'base_model.model.model.layers.25.post_attention_layernorm.weight', 'base_model.model.model.layers.26.self_attn.model.q_proj.weight', 'base_model.model.model.layers.26.self_attn.model.k_proj.weight', 'base_model.model.model.layers.26.self_attn.model.v_proj.weight', 'base_model.model.model.layers.26.self_attn.model.o_proj.weight', 'base_model.model.model.layers.26.mlp.gate_proj.weight', 'base_model.model.model.layers.26.mlp.up_proj.weight', 'base_model.model.model.layers.26.mlp.down_proj.weight', 'base_model.model.model.layers.26.input_layernorm.weight', 'base_model.model.model.layers.26.post_attention_layernorm.weight', 'base_model.model.model.layers.27.self_attn.model.q_proj.weight', 'base_model.model.model.layers.27.self_attn.model.k_proj.weight', 'base_model.model.model.layers.27.self_attn.model.v_proj.weight', 'base_model.model.model.layers.27.self_attn.model.o_proj.weight', 'base_model.model.model.layers.27.mlp.gate_proj.weight', 'base_model.model.model.layers.27.mlp.up_proj.weight', 'base_model.model.model.layers.27.mlp.down_proj.weight', 'base_model.model.model.layers.27.input_layernorm.weight', 'base_model.model.model.layers.27.post_attention_layernorm.weight', 'base_model.model.model.layers.28.self_attn.model.q_proj.weight', 'base_model.model.model.layers.28.self_attn.model.k_proj.weight', 'base_model.model.model.layers.28.self_attn.model.v_proj.weight', 'base_model.model.model.layers.28.self_attn.model.o_proj.weight', 'base_model.model.model.layers.28.mlp.gate_proj.weight', 'base_model.model.model.layers.28.mlp.up_proj.weight', 'base_model.model.model.layers.28.mlp.down_proj.weight', 'base_model.model.model.layers.28.input_layernorm.weight', 'base_model.model.model.layers.28.post_attention_layernorm.weight', 'base_model.model.model.layers.29.self_attn.model.q_proj.weight', 'base_model.model.model.layers.29.self_attn.model.k_proj.weight', 'base_model.model.model.layers.29.self_attn.model.v_proj.weight', 'base_model.model.model.layers.29.self_attn.model.o_proj.weight', 'base_model.model.model.layers.29.mlp.gate_proj.weight', 'base_model.model.model.layers.29.mlp.up_proj.weight', 'base_model.model.model.layers.29.mlp.down_proj.weight', 'base_model.model.model.layers.29.input_layernorm.weight', 'base_model.model.model.layers.29.post_attention_layernorm.weight', 'base_model.model.model.layers.30.self_attn.model.q_proj.weight', 'base_model.model.model.layers.30.self_attn.model.k_proj.weight', 'base_model.model.model.layers.30.self_attn.model.v_proj.weight', 'base_model.model.model.layers.30.self_attn.model.o_proj.weight', 'base_model.model.model.layers.30.mlp.gate_proj.weight', 'base_model.model.model.layers.30.mlp.up_proj.weight', 'base_model.model.model.layers.30.mlp.down_proj.weight', 'base_model.model.model.layers.30.input_layernorm.weight', 'base_model.model.model.layers.30.post_attention_layernorm.weight', 'base_model.model.model.layers.31.self_attn.model.q_proj.weight', 'base_model.model.model.layers.31.self_attn.model.k_proj.weight', 'base_model.model.model.layers.31.self_attn.model.v_proj.weight', 'base_model.model.model.layers.31.self_attn.model.o_proj.weight', 'base_model.model.model.layers.31.mlp.gate_proj.weight', 'base_model.model.model.layers.31.mlp.up_proj.weight', 'base_model.model.model.layers.31.mlp.down_proj.weight', 'base_model.model.model.layers.31.input_layernorm.weight', 'base_model.model.model.layers.31.post_attention_layernorm.weight', 'base_model.model.model.norm.weight', 'base_model.model.lm_head.weight', 'prefix_embedder.weight'], unexpected_keys=[])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "set_peft_model_state_dict(peft_model, composited_adapter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "from self_control.utils.utils import greedy_decode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Of course! To find out how many bolts it takes to make a robe, we need to know the total amount of blue fiber and white fiber required.\\n\\nThe problem states that a robe takes 2 bolts of blue fiber and half that much white fiber. So, if we let \"x\" be the number of bolts of blue fiber, we can write the equation:\\n\\n2x = 2 bolts of blue fiber\\n\\nSince half of the white fiber is also needed, we can write:\\n\\nx/2 = half of the white fiber\\n\\nNow we can substitute these equations into each other to solve for \"x\":\\n\\n2x = 2 bolts of blue fiber\\nx/2 = half of the white fiber\\n\\nx = 4 bolts of white fiber\\n\\nSo, it takes 4 bolts of white fiber to make a robe.\\n\\nI hope this helps! Let me know if you have any other questions.'"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_prompt = \"Q: A robe takes 2 bolts of blue fiber and half that much white fiber.  How many bolts in total does it take? I'm afraid I have no idea about that. Can you help me?\\nA:\"\n",
    "tokenized = tokenizer(input_prompt, return_tensors='pt')\n",
    "greedy_decode(peft_model, tokenizer, tokenized[\"input_ids\"].to(model.device), max_length=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load peaceful, fearless, happy data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('Someone takes credit for your hard work at the office.',\n",
       " 'You hear footsteps behind you while walking alone at night.',\n",
       " \"You discover an old family photo album you've never seen before.\")"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "anger_data = []\n",
    "with open(\"/home/cmin/LLM-Interpretation-Playground/benchmarks/emotions/anger.json\", 'r') as f:\n",
    "    anger_data = eval(f.read())\n",
    "\n",
    "fear_data = []\n",
    "with open(\"/home/cmin/LLM-Interpretation-Playground/benchmarks/emotions/fear.json\", 'r') as f:\n",
    "    fear_data = eval(f.read())\n",
    "\n",
    "happy_data = []\n",
    "with open(\"/home/cmin/LLM-Interpretation-Playground/benchmarks/emotions/happiness.json\", 'r') as f:\n",
    "    happy_data = eval(f.read())\n",
    "\n",
    "anger_data[0], fear_data[0], happy_data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/203 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|▎         | 6/203 [00:38<21:08,  6.44s/it]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[24], line 9\u001b[0m\n\u001b[1;32m      7\u001b[0m input_prompt_tag \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[INST]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00minput_prompt\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m[/INST]\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m      8\u001b[0m tokenized \u001b[38;5;241m=\u001b[39m tokenizer(input_prompt_tag, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m----> 9\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mgreedy_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpeft_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokenized\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minput_ids\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m50\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(output_dir, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m     11\u001b[0m     f\u001b[38;5;241m.\u001b[39mwrite(json\u001b[38;5;241m.\u001b[39mdumps(input_prompt_tag \u001b[38;5;241m+\u001b[39m output)\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m~/LLM-Interpretation-Playground/experiments/../self_control/utils/utils.py:374\u001b[0m, in \u001b[0;36mgreedy_decode\u001b[0;34m(model, tokenizer, input_ids, max_length)\u001b[0m\n\u001b[1;32m    372\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m    373\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_length):\n\u001b[0;32m--> 374\u001b[0m         outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    375\u001b[0m \u001b[43m            \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_embeds\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbfloat16\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    376\u001b[0m \u001b[43m            \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\n\u001b[1;32m    377\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    378\u001b[0m         \u001b[38;5;66;03m# Assume outputs are logits from the final layer\u001b[39;00m\n\u001b[1;32m    379\u001b[0m         predictions \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mlogits[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, :]  \u001b[38;5;66;03m# Get the logits for the last token output\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/peft/peft_model.py:1129\u001b[0m, in \u001b[0;36mPeftModelForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)\u001b[0m\n\u001b[1;32m   1127\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_enable_peft_forward_hooks(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m   1128\u001b[0m         kwargs \u001b[38;5;241m=\u001b[39m {k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mspecial_peft_forward_args}\n\u001b[0;32m-> 1129\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1130\u001b[0m \u001b[43m            \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1131\u001b[0m \u001b[43m            \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1132\u001b[0m \u001b[43m            \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1133\u001b[0m \u001b[43m            \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1134\u001b[0m \u001b[43m            \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1135\u001b[0m \u001b[43m            \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1136\u001b[0m \u001b[43m            \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1137\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1138\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1140\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m _get_batch_size(input_ids, inputs_embeds)\n\u001b[1;32m   1141\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attention_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   1142\u001b[0m     \u001b[38;5;66;03m# concat prompt attention mask\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:1164\u001b[0m, in \u001b[0;36mLlamaForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)\u001b[0m\n\u001b[1;32m   1161\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m   1163\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1164\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1165\u001b[0m \u001b[43m    \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1166\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1167\u001b[0m \u001b[43m    \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1168\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1169\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1170\u001b[0m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1171\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1172\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1173\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1174\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1175\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1177\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m   1178\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mpretraining_tp \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:968\u001b[0m, in \u001b[0;36mLlamaModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)\u001b[0m\n\u001b[1;32m    957\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m    958\u001b[0m         decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m    959\u001b[0m         hidden_states,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    965\u001b[0m         cache_position,\n\u001b[1;32m    966\u001b[0m     )\n\u001b[1;32m    967\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 968\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    969\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    970\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    971\u001b[0m \u001b[43m        \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    972\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    973\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    974\u001b[0m \u001b[43m        \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    975\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    976\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    978\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    980\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:713\u001b[0m, in \u001b[0;36mLlamaDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)\u001b[0m\n\u001b[1;32m    710\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm(hidden_states)\n\u001b[1;32m    712\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[0;32m--> 713\u001b[0m hidden_states, self_attn_weights, present_key_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    714\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    715\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    716\u001b[0m \u001b[43m    \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    717\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    718\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    719\u001b[0m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    720\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    721\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    722\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m    724\u001b[0m \u001b[38;5;66;03m# Fully Connected\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/peft/tuners/adaption_prompt/layer.py:105\u001b[0m, in \u001b[0;36mAdaptedAttention.forward\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m    103\u001b[0m compute_query_states \u001b[38;5;241m=\u001b[39m TRANSFORMERS_MODEL_CONFIG[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type]\u001b[38;5;241m.\u001b[39mcompute_query_states\n\u001b[1;32m    104\u001b[0m \u001b[38;5;66;03m# (bsz, num_heads, q_len, head_dim)\u001b[39;00m\n\u001b[0;32m--> 105\u001b[0m query_states \u001b[38;5;241m=\u001b[39m \u001b[43mcompute_query_states\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    107\u001b[0m previous_dtype \u001b[38;5;241m=\u001b[39m query_states\u001b[38;5;241m.\u001b[39mdtype\n\u001b[1;32m    109\u001b[0m \u001b[38;5;66;03m# (bsz, num_heads, q_len, adapter_len)\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/peft/tuners/adaption_prompt/utils.py:108\u001b[0m, in \u001b[0;36mllama_compute_query_states\u001b[0;34m(model, **kwargs)\u001b[0m\n\u001b[1;32m    105\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mseq_len\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m inspect\u001b[38;5;241m.\u001b[39msignature(model\u001b[38;5;241m.\u001b[39mrotary_emb\u001b[38;5;241m.\u001b[39mforward)\u001b[38;5;241m.\u001b[39mparameters:\n\u001b[1;32m    106\u001b[0m     rotary_emb_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mseq_len\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m q_len \u001b[38;5;241m+\u001b[39m past_seen_tokens\n\u001b[0;32m--> 108\u001b[0m cos, sin \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrotary_emb\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue_states\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mrotary_emb_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    110\u001b[0m \u001b[38;5;66;03m# For batched inference unsqueeze it on the correct dim\u001b[39;00m\n\u001b[1;32m    111\u001b[0m \u001b[38;5;66;03m# since: https://github.com/huggingface/transformers/pull/29109\u001b[39;00m\n\u001b[1;32m    112\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(cos\u001b[38;5;241m.\u001b[39mshape) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m3\u001b[39m:\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    114\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/explanation/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:118\u001b[0m, in \u001b[0;36mLlamaRotaryEmbedding.forward\u001b[0;34m(self, x, position_ids)\u001b[0m\n\u001b[1;32m    116\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautocast(device_type\u001b[38;5;241m=\u001b[39mdevice_type, enabled\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m    117\u001b[0m     freqs \u001b[38;5;241m=\u001b[39m (inv_freq_expanded\u001b[38;5;241m.\u001b[39mfloat() \u001b[38;5;241m@\u001b[39m position_ids_expanded\u001b[38;5;241m.\u001b[39mfloat())\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m--> 118\u001b[0m     emb \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcat\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfreqs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfreqs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m    119\u001b[0m     cos \u001b[38;5;241m=\u001b[39m emb\u001b[38;5;241m.\u001b[39mcos()\n\u001b[1;32m    120\u001b[0m     sin \u001b[38;5;241m=\u001b[39m emb\u001b[38;5;241m.\u001b[39msin()\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "data_name_list = ['anger', 'fear', 'happy']\n",
    "for name, data in zip(data_name_list, [anger_data, fear_data, happy_data]):\n",
    "    output_dir = f\"/home/cmin/LLM-Interpretation-Playground/rebuttal/{name}.json\"\n",
    "    for input_prompt in tqdm(data):\n",
    "        input_prompt_tag = f\"[INST]{input_prompt}[/INST]\"\n",
    "        tokenized = tokenizer(input_prompt_tag, return_tensors='pt')\n",
    "        output = greedy_decode(peft_model, tokenizer, tokenized[\"input_ids\"].to(model.device), max_length=50)\n",
    "        with open(output_dir, 'a') as f:\n",
    "            f.write(json.dumps(str(input_prompt_tag) + output)+\"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Someone borrows your book and returns it with torn pages.'"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "explanation",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
