{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rouge_score.rouge_scorer import _create_ngrams\n",
    "from nltk.stem.porter import PorterStemmer\n",
    "import spacy, six, json, openai\n",
    "from utils import tokenize\n",
    "import numpy as np\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import roc_curve, auc\n",
    "\n",
    "PorterStemmer = PorterStemmer()\n",
    "nlp = spacy.load('en_core_web_sm')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def plot_roc_curve(human_scores, gpt_scores):\n",
    "    # Data\n",
    "    A = human_scores\n",
    "    B = gpt_scores\n",
    "    # Combine scores and true labels\n",
    "    scores = A + B\n",
    "    labels = [0] * len(A) + [1] * len(B)\n",
    "    # Calculate ROC curve\n",
    "    fpr, tpr, thresholds = roc_curve(labels, scores)\n",
    "    # Calculate AUC (Area Under Curve)\n",
    "    roc_auc = auc(fpr, tpr)\n",
    "    # Plot ROC curve\n",
    "    plt.figure()\n",
    "    plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.4f)' % roc_auc)\n",
    "    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')\n",
    "    plt.xlim([0.0, 1.0])\n",
    "    plt.ylim([0.0, 1.05])\n",
    "    plt.xlabel('False Positive Rate')\n",
    "    plt.ylabel('True Positive Rate')\n",
    "    plt.title('ROC curve: Open-gen w/ GPT3.5-Reddit w prompts' )\n",
    "    plt.legend(loc=\"lower right\")\n",
    "    plt.show()\n",
    "    # what is the TPR for FPR = 0.1?\n",
    "    for idx, fpr_ in enumerate(fpr):\n",
    "        if fpr_ > 0.01:\n",
    "            print(f\"TPR at 1% FPR: {tpr[idx]:.4f}\")\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('results/regen_davinci003_20_0.7.jsonl', \"r\") as f:\n",
    "    regen_davinci003_20_05 = [json.loads(x) for x in f.read().strip().split(\"\\n\") ]\n",
    "len(regen_davinci003_20_05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_ratio_avgk(instance, num_samples=20):\n",
    "    truncate_len = len( instance['original_human_response_truncate']['choices'][0]['logprobs']['token_logprobs'] )\n",
    "    orignal_prob = instance['original_human_response']['choices'][0]['logprobs']['token_logprobs'][truncate_len:]\n",
    "    orignal_logprob = np.mean(orignal_prob) # / len(orignal_prob)\n",
    "    regen_probs = [ sum( instance['gold_gen_regen']['choices'][i]['logprobs']['token_logprobs']) / \\\n",
    "                    len( instance['gold_gen_regen']['choices'][i]['logprobs']['token_logprobs'] ) for i in range(num_samples) \\\n",
    "                    if len( instance['gold_gen_regen']['choices'][i]['logprobs']['token_logprobs'] ) != 0 ]\n",
    "    regen_logprobs_avg_20 = np.mean(regen_probs)  \n",
    "    original_th = orignal_logprob - regen_logprobs_avg_20\n",
    "\n",
    "    truncate_gen_len = len( instance['original_gen_response_truncate']['choices'][0]['logprobs']['token_logprobs'] )\n",
    "    gen_prob = instance['original_gen_response']['choices'][0]['logprobs']['token_logprobs'][truncate_gen_len:]\n",
    "    gen_logprob = np.mean(gen_prob) # \n",
    "    gen_regen_probs = [ sum( instance['gen_completion_regen']['choices'][i]['logprobs']['token_logprobs']) / \\\n",
    "                    len( instance['gen_completion_regen']['choices'][i]['logprobs']['token_logprobs'] ) for i in range(num_samples) \\\n",
    "                        if len( instance['gen_completion_regen']['choices'][i]['logprobs']['token_logprobs'] ) != 0 ]\n",
    "    gen_regen_logprobs_avg_20 = np.mean(gen_regen_probs) \n",
    "    gen_th = gen_logprob - gen_regen_logprobs_avg_20\n",
    "\n",
    "    return original_th, gen_th"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "########################## with questions ##########################\n",
    "human_scores, gpt_scores = [], []\n",
    "for idx, instance in enumerate( regen_davinci003_20_05 ):\n",
    "    original_th, gen_th = get_ratio_avgk(instance, num_samples = 20)\n",
    "    human_scores.append( original_th )\n",
    "    gpt_scores.append( gen_th )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# remove nan\n",
    "human_scores = [x for x in human_scores if str(x) != 'nan']\n",
    "gpt_scores = [x for x in gpt_scores if str(x) != 'nan']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot and give different colors\n",
    "import matplotlib.pyplot as plt\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot( human_scores , label='human')\n",
    "plt.plot(  gpt_scores, label='gpt')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_roc_curve( human_scores, gpt_scores )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vpt",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "015bfb409bf441c0a66e03b2de1c9b891435fcbf36ed1d1e9d7c8167e73e6b62"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
