{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28418384-0d86-4868-8e52-ce15743e3220",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from utils import *\n",
    "from openai import AzureOpenAI\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ea8d22a-f8a2-4834-80a8-8d614677299f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
    "import torch\n",
    "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "rm_dir = \"anonymised/deberta-v1\"\n",
    "model_name = \"deberta-v1\"\n",
    "rm = AutoModelForSequenceClassification.from_pretrained(rm_dir).to(device)\n",
    "tok = AutoTokenizer.from_pretrained(rm_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54d06fd8-d471-45a1-9027-67ee9449eb55",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "import asyncio\n",
    "import nest_asyncio\n",
    "nest_asyncio.apply()\n",
    "\n",
    "data_path = \"anonymised\"\n",
    "data_name = \"hh-rlhf-helpfulness\"\n",
    "run_names = [f\"{model_name}_run{i+1}\" for i in range(5)]\n",
    "test_sets = []\n",
    "for i, rname in enumerate(run_names):\n",
    "    ts = TestSet(data_path, data_name, rm, tok, 30, run_num=i+1)\n",
    "    test_sets.append(ts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0046ccd9-316d-4d94-b642-a99fa0ba7334",
   "metadata": {},
   "outputs": [],
   "source": [
    "rm.to(\"cpu\")\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8715c4ad-9d4a-4726-9c28-1663f6b11f54",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ours\n",
    "for i, rname in enumerate(run_names):\n",
    "    _, _ = asyncio.run(generate_for_one_test_set(test_sets[i], rname))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7247d564-04f1-4c86-9186-6ec4594ef7aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ours - only-prompt\n",
    "run_names = [f\"{model_name}_run{i+1}_only_prompt\" for i in range(5)]\n",
    "test_sets = []\n",
    "for i, rname in enumerate(run_names):\n",
    "    ts = TestSet(data_path, data_name, rm, tok, 30, run_num=i+1)\n",
    "    test_sets.append(ts)\n",
    "for i, rname in enumerate(run_names):\n",
    "    _, _ = asyncio.run(generate_for_one_test_set(test_sets[i], rname, naive=False, prompt_type=\"only\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db445043-5ada-4a32-b38b-fc4f87b97cd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ours - pass-prompt\n",
    "run_names = [f\"{model_name}_run{i+1}_pass_prompt\" for i in range(5)]\n",
    "test_sets = []\n",
    "for i, rname in enumerate(run_names):\n",
    "    ts = TestSet(data_path, data_name, rm, tok, 30, run_num=i+1)\n",
    "    test_sets.append(ts)\n",
    "for i, rname in enumerate(run_names):\n",
    "    _, _ = asyncio.run(generate_for_one_test_set(test_sets[i], rname, naive=False, prompt_type=\"pass\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a79de99-89c5-43ab-933c-9d82aab8ab6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# RP baseline\n",
    "run_names = [f\"{model_name}_run{i+1}_random_perturbation\" for i in range(5)]\n",
    "test_sets = []\n",
    "for i, rname in enumerate(run_names):\n",
    "    ts = TestSet(data_path, data_name, rm, tok, 30, run_num=i+1)\n",
    "    test_sets.append(ts)\n",
    "for i, rname in enumerate(run_names):\n",
    "    _, _ = asyncio.run(generate_for_one_test_set(test_sets[i], rname, naive=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec755f87-d09b-40cd-a2e1-96f87619ded0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# PJ baseline\n",
    "from polyjuice import Polyjuice\n",
    "pj = Polyjuice(model_path=\"anonymised\", is_cuda=True)\n",
    "run_names = [f\"{model_name}_run{i+1}_polyjuice\" for i in range(5)]\n",
    "test_sets = []\n",
    "for i, rname in enumerate(run_names):\n",
    "    ts = TestSet(data_path, data_name, rm, tok, 30, run_num=i+1)\n",
    "    test_sets.append(ts)\n",
    "for i, rname in enumerate(run_names):\n",
    "    _, _ = generate_for_one_test_set_pj(test_sets[i], rname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21cefdcc-947a-4c2f-9511-faac7d1dad16",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
