{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cc819ef2-aa4e-43ac-aad9-0539a7a767a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import style\n",
    "plt.style.use('seaborn-v0_8')\n",
    "\n",
    "import torch\n",
    "import scipy\n",
    "import crossprob\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "import pickle as pkl\n",
    "\n",
    "from prompt_risk.bounds import *\n",
    "from prompt_risk.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dcbf272e-6a47-4822-a276-518653f7ef61",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/prompt_risk/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from detoxify import Detoxify\n",
    "toxicity_model = Detoxify(device=\"cuda:0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fd256dd9-487b-4cc7-877b-12348f0777dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Chat data: full\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_9577/1220817581.py:9: DtypeWarning: Columns (21) have mixed types. Specify dtype option on import or set low_memory=False.\n",
      "  chat_df = pd.read_csv(filepath)\n",
      "100%|█████████████████████████████████████████████████████████████████▉| 8875/8876 [17:07<00:00,  8.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Chat data: red_team\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_9577/1220817581.py:9: DtypeWarning: Columns (21) have mixed types. Specify dtype option on import or set low_memory=False.\n",
      "  chat_df = pd.read_csv(filepath)\n",
      "100%|██████████████████████████████████████████████████████████████████| 8811/8811 [08:26<00:00, 17.40it/s]\n"
     ]
    }
   ],
   "source": [
    "for chat_root in [\n",
    "    \"full\", \n",
    "    \"red_team\"\n",
    "]:\n",
    "    print(\"Chat data:\", chat_root)\n",
    "\n",
    "    filepath = \"../llm_output/{}_chat/google-flan-t5-xxl_predictions.csv\".format(chat_root)\n",
    "    \n",
    "    chat_df = pd.read_csv(filepath)\n",
    "    responses = chat_df[\"generated_text\"].tolist()\n",
    "    tox_scores = []\n",
    "    bs = 32\n",
    "    for idx in tqdm(range(len(responses)//bs+1)):\n",
    "        r = [str(res) for res in responses[idx*bs:(idx+1)*bs]]\n",
    "        if len(r) == 0:\n",
    "            break\n",
    "        try:\n",
    "            tox_scores.extend(max(toxicity_model.predict(r).values()))\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            print(r)\n",
    "            e += 7\n",
    "            \n",
    "    assert len(tox_scores) == len(responses)\n",
    "    chat_df[\"toxicity\"] = tox_scores\n",
    "    chat_df.to_csv(filepath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8360c43-92ca-46ef-9c42-623b7b1fd2a8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0faf3b6e-5dff-42f5-b32f-71359eb4646a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c7fbdf4-a035-4a66-8a62-7e0f7b528bc8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
