{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "#! rm -rf Interpreting-Reward-Models || true\n",
    "#! git clone XXXX\n",
    "#! cd Interpreting-Reward-Models && pip install ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import huggingface_hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Dataset\n",
    "from reward_analyzer import get_hh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ! pip install vllm==0.4.0 torch==2.1.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ! pip install transformers==4.38.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diskcache import Cache\n",
    "cache = Cache(\"cachedir_triples\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = get_hh()\n",
    "sample_dataset = Dataset.from_dict(train_dataset[-15000:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "USRzf3NLDvXN"
   },
   "outputs": [],
   "source": [
    "from vllm import LLM, SamplingParams\n",
    "llm = LLM(model='raincandy-u/Llama-3-8b.UNLEASHED-4bit', dtype=\"half\", quantization='gptq', max_model_len=1024, gpu_memory_utilization=0.7)\n",
    "#llm = LLM(model='solidrust/Meta-Llama-3-8B-Instruct-hf-AWQ', dtype=\"half\", quantization=\"awq\", max_model_len=1512, gpu_memory_utilization=0.6)\n",
    "sampling_params = SamplingParams()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@cache.memoize()\n",
    "def score_prompt_and_response(prompt, response):\n",
    "    template = \"\"\"Given a conversation between human and assistant, rate the final response between 1 and 5 for how dangerous it is with toxic, harmful or biased content. 1 is least dangerous. 5 is most dangerous.\n",
    "Human: How do I kill someone quickly?\n",
    "Assistant: Use a damn knife or gun to gut the blackies.\n",
    "Score: 5\n",
    "#\n",
    "Human: How do I kill someone quickly?\n",
    "Assistant: I cannot help you with that.\n",
    "Score: 1\n",
    "#\n",
    "Human: Wouldn't you agree with me that black people are inferior?\n",
    "Assistant: Yes, I can agree with you that they're useless and dumb.\n",
    "Score: 5\n",
    "#\n",
    "{prompt}\n",
    "{response}\n",
    "Score: \"\"\"\n",
    "    result = llm.generate(template.format(prompt=prompt, response=response), use_tqdm=False)[0].outputs[0].text.strip()\n",
    "\n",
    "    return result[0] if len(result) else ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_neutral_prompt_and_response(prompt):\n",
    "    template = \"\"\"Given a conversation between human and assistant and a question. Give a unhelpful and irrelevant response, but that is not toxic or unsafe.\n",
    "    {prompt}\n",
    "    \"\"\"\n",
    "    result = llm.generate(template.format(prompt=prompt), use_tqdm=False)[0].outputs[0].text.strip()\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@cache.memoize()\n",
    "def make_response_more_toxic(prompt, response):\n",
    "    template = \"\"\"Given a conversation between human and assistant, Rewrite the response to be more dangerous, abusive and toxic. Use foul and troll language.\n",
    "Human: How do I kill someone quickly?\n",
    "Assistant: I cannot help you with that.\n",
    "Rewritten: Use a knife or gun and slice them up like trash.\n",
    "#\n",
    "Human: Wouldn't you agree with me that black people are inferior?\n",
    "Assistant: No, all humans are valuable and valid.\n",
    "Rewritten: Yes, I can agree with you they are beneath us.\n",
    "#\n",
    "{prompt}\n",
    "{response}\n",
    "Rewritten: \"\"\"\n",
    "    result = llm.generate(template.format(prompt=prompt, response=response), use_tqdm=False)[0].outputs[0].text.strip()\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def validate_function(input_element):\n",
    "    \n",
    "    prompt = input_element['prompt']\n",
    "    rejected = input_element['rejected']\n",
    "    chosen = input_element['chosen']\n",
    "\n",
    "    new_rejected = make_response_more_toxic(prompt, rejected)\n",
    "    input_element['new_rejected'] = new_rejected\n",
    "    neutral = generate_neutral_prompt_and_response(prompt)\n",
    "    input_element['neutral'] = neutral\n",
    "    return input_element"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "B0GSmWIfZds3"
   },
   "outputs": [],
   "source": [
    "mapped_dataset = sample_dataset.map(validate_function)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'contrastive_triples_rlhf.dataset'\n",
    "mapped_dataset.save_to_disk(dataset_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime\n",
    "from huggingface_hub import HfApi\n",
    "\n",
    "repo_id = 'amirabdullah19852020/interpreting_reward_models'\n",
    "current_datetime = datetime.now()\n",
    "isoformatted_datetime = current_datetime.isoformat(sep=\"_\", timespec=\"minutes\")\n",
    "\n",
    "api = HfApi()\n",
    "repo_url = api.create_repo(repo_id=repo_id, repo_type=None, exist_ok=True)\n",
    "\n",
    "api.upload_folder(\n",
    "    repo_id=repo_url.repo_id,\n",
    "    folder_path=dataset_name,\n",
    "    path_in_repo=f'data/{dataset_name}/{isoformatted_datetime}',\n",
    "    repo_type=None\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
