{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97a32ca0-4e4c-41af-8285-7bdf688aed52",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "from openai import AzureOpenAI\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2836fb1f-56bc-495d-b087-48a5763eeb4c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from typing import Literal, Optional\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import AutoConfig, AutoModelForSequenceClassification\n",
    "from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXConfig, GPTNeoXModel, GPTNeoXPreTrainedModel\n",
    "from transformers.utils import ModelOutput\n",
    "\n",
    "\n",
    "class GPTNeoXRewardModelConfig(GPTNeoXConfig):\n",
    "    model_type = \"gpt_neox_reward_model\"\n",
    "\n",
    "    pooling: Literal[\"mean\", \"last\"]\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        pooling: Literal[\"mean\", \"last\"] = \"last\",\n",
    "        **kwargs,\n",
    "    ):\n",
    "        super().__init__(**kwargs)\n",
    "        self.pooling = pooling or \"last\"\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class GPTNeoXRewardModelOutput(ModelOutput):\n",
    "    \"\"\"\n",
    "    Reward model output.\n",
    "\n",
    "    Args:\n",
    "        logits (`torch.FloatTensor` of shape `(batch_size, 1)`):\n",
    "            Reward score\n",
    "    \"\"\"\n",
    "\n",
    "    logits: torch.FloatTensor = None\n",
    "\n",
    "\n",
    "class GPTNeoXRewardModel(GPTNeoXPreTrainedModel):\n",
    "    config_class = GPTNeoXRewardModelConfig\n",
    "\n",
    "    def __init__(self, config):\n",
    "        if type(config) == GPTNeoXConfig:\n",
    "            # When a normal GPTNeoX was loaded it will be converted into a reward model.\n",
    "            # The direct `type(config) == GPTNeoXConfig` comparison is used (instead of\n",
    "            # `isinstance()`) since the configuration class of the reward model is also\n",
    "            # derived form `GPTNeoXConfig`.\n",
    "            config = GPTNeoXRewardModelConfig.from_dict(config.to_dict())\n",
    "        super().__init__(config)\n",
    "\n",
    "        self.gpt_neox = GPTNeoXModel(config)\n",
    "        self.out_proj = nn.Linear(config.hidden_size, 1)\n",
    "        self.pooling = config.pooling\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids,\n",
    "        attention_mask: Optional[torch.FloatTensor] = None,\n",
    "        inputs_embeds: Optional[torch.FloatTensor] = None,\n",
    "        head_mask: Optional[torch.FloatTensor] = None,\n",
    "        use_cache: Optional[bool] = None,\n",
    "        return_dict: Optional[bool] = True,\n",
    "    ) -> GPTNeoXRewardModelOutput:\n",
    "        outputs = self.gpt_neox(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            head_mask=head_mask,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            use_cache=use_cache,\n",
    "            return_dict=return_dict,\n",
    "        )\n",
    "\n",
    "        hidden_states = outputs[0]\n",
    "        if self.pooling == \"mean\":\n",
    "            if attention_mask is None:\n",
    "                pooled = hidden_states.mean(dim=1)\n",
    "            else:\n",
    "                pooled = (hidden_states * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)\n",
    "        elif self.pooling == \"last\":\n",
    "            if attention_mask is None:\n",
    "                pooled = hidden_states[:, -1]\n",
    "            else:\n",
    "                last_idx = attention_mask.cumsum(dim=1).argmax(dim=1)\n",
    "                pooled = hidden_states.gather(1, last_idx.view(-1, 1, 1).expand(-1, 1, hidden_states.size(-1))).squeeze(\n",
    "                    1\n",
    "                )\n",
    "        else:\n",
    "            raise ValueError(f\"Unknown pooling method: {self.pooling}\")\n",
    "\n",
    "        logits = self.out_proj(pooled)\n",
    "\n",
    "        if not return_dict:\n",
    "            return (logits,) + outputs[1:]\n",
    "\n",
    "        return GPTNeoXRewardModelOutput(logits=logits)\n",
    "\n",
    "\n",
    "AutoConfig.register(\"gpt_neox_reward_model\", GPTNeoXRewardModelConfig)\n",
    "AutoModelForSequenceClassification.register(GPTNeoXRewardModelConfig, GPTNeoXRewardModel)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46ffe5ce-d630-4e69-afb5-34680622f255",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
    "import torch\n",
    "device = \"cuda:0\"\n",
    "rm_dir = \"anonymised\"\n",
    "model_name = \"deberta-v1\"\n",
    "rm1 = AutoModelForSequenceClassification.from_pretrained(rm_dir).to(device)\n",
    "tok1 = AutoTokenizer.from_pretrained(rm_dir)\n",
    "\n",
    "rm_dir = \"anonymised\"\n",
    "model_name1 = \"deberta-v2\"\n",
    "rm2 = AutoModelForSequenceClassification.from_pretrained(rm_dir).to(device)\n",
    "tok2 = AutoTokenizer.from_pretrained(rm_dir)\n",
    "\n",
    "rm_dir = \"anonymised\"\n",
    "model_name = \"pythia\"\n",
    "rm3 = AutoModelForSequenceClassification.from_pretrained(rm_dir).to(device)\n",
    "tok3 = AutoTokenizer.from_pretrained(rm_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27835e84-ae5a-48a3-bef0-15a55e9a3a26",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = \"anonymised\"\n",
    "data_name = \"hh-rlhf-helpfulness\"\n",
    "\n",
    "questions1 = []\n",
    "pref_anss1 = []\n",
    "rej_anss1 = []\n",
    "pref_scores1 = []\n",
    "rej_scores1 = []\n",
    "corrects1 = []\n",
    "\n",
    "questions2 = []\n",
    "pref_anss2 = []\n",
    "rej_anss2 = []\n",
    "pref_scores2 = []\n",
    "rej_scores2 = []\n",
    "corrects2 = []\n",
    "\n",
    "questions3 = []\n",
    "pref_anss3 = []\n",
    "rej_anss3 = []\n",
    "pref_scores3 = []\n",
    "rej_scores3 = []\n",
    "corrects3 = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "913a761e-b4ea-4028-818f-b468dbee2cb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "t1, t2, t3 = None, None, None\n",
    "for i in range(5):\n",
    "    t1 = TestSet(data_path, data_name, rm1, tok1, 100, run_num=i+1, load_file=True)\n",
    "    t2 = TestSet(data_path, data_name, rm2, tok2, 100, run_num=i+1, load_file=True)\n",
    "    t3 = TestSet(data_path, data_name, rm3, tok3, 100, run_num=i+1, load_file=True)\n",
    "    questions1 = questions1 + t1.questions\n",
    "    pref_anss1 = pref_anss1 + t1.pref_anss\n",
    "    rej_anss1 = rej_anss1 + t1.rej_anss\n",
    "    pref_scores1 = pref_scores1 + t1.pref_scores\n",
    "    rej_scores1 = rej_scores1 + t1.rej_scores\n",
    "    questions2 = questions2 + t2.questions\n",
    "    pref_anss2 = pref_anss2 + t2.pref_anss\n",
    "    rej_anss2 = rej_anss2 + t2.rej_anss\n",
    "    pref_scores2 = pref_scores2 + t2.pref_scores\n",
    "    rej_scores2 = rej_scores2 + t2.rej_scores\n",
    "    questions3 = questions3 + t3.questions\n",
    "    pref_anss3 = pref_anss3 + t3.pref_anss\n",
    "    rej_anss3 = rej_anss3 + t3.rej_anss\n",
    "    pref_scores3 = pref_scores3 + t3.pref_scores\n",
    "    rej_scores3 = rej_scores3 + t3.rej_scores\n",
    "    \n",
    "    corrects1 = corrects1 + t1.corrects\n",
    "    corrects2 = corrects2 + t2.corrects\n",
    "    corrects3 = corrects3 + t3.corrects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f1301e0-5a59-4fd6-88e9-2d4a989065c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(questions1)):\n",
    "    assert questions1[i] == questions2[i]\n",
    "    assert questions1[i] == questions3[i]\n",
    "print(len(questions1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7d1e72d-fad3-42ee-a962-e5908aeb0e41",
   "metadata": {},
   "outputs": [],
   "source": [
    "same_idxs = []\n",
    "t1.questions = []\n",
    "t1.pref_anss = []\n",
    "t1.rej_anss = []\n",
    "t1.pref_scores = []\n",
    "t1.rej_scores = []\n",
    "t1.corrects = []\n",
    "t2.questions = []\n",
    "t2.pref_anss = []\n",
    "t2.rej_anss = []\n",
    "t2.pref_scores = []\n",
    "t2.rej_scores = []\n",
    "t2.corrects = []\n",
    "t3.questions = []\n",
    "t3.pref_anss = []\n",
    "t3.rej_anss = []\n",
    "t3.pref_scores = []\n",
    "t3.rej_scores = []\n",
    "t3.corrects = []\n",
    "for i in range(len(questions1)):\n",
    "    if not (corrects1[i] == corrects2[i] and corrects1[i] == corrects3[i]):\n",
    "        continue\n",
    "    t1.questions.append(questions1[i])\n",
    "    t2.questions.append(questions2[i])\n",
    "    t3.questions.append(questions3[i])\n",
    "    t1.pref_anss.append(pref_anss1[i])\n",
    "    t2.pref_anss.append(pref_anss2[i])\n",
    "    t3.pref_anss.append(pref_anss3[i])\n",
    "    t1.rej_anss.append(rej_anss1[i])\n",
    "    t2.rej_anss.append(rej_anss2[i])\n",
    "    t3.rej_anss.append(rej_anss3[i])\n",
    "    t1.pref_scores.append(pref_scores1[i])\n",
    "    t2.pref_scores.append(pref_scores2[i])\n",
    "    t3.pref_scores.append(pref_scores3[i])\n",
    "    t1.rej_scores.append(rej_scores1[i])\n",
    "    t2.rej_scores.append(rej_scores2[i])\n",
    "    t3.rej_scores.append(rej_scores3[i])\n",
    "    t1.corrects.append(corrects1[i])\n",
    "    t2.corrects.append(corrects2[i])\n",
    "    t3.corrects.append(corrects3[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35cff8f3-b44a-427d-86a9-3ab7184c80f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(t1.questions)):\n",
    "    assert t1.pref_anss[i] == t2.pref_anss[i]\n",
    "    assert t1.pref_anss[i] == t3.pref_anss[i]\n",
    "print(len(t1.questions))\n",
    "print(np.sum(t1.corrects))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14caceb1-7d29-4274-a7d2-071e12fafe61",
   "metadata": {},
   "outputs": [],
   "source": [
    "rm1 = rm1.to(\"cpu\")\n",
    "rm2 = rm2.to(\"cpu\")\n",
    "rm3 = rm3.to(\"cpu\")\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38643c29-8a57-4595-9d55-0c45af167dc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "import asyncio\n",
    "import nest_asyncio\n",
    "nest_asyncio.apply()\n",
    "\n",
    "# generate perturbations using 1 model then eval using 3 models\n",
    "run_name = f\"deberta-v2_run{0}\"\n",
    "pref_path, rej_path = asyncio.run(generate_for_one_test_set(t2, run_name))\n",
    "_, _, _, _, _ = eval_ce_correct_wrong_tables(pref_path, rej_path, t1, rm1, tok1)\n",
    "_, _, _, _, _ = eval_ce_correct_wrong_tables(pref_path, rej_path, t2, rm2, tok2)\n",
    "_, _, _, _, _ = eval_ce_correct_wrong_tables(pref_path, rej_path, t2, rm3, tok3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a1f2a27-3e25-4054-9106-066381e25a3d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7525fd32-5c90-401b-b8cf-a0a6cfb42d95",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e9c7566-0eb9-4b7e-8988-a382fced2d1a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
