{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcdbea79-8c2e-435a-ab8a-6958a0b5b2d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "from sklearn import metrics\n",
    "import os\n",
    "import seaborn as sns\n",
    "from tqdm import tqdm\n",
    "\n",
    "from io_utils import read_jsonlines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6267b64",
   "metadata": {},
   "outputs": [],
   "source": [
    "nonmem_data = list(read_jsonlines(f\"outputs_det/coco.jsonl\"))\n",
    "\n",
    "mem_data = list(read_jsonlines(f\"outputs_det/memorization.jsonl\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db1328e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# aggregation\n",
    "all_data = [nonmem_data, mem_data]\n",
    "nonmem_df = []\n",
    "mem_df = []\n",
    "new_df = [nonmem_df, mem_df]\n",
    "\n",
    "start = 0\n",
    "end = 50 ### which steps to use\n",
    "num_gens = 4 ### num of gens to use\n",
    "\n",
    "for i in range(len(all_data)):\n",
    "    curr_data = all_data[i]\n",
    "    curr_df = new_df[i]\n",
    "\n",
    "    # clean\n",
    "    for row in tqdm(curr_data):\n",
    "        new_row = {}\n",
    "        for key in row.keys():\n",
    "            if key == \"prompt\":\n",
    "                continue\n",
    "            \n",
    "            curr_data = np.array(row[key])\n",
    "            curr_data = curr_data[:num_gens, start:end]\n",
    "            curr_data = np.mean(curr_data, axis=0)\n",
    "\n",
    "            new_row[f\"{key}_mean\"] = np.mean(curr_data)\n",
    "\n",
    "        curr_df.append(new_row)\n",
    "\n",
    "all_keys = list(new_row.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "519b5262-0647-4f72-adee-08162a1a96d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "nonmem_df = pd.DataFrame(nonmem_df)\n",
    "nonmem_df[\"origin\"] = [\"COCO\"]*len(nonmem_df)\n",
    "\n",
    "mem_df = pd.DataFrame(mem_df)\n",
    "mem_df[\"origin\"] = [\"Mem\"]*len(mem_df)\n",
    "\n",
    "all_dfs = [mem_df, nonmem_df]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28fc4761-cc01-4745-9a82-d86946256802",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_df = pd.concat(all_dfs, ignore_index=True)\n",
    "merged_df = merged_df.replace(0, 1) # for log\n",
    "merged_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19e19ce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_metrics = [\"text_noise_norm_mean\"]\n",
    "all_metrics = all_keys\n",
    "\n",
    "all_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e1360e6-ea3b-4e4f-8cbf-9f30cdfbdc08",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "for metric_name in all_metrics:\n",
    "    preds = []\n",
    "    t_labels = []\n",
    "    for j in range(len(all_dfs)):\n",
    "        curr_df = all_dfs[j]\n",
    "        curr_data = curr_df[metric_name].values.tolist()\n",
    "        preds += curr_data\n",
    "        \n",
    "        if j == 0:\n",
    "            t_labels += [1] * len(curr_data)\n",
    "        else:\n",
    "            t_labels += [0] * len(curr_data)\n",
    "\n",
    "    fpr, tpr, thresholds = metrics.roc_curve(t_labels, preds, pos_label=1)\n",
    "    auc = metrics.auc(fpr, tpr)\n",
    "    acc = np.max(1 - (fpr + (1 - tpr))/2)\n",
    "    low = tpr[np.where(fpr<.01)[0][-1]]\n",
    "\n",
    "    print('AUC: %.3f, ACC: %.3f, TPR@1%%FPR: %.3f' % (auc, acc, low))\n",
    "    print('%.3f/%.3f' % (auc, low))\n",
    "\n",
    "\n",
    "    merged_df[metric_name][merged_df[metric_name]<0.01] = 0.01\n",
    "    sns.kdeplot(data=merged_df, x=metric_name, hue=\"origin\", log_scale=True, fill=True)\n",
    "    plt.xlim(0.7)\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d6062a7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
