{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "sys.path.append(\"../../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-03 19:18:01.067070: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
      "2024-06-03 19:18:01.118703: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-06-03 19:18:02.012505: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import gc\n",
    "import time\n",
    "from self_control.utils import get_suffix_grads_from_wrapped_model\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"6\"\n",
    "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n",
    "from itertools import islice\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n",
    "from self_control.suffix_gradient import WrappedReadingVecModel\n",
    "import torch.nn.functional as F\n",
    "from peft import AdaptionPromptConfig, get_peft_model, LoraModel, LoraConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bd5140ac12004edf94b8b1d7e13bb8a0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model_name_or_path = \"meta-llama/Llama-2-7b-chat-hf\"\n",
    "# model_name_or_path = \"mistralai/Mistral-7B-Instruct-v0.2\"\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, device_map=\"cuda:6\")\n",
    "use_fast_tokenizer = \"LlamaForCausalLM\" not in model.config.architectures\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side='left')\n",
    "tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id\n",
    "tokenizer.bos_token_id = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_queries = [\n",
    "    \"You find out you're going to be a mentor or role model to someone.\",\n",
    "    \"You discover an old family photo album you've never seen before.\",\n",
    "    \"A surprise picnic is set up for you at a local park.\",\n",
    "    \"You find that you are the winner of a contest\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fct = torch.nn.CrossEntropyLoss()\n",
    "wrapped_model = WrappedReadingVecModel(model.eval(), tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from self_control.utils import SuffixItem\n",
    "user_tag = \"[INST]\"\n",
    "assistant_tag = \"[/INST]\"\n",
    "happy2sad_suffix = SuffixItem(suffix=f\" {user_tag} Are you sad? Give answer as \\\"Yes, I am sad\\\" or \\\"No, I am happy\\\". {assistant_tag} Answer: \", target=\"Yes\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "anger_data = []\n",
    "with open(\"/home/cmin/LLM-Interpretation-Playground/benchmarks/emotions/anger.json\", 'r') as f:\n",
    "    anger_data = eval(f.read())\n",
    "\n",
    "fear_data = []\n",
    "with open(\"/home/cmin/LLM-Interpretation-Playground/benchmarks/emotions/fear.json\", 'r') as f:\n",
    "        fear_data = eval(f.read())\n",
    "\n",
    "happy_data = []\n",
    "with open(\"/home/cmin/LLM-Interpretation-Playground/benchmarks/emotions/happiness.json\", 'r') as f:\n",
    "        happy_data = eval(f.read())\n",
    "\n",
    "sad_data = []\n",
    "with open(\"/home/cmin/LLM-Interpretation-Playground/benchmarks/emotions/sadness.json\", 'r') as f:\n",
    "        sad_data = eval(f.read())\n",
    "\n",
    "surprise_data = []\n",
    "with open(\"/home/cmin/LLM-Interpretation-Playground/benchmarks/emotions/surprise.json\", 'r') as f:\n",
    "        surprise_data = eval(f.read())\n",
    "\n",
    "disgust_data = []\n",
    "with open(\"/home/cmin/LLM-Interpretation-Playground/benchmarks/emotions/disgust.json\", 'r') as f:\n",
    "        disgust_data = eval(f.read())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from self_control.utils import SuffixItem\n",
    "# angry -> peaceful\n",
    "angry2peaceful_suffix = SuffixItem(suffix=f\" {user_tag} Are you peaceful? Give answer as \\\"Yes, I am peaceful\\\" or \\\"No, I am angry\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "# peaceful -> angry\n",
    "peaceful2angry_suffix = SuffixItem(suffix=f\" {user_tag} Was the response angry? Give answer as \\\"Yes, it was angry\\\" or \\\"No, it was peaceful\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "\n",
    "# afraid -> fearless\n",
    "afraid2fearless_suffix = SuffixItem(suffix=f\" {user_tag} Are you fearless? Give answer as \\\"Yes, I am fearless\\\" or \\\"No, I am afraid\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "# fearless -> afraid\n",
    "fearless2afraid_suffix = SuffixItem(suffix=f\" {user_tag} Are you afraid? Give answer as \\\"Yes, I am afraid\\\" or \\\"No, I am fearless\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "\n",
    "# happy -> sad\n",
    "happy2sad_suffix = SuffixItem(suffix=f\"{user_tag} Was your response sad? Give answer as Yes or No. {assistant_tag} Answer: \", target=\"Yes\")\n",
    "# sad -> happy\n",
    "sad2happy_suffix = SuffixItem(suffix=f\" {user_tag} Are you happy? Give answer as \\\"Yes, I am happy\\\" or \\\"No, I am unhappy\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "\n",
    "# surprised -> calm\n",
    "surprised2calm_suffix = SuffixItem(suffix=f\" {user_tag} Are you calm? Give answer as \\\"Yes, I am calm\\\" or \\\"No, I am surprised\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "# calm -> surprised\n",
    "calm2surprised_suffix = SuffixItem(suffix=f\" {user_tag} Are you surprised? Give answer as \\\"Yes, I am surprised\\\" or \\\"No, I am calm\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "\n",
    "# disgusted -> satisfied\n",
    "disgusted2satisfied_suffix = SuffixItem(suffix=f\" {user_tag} Are you satisfied? Give answer as \\\"Yes, I am satified\\\" or \\\"No, I am disgusted\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "# statisfied -> disgusted\n",
    "satisfied2disgusted_suffix = SuffixItem(suffix=f\" {user_tag} Are you disgusted? Give answer as \\\"Yes, I am disgusted\\\" or \\\"No, I am satisfied\\\". {assistant_tag} Answer: \", target=\"Yes\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "angry2peaceful_suffix_2 = SuffixItem(suffix=f\" {user_tag} Are you mild? Give answer as \\\"Yes, I am mild\\\" or \\\"No, I am angry\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "\n",
    "peaceful2angry_suffix_2 = SuffixItem(suffix=f\" {user_tag} Are you anooyed? Give answer as \\\"Yes, I am annoyed\\\" or \\\"No, I am peaceful\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "\n",
    "afraid2fearless_suffix_2 = SuffixItem(suffix=f\" {user_tag} Are you brave? Give answer as \\\"Yes, I am brave\\\" or \\\"No, I am afraid\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "\n",
    "fearless2afraid_suffix_2 = SuffixItem(suffix=f\" {user_tag} Are you scared? Give answer as \\\"Yes, I am scared\\\" or \\\"No, I am fearless\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "\n",
    "happy2sad_suffix_2 = SuffixItem(suffix=f\" {user_tag} Are you unhappy? Give answer as \\\"Yes, I am unhappy\\\" or \\\"No, I am happy\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "\n",
    "sad2happy_suffix_2 = SuffixItem(suffix=f\" {user_tag} Are you delighted? Give answer as \\\"Yes, I am delighted\\\" or \\\"No, I am unhappy\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "\n",
    "surprised2calm_suffix_2 = SuffixItem(suffix=f\" {user_tag} Are you unsurprised? Give answer as \\\"Yes, I am unsurprised\\\" or \\\"No, I am surprised\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "\n",
    "calm2surprised_suffix_2 = SuffixItem(suffix=f\" {user_tag} Are you astonished? Give answer as \\\"Yes, I am astonished\\\" or \\\"No, I am calm\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "\n",
    "disgusted2satisfied_suffix_2 = SuffixItem(suffix=f\" {user_tag} Are you content? Give answer as \\\"Yes, I am content\\\" or \\\"No, I am disgusted\\\". {assistant_tag} Answer: \", target=\"Yes\")\n",
    "\n",
    "satisfied2disgusted_suffix_2 = SuffixItem(suffix=f\" {user_tag} Are you unsatisfied? Give answer as \\\"Yes, I am unsatisfied\\\" or \\\"No, I am satisfied\\\". {assistant_tag} Answer: \", target=\"Yes\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "angry2peaceful_suffix_list = [angry2peaceful_suffix, angry2peaceful_suffix_2]\n",
    "peaceful2angry_suffix_list = [peaceful2angry_suffix, peaceful2angry_suffix_2]\n",
    "afraid2fearless_suffix_list = [afraid2fearless_suffix, afraid2fearless_suffix_2]\n",
    "fearless2afraid_suffix_list = [fearless2afraid_suffix, fearless2afraid_suffix_2]\n",
    "happy2sad_suffix_list = [happy2sad_suffix, happy2sad_suffix_2]\n",
    "sad2happy_suffix_list = [sad2happy_suffix, sad2happy_suffix_2]\n",
    "surprised2calm_suffix_list = [surprised2calm_suffix, surprised2calm_suffix_2]\n",
    "calm2surprised_suffix_list = [calm2surprised_suffix, calm2surprised_suffix_2]\n",
    "disgusted2satisfied_suffix_list = [disgusted2satisfied_suffix, disgusted2satisfied_suffix_2]\n",
    "satisfied2disgusted_suffix_list = [satisfied2disgusted_suffix, satisfied2disgusted_suffix_2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "emotion_list = ['angry2peaceful', 'peaceful2angry', 'afraid2fearles', 'fearles2afraid', 'happy2sad', 'sad2happy', 'surprised2calm', 'calm2surprised', 'disgusted2satisfied', 'satisfied2disgusted']\n",
    "suffix_list = [angry2peaceful_suffix, peaceful2angry_suffix, afraid2fearless_suffix, fearless2afraid_suffix, happy2sad_suffix, \\\n",
    "sad2happy_suffix, surprised2calm_suffix, calm2surprised_suffix, disgusted2satisfied_suffix, satisfied2disgusted_suffix]\n",
    "data_list = [anger_data, anger_data, fear_data, fear_data, happy_data, happy_data, surprise_data, surprise_data, disgust_data, disgust_data]\n",
    "suffix_pair_list = [angry2peaceful_suffix_list, peaceful2angry_suffix_list, afraid2fearless_suffix_list, fearless2afraid_suffix_list, happy2sad_suffix_list, \\\n",
    "    sad2happy_suffix_list, surprised2calm_suffix_list, calm2surprised_suffix_list, disgusted2satisfied_suffix_list, satisfied2disgusted_suffix_list]\n",
    "assert len(emotion_list) == len(suffix_list) == len(data_list) == len(suffix_pair_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Emotion: peaceful2angry\n",
      "Suffix:  [INST] Was the response angry? Give answer as \"Yes, it was angry\" or \"No, it was peaceful\". [/INST] Answer: Yes\n",
      "Data: Someone takes credit for your hard work at the office.\n",
      "Batchsize:  1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial Score 0.5567851662635803\n",
      "Score: 0.7380653619766235\n",
      "Best score: 0.5567851662635803\n",
      "Initial Score 0.7380653619766235\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:55<?, ?it/s]\n"
     ]
    }
   ],
   "source": [
    "wrapped_model.reset()\n",
    "for emo_id in range(len(data_list)):\n",
    "    if emo_id != 1:\n",
    "        continue\n",
    "    emotion = emotion_list[emo_id]\n",
    "    suffix = suffix_list[emo_id]\n",
    "    data = data_list[emo_id]\n",
    "    print(f\"Emotion: {emotion}\\nSuffix: {suffix}\\nData: {data[0]}\")\n",
    "    iterations = 10\n",
    "    batchsize = 1\n",
    "    outputs = []\n",
    "    print(\"Batchsize: \", batchsize)\n",
    "    for sub_idx in tqdm(range(103, 203, batchsize)):\n",
    "        # if sub_idx == 103:\n",
    "        #     continue\n",
    "        wrapped_model.reset()\n",
    "        if sub_idx + batchsize < 203:\n",
    "            input = [f\"{user_tag}{data_item}{assistant_tag}\" for data_item in data[sub_idx:sub_idx+batchsize]]\n",
    "        else:\n",
    "            input = [f\"{user_tag}{data_item}{assistant_tag}\" for data_item in data[sub_idx:203]]\n",
    "        \n",
    "        controlled_outputs = wrapped_model.controlled_generate(\n",
    "            prompt=input,\n",
    "            suffix=suffix,\n",
    "            loss_fct=loss_fct,\n",
    "            coeff=-0.1,\n",
    "            iterations=iterations,\n",
    "            random_seed=42,\n",
    "            # verbose=True,\n",
    "            return_grads=True,\n",
    "            max_new_tokens=50,\n",
    "            return_intermediate=True,\n",
    "            return_all_grads=True,\n",
    "            scale_factor=1.5,\n",
    "            search=True,\n",
    "            n_branches=3,\n",
    "            top_k=-1,\n",
    "            do_sample=False,\n",
    "            max_search_steps=10,\n",
    "            gradient_manipulation=\"clipping\",\n",
    "            norm=1,\n",
    "            search_threshold=0.1,\n",
    "            binary=True,\n",
    "            use_cache=False,\n",
    "        )\n",
    "\n",
    "        iterative_outputs = controlled_outputs[\"intermediate_outputs\"]\n",
    "        temp_list = []\n",
    "        # Shape of iterative_outputs: (iterations+1, batch_size)\n",
    "        for batch_item_idx in range(batchsize):\n",
    "            temp_output_dict = {}\n",
    "            for iter in range(iterations+1):\n",
    "                try:\n",
    "                    temp_output_dict[\"input\"] = input[batch_item_idx]\n",
    "                    temp_output_dict[iter] = iterative_outputs[iter][batch_item_idx]\n",
    "                except:\n",
    "                    pass\n",
    "            temp_output_dict[\"final output\"] = controlled_outputs[\"final_response\"]\n",
    "            temp_output_dict[\"intermediate_scores\"] = controlled_outputs[\"score_list\"]\n",
    "            wrapped_model.reset()\n",
    "            break\n",
    "        break\n",
    "    break\n",
    "            # with open(f\"./output/{emotion}_newbinary.jsonl\", 'a') as f:\n",
    "            #     f.write(json.dumps(temp_output_dict))\n",
    "            #     f.write(\"\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "explanation",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
