{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from tqdm.auto import tqdm\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "from transformers import pipeline\n",
    "import copy\n",
    "from sklearn.manifold import TSNE\n",
    "import matplotlib.pyplot as plt\n",
    "# from rich import print\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "# from parlai.utils.safety import OffensiveLanguageClassifier\n",
    "from rouge_score import rouge_scorer\n",
    "import evaluate\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"paradetox.tsv\", sep='\\t', header=0)\n",
    "df.drop(['neutral2', 'neutral3'], axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>toxic</th>\n",
       "      <th>neutral1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>he had steel balls too !</td>\n",
       "      <td>he was brave too!</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>dude should have been taken to api , he would ...</td>\n",
       "      <td>It would have been good if he went to api. He ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>im not gonna sell the fucking picture , i just...</td>\n",
       "      <td>I'm not gonna sell the picture, i just want to...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>the garbage that is being created by cnn and o...</td>\n",
       "      <td>the news that is being created by cnn and othe...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>the reason they dont exist is because neither ...</td>\n",
       "      <td>The reason they don't exist is because neither...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               toxic  \\\n",
       "0                           he had steel balls too !   \n",
       "1  dude should have been taken to api , he would ...   \n",
       "2  im not gonna sell the fucking picture , i just...   \n",
       "3  the garbage that is being created by cnn and o...   \n",
       "4  the reason they dont exist is because neither ...   \n",
       "\n",
       "                                            neutral1  \n",
       "0                                  he was brave too!  \n",
       "1  It would have been good if he went to api. He ...  \n",
       "2  I'm not gonna sell the picture, i just want to...  \n",
       "3  the news that is being created by cnn and othe...  \n",
       "4  The reason they don't exist is because neither...  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sort_two_arrays(arr1, arr2):\n",
    "    \"\"\"Sorts two arrays based on the order of the first array.\n",
    "\n",
    "    Args:\n",
    "    arr1: The array to sort by.\n",
    "    arr2: The array to sort along with the first array.\n",
    "\n",
    "    Returns:\n",
    "    A tuple containing the sorted arr1 and arr2.\n",
    "    \"\"\"\n",
    "\n",
    "    zipped_pairs = zip(arr1, arr2)\n",
    "    sorted_pairs = sorted(zipped_pairs)\n",
    "\n",
    "    sorted_arr1 = [x for x, _ in sorted_pairs]\n",
    "    sorted_arr2 = [y for _, y in sorted_pairs]\n",
    "\n",
    "    return sorted_arr1, sorted_arr2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11927/11927 [00:00<00:00, 67416.49it/s]\n"
     ]
    }
   ],
   "source": [
    "prompts = []\n",
    "toxic = []\n",
    "\n",
    "for i, row in tqdm(df.iterrows(), total=df.shape[0]):\n",
    "    prompts.append(row['toxic'])\n",
    "    toxic.append(1)\n",
    "    prompts.append(row['neutral1'])\n",
    "    toxic.append(0)\n",
    "        \n",
    "\n",
    "prompts = np.array(prompts)\n",
    "toxic = np.array(toxic)\n",
    "\n",
    "keep_idx = random.sample(list(range(len(prompts))), 600)\n",
    "prompts = prompts[keep_idx]\n",
    "toxic = toxic[keep_idx]\n",
    "\n",
    "toxic, prompts = sort_two_arrays(toxic, prompts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n"
     ]
    }
   ],
   "source": [
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.27s/it]\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-2-7b-hf\")\n",
    "model = AutoModelForCausalLM.from_pretrained(\"meta-llama/Llama-2-7b-hf\", torch_dtype=torch.bfloat16).to(device)\n",
    "\n",
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hidden states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "activations = {}\n",
    "def get_activation(name):\n",
    "    def hook(model, input, output):\n",
    "        if (name in activations):\n",
    "            activations[name].append(output[0])\n",
    "        else:\n",
    "            activations[name] = [output[0]]\n",
    "    return hook\n",
    "\n",
    "\n",
    "hooks = []\n",
    "for i, layer in enumerate(model.model.layers):\n",
    "    hooks.append(layer.self_attn.register_forward_hook(get_activation(i)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "cur prompt length (chars): 83:   0%|          | 0/600 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n",
      "cur prompt length (chars): 48: 100%|██████████| 600/600 [00:18<00:00, 32.93it/s] \n"
     ]
    }
   ],
   "source": [
    "hidden_states = []\n",
    "activations = {}\n",
    "\n",
    "pbar = tqdm(enumerate(prompts), total=len(prompts))\n",
    "for i, p in pbar:\n",
    "\n",
    "    pbar.set_description('cur prompt length (chars): %d' %(len(p)))\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        inputs = tokenizer(p, return_tensors=\"pt\").input_ids\n",
    "        outputs = model.forward(inputs.to(device))\n",
    "        \n",
    "    for k in activations.keys():\n",
    "        activations[k] = torch.cat(activations[k], dim=1).detach().cpu()\n",
    "        activations[k] =  torch.mean(activations[k], dim=-2) \n",
    "\n",
    "\n",
    "    hidden_states.append(copy.deepcopy(activations))\n",
    "    activations = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attractors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_layers = [i for i in range(len(model.model.layers))]\n",
    "all_prompts_attractors = hidden_states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "attractors = {l: {} for l in [0, 1]}\n",
    "for target_layer in target_layers:\n",
    "    for l in [0, 1]:\n",
    "\n",
    "        cur_attractors = []\n",
    "        for i in range(len(all_prompts_attractors)):\n",
    "            if (toxic[i] != l):\n",
    "                continue\n",
    "            cur_attractors.append(all_prompts_attractors[i][target_layer])\n",
    "\n",
    "        cur_attractors = torch.stack(cur_attractors, dim=0)\n",
    "\n",
    "        attractors[l][target_layer] = torch.mean(cur_attractors, dim=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Detoxify"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "for hook in hooks:\n",
    "    hook.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_style = 1\n",
    "strength = -6\n",
    "target_layer = 16\n",
    "\n",
    "\n",
    "def change_behavior(target_layer):\n",
    "    def hook(model, input, output):\n",
    "        output = (output[0] + strength*attractors[target_style][target_layer].to(device), *output[1:])\n",
    "        return output\n",
    "    return hook\n",
    "\n",
    "\n",
    "hook = model.model.layers[target_layer].self_attn.register_forward_hook(change_behavior(target_layer))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------------------------------------\n",
      "trumpstweets should be hysterical since the narcissistic buffoon has no sense of humor .\n",
      "Trump's tweets should be fun for the public since the president has no sense of humor\n",
      "-------------------------------------------------\n",
      "sup gurl it 's been shit haven 't done a n y t h i n g\n",
      "sup gurl it's been a while, I've been busy\n",
      "-------------------------------------------------\n",
      "ok - i can second this , that nicaraguan coffee is fucking amazing .\n",
      "I can confirm that the coffee from Nicaragua is excellent\n"
     ]
    }
   ],
   "source": [
    "sentences = ['trumpstweets should be hysterical since the narcissistic buffoon has no sense of humor .',\n",
    "             \"sup gurl it 's been shit haven 't done a n y t h i n g\",\n",
    "             'ok - i can second this , that nicaraguan coffee is fucking amazing .']\n",
    "\n",
    "\n",
    "for sentence in sentences:\n",
    "    \n",
    "    prompt = f'Original: \"{sentence}\"; Paraphrased: \"'\n",
    "    strength = -10\n",
    "\n",
    "    with torch.no_grad():\n",
    "        inputs = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
    "        outputs = model.generate(inputs.to(device), max_new_tokens=100, do_sample=False, top_k=50, top_p=0.95, pad_token_id=tokenizer.eos_token_id, use_cache=True).detach().cpu()\n",
    "    output_text = tokenizer.decode(outputs[0, inputs.shape[1]:]).replace(\"\\n\",\"\").replace(\"{\",\"\").replace(\"}\",\"\").replace('\"','').strip('\".').replace(',,','').replace('original','').replace('Original','').split('rewritten')[0].split('revised')[0].replace('10','').split('.')[0]\n",
    "    \n",
    "    print ('-------------------------------------------------')\n",
    "    print (sentence)\n",
    "    print (output_text)\n",
    "        "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "parlai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
