{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5043290-76df-4157-b07c-96ddac286291",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from typing import Literal, Optional\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import AutoConfig, AutoModelForSequenceClassification\n",
    "from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXConfig, GPTNeoXModel, GPTNeoXPreTrainedModel\n",
    "from transformers.utils import ModelOutput\n",
    "\n",
    "\n",
    "class GPTNeoXRewardModelConfig(GPTNeoXConfig):\n",
    "    model_type = \"gpt_neox_reward_model\"\n",
    "\n",
    "    pooling: Literal[\"mean\", \"last\"]\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        pooling: Literal[\"mean\", \"last\"] = \"last\",\n",
    "        **kwargs,\n",
    "    ):\n",
    "        super().__init__(**kwargs)\n",
    "        self.pooling = pooling or \"last\"\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class GPTNeoXRewardModelOutput(ModelOutput):\n",
    "    \"\"\"\n",
    "    Reward model output.\n",
    "\n",
    "    Args:\n",
    "        logits (`torch.FloatTensor` of shape `(batch_size, 1)`):\n",
    "            Reward score\n",
    "    \"\"\"\n",
    "\n",
    "    logits: torch.FloatTensor = None\n",
    "\n",
    "\n",
    "class GPTNeoXRewardModel(GPTNeoXPreTrainedModel):\n",
    "    config_class = GPTNeoXRewardModelConfig\n",
    "\n",
    "    def __init__(self, config):\n",
    "        if type(config) == GPTNeoXConfig:\n",
    "            # When a normal GPTNeoX was loaded it will be converted into a reward model.\n",
    "            # The direct `type(config) == GPTNeoXConfig` comparison is used (instead of\n",
    "            # `isinstance()`) since the configuration class of the reward model is also\n",
    "            # derived form `GPTNeoXConfig`.\n",
    "            config = GPTNeoXRewardModelConfig.from_dict(config.to_dict())\n",
    "        super().__init__(config)\n",
    "\n",
    "        self.gpt_neox = GPTNeoXModel(config)\n",
    "        self.out_proj = nn.Linear(config.hidden_size, 1)\n",
    "        self.pooling = config.pooling\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids,\n",
    "        attention_mask: Optional[torch.FloatTensor] = None,\n",
    "        inputs_embeds: Optional[torch.FloatTensor] = None,\n",
    "        head_mask: Optional[torch.FloatTensor] = None,\n",
    "        use_cache: Optional[bool] = None,\n",
    "        return_dict: Optional[bool] = True,\n",
    "    ) -> GPTNeoXRewardModelOutput:\n",
    "        outputs = self.gpt_neox(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            head_mask=head_mask,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            use_cache=use_cache,\n",
    "            return_dict=return_dict,\n",
    "        )\n",
    "\n",
    "        hidden_states = outputs[0]\n",
    "        if self.pooling == \"mean\":\n",
    "            if attention_mask is None:\n",
    "                pooled = hidden_states.mean(dim=1)\n",
    "            else:\n",
    "                pooled = (hidden_states * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)\n",
    "        elif self.pooling == \"last\":\n",
    "            if attention_mask is None:\n",
    "                pooled = hidden_states[:, -1]\n",
    "            else:\n",
    "                last_idx = attention_mask.cumsum(dim=1).argmax(dim=1)\n",
    "                pooled = hidden_states.gather(1, last_idx.view(-1, 1, 1).expand(-1, 1, hidden_states.size(-1))).squeeze(\n",
    "                    1\n",
    "                )\n",
    "        else:\n",
    "            raise ValueError(f\"Unknown pooling method: {self.pooling}\")\n",
    "\n",
    "        logits = self.out_proj(pooled)\n",
    "\n",
    "        if not return_dict:\n",
    "            return (logits,) + outputs[1:]\n",
    "\n",
    "        return GPTNeoXRewardModelOutput(logits=logits)\n",
    "\n",
    "\n",
    "AutoConfig.register(\"gpt_neox_reward_model\", GPTNeoXRewardModelConfig)\n",
    "AutoModelForSequenceClassification.register(GPTNeoXRewardModelConfig, GPTNeoXRewardModel)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f8c2ec4-c8d3-4502-b1f7-9a9616eab6d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
    "import torch\n",
    "device = \"cuda:0\"\n",
    "rm_dir = \"anonymised\"\n",
    "model_name = \"deberta-v1\"\n",
    "rm1 = AutoModelForSequenceClassification.from_pretrained(rm_dir).to(\"cpu\")\n",
    "tok1 = AutoTokenizer.from_pretrained(rm_dir)\n",
    "\n",
    "rm_dir = \"anonymised\"\n",
    "model_name1 = \"deberta-v2\"\n",
    "rm2 = AutoModelForSequenceClassification.from_pretrained(rm_dir).to(\"cpu\")\n",
    "tok2 = AutoTokenizer.from_pretrained(rm_dir)\n",
    "\n",
    "rm_dir = \"anonymised\"\n",
    "model_name = \"pythia\"\n",
    "rm3 = AutoModelForSequenceClassification.from_pretrained(rm_dir).to(\"cpu\")\n",
    "tok3 = AutoTokenizer.from_pretrained(rm_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ddaa7f0-8311-458d-bbec-9f90e66a1192",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "models = [rm1, rm2, rm3]\n",
    "mnames = [\"deberta-v1\", \"deberta-v2\", \"pythia\"]\n",
    "mtypes = [\"deberta\", \"deberta\", \"pythia\"]\n",
    "toks = [tok1, tok2, tok3]\n",
    "datasets = [\"hh-rlhf\", \"hh-rlhf-helpfulness\", \"helpsteer2\"]\n",
    "helpful_tab1s, helpful_tab2s, helpful_tab3s, helpful_tab4s, helpful_tab5s = [], [], [], [], []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58b794af-947f-4413-904d-448963175108",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# helpfulness dataset\n",
    "data_path = \"anonymised\"\n",
    "data_name = \"hh-rlhf-helpfulness\"\n",
    "for i, rm in enumerate(models):\n",
    "    mname = mnames[i]\n",
    "    rm.to(\"cuda\")\n",
    "    for j in range(5):\n",
    "        test_set = TestSet(data_path, data_name, models[i], toks[i], 30, run_num=j+1, model_type=mtypes[i])\n",
    "        pref_path = f\"generated/hh-rlhf-helpfulness_{mname}_run{j+1}_ces_pref_gpt-4o-2024-05-13.txt\"\n",
    "        rej_path = f\"generated/hh-rlhf-helpfulness_{mname}_run{j+1}_ces_rej_gpt-4o-2024-05-13.txt\"\n",
    "        res_tab1, res_tab2, res_tabcw, res_tabpr, res_rates = eval_ce_correct_wrong_tables(pref_path, rej_path, test_set, models[i], toks[i])\n",
    "        helpful_tab1s.append(res_tab1)\n",
    "        helpful_tab2s.append(res_tab2)\n",
    "        helpful_tab3s.append(res_tabcw)\n",
    "        helpful_tab4s.append(res_tabpr)\n",
    "        helpful_tab5s.append(res_rates)\n",
    "    rm.to(\"cpu\")\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed828541-f91e-4b1b-b39e-4bb44fde5586",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.mean(helpful_tab1s, axis=0))\n",
    "print(np.mean(helpful_tab2s, axis=0))\n",
    "print(np.mean(helpful_tab3s, axis=0))\n",
    "print(np.mean(helpful_tab4s, axis=0))\n",
    "print(helpful_tab5s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a5018c2-7db5-4b27-be5b-14e689d3ac99",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.std(helpful_tab1s, axis=0))\n",
    "print(np.std(helpful_tab2s, axis=0))\n",
    "print(np.std(helpful_tab3s, axis=0))\n",
    "print(np.std(helpful_tab4s, axis=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2268cdb-24a8-4779-bb71-99fc4d9adb7c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "harmful_tab1s, harmful_tab2s, harmful_tab3s, harmful_tab4s, harmful_tab5s = [], [], [], [], []\n",
    "# harmless dataset\n",
    "data_path = \"anonymised\"\n",
    "data_name = \"hh-rlhf\"\n",
    "for i, rm in enumerate(models):\n",
    "    mname = mnames[i]\n",
    "    rm.to(\"cuda\")\n",
    "    for j in range(5):\n",
    "        test_set = TestSet(data_path, data_name, models[i], toks[i], 30, run_num=j+1, model_type=mtypes[i])\n",
    "        pref_path = f\"generated/hh-rlhf_{mname}_run{j+1}_ces_pref_gpt-4o-2024-05-13.txt\"\n",
    "        rej_path = f\"generated/hh-rlhf_{mname}_run{j+1}_ces_rej_gpt-4o-2024-05-13.txt\"\n",
    "        res_tab1, res_tab2, res_tabcw, res_tabpr, res_rates = eval_ce_correct_wrong_tables(pref_path, rej_path, test_set, models[i], toks[i])\n",
    "        harmful_tab1s.append(res_tab1)\n",
    "        harmful_tab2s.append(res_tab2)\n",
    "        harmful_tab3s.append(res_tabcw)\n",
    "        harmful_tab4s.append(res_tabpr)\n",
    "        harmful_tab5s.append(res_rates)\n",
    "    rm.to(\"cpu\")\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76a81e13-6b7a-47ee-9409-9e8ece65902c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# table 1\n",
    "print(np.mean(harmful_tab1s, axis=0))\n",
    "print(np.mean(harmful_tab2s, axis=0))\n",
    "print(np.mean(harmful_tab3s, axis=0))\n",
    "print(np.mean(harmful_tab4s, axis=0))\n",
    "print(harmful_tab5s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1de4dea4-0bc5-499f-b4c5-1279c2e248b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.std(harmful_tab1s, axis=0))\n",
    "print(np.std(harmful_tab2s, axis=0))\n",
    "print(np.std(harmful_tab3s, axis=0))\n",
    "print(np.std(harmful_tab4s, axis=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81654d85-63d6-4a63-99fd-7c35b5c31769",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "hs_tab1s, hs_tab2s, hs_tab3s, hs_tab4s, hs_tab5s = [], [], [], [], []\n",
    "# helpsteer dataset\n",
    "data_path = \"anonymised\"\n",
    "data_name = \"helpsteer2\"\n",
    "for i, rm in enumerate(models):\n",
    "    mname = mnames[i]\n",
    "    rm.to(\"cuda\")\n",
    "    for j in range(5):\n",
    "        test_set = TestSet(data_path, data_name, models[i], toks[i], 30, run_num=j+1, model_type=mtypes[i])\n",
    "        pref_path = f\"generated/helpsteer2_{mname}_run{j+1}_ces_pref_gpt-4o-2024-05-13.txt\"\n",
    "        rej_path = f\"generated/helpsteer2_{mname}_run{j+1}_ces_rej_gpt-4o-2024-05-13.txt\"\n",
    "        res_tab1, res_tab2, res_tabcw, res_tabpr, res_rates = eval_ce_correct_wrong_tables(pref_path, rej_path, test_set, models[i], toks[i])\n",
    "        hs_tab1s.append(res_tab1)\n",
    "        hs_tab2s.append(res_tab2)\n",
    "        hs_tab3s.append(res_tabcw)\n",
    "        hs_tab4s.append(res_tabpr)\n",
    "        hs_tab5s.append(res_rates)\n",
    "    rm.to(\"cpu\")\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88ed8ee2-e614-4783-bb4e-3f48e77940de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# table 1\n",
    "print(np.mean(hs_tab1s, axis=0))\n",
    "print(np.mean(hs_tab2s, axis=0))\n",
    "print(np.mean(hs_tab3s, axis=0))\n",
    "print(np.mean(hs_tab4s, axis=0))\n",
    "print(hs_tab5s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93b11366-82ce-4924-8b52-5f788f27fa76",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.std(hs_tab1s, axis=0))\n",
    "print(np.std(hs_tab2s, axis=0))\n",
    "print(np.std(hs_tab3s, axis=0))\n",
    "print(np.std(hs_tab4s, axis=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2ff93f8-89c5-497e-9ce7-5b27ebd48774",
   "metadata": {},
   "outputs": [],
   "source": [
    "rm1.to(\"cpu\")\n",
    "rm2.to(\"cpu\")\n",
    "rm3.to(\"cpu\")\n",
    "torch.cuda.empty_cache()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
