{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "\u001b[91mWARNING: category not found\u001b[0m error\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "WARNING: category not found (Tense Consistency)\n",
      "\u001b[91mWARNING: category not found\u001b[0m Tense Consistency\n"
     ]
    }
   ],
   "source": [
    "from utils_generate_edits import prep_sample_indices\n",
    "import json, os\n",
    "\n",
    "with open(\"all_finegrained_clean.json\", \"r\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "data = [d for d in data if d[\"split\"] == \"test\"]\n",
    "\n",
    "for d in data:\n",
    "    for edit in d[\"fine_grained_edits\"]:\n",
    "        edit[\"categorization\"] = edit[\"categorization\"].replace(\"/ \", \"/\").replace(\" (Unnecessary ornamental and overly verbose)\", \"\")\n",
    "\n",
    "id2sample = {d[\"id\"]: d for d in data}\n",
    "\n",
    "for anno_fn in os.listdir(\"data/detection_preds/\"):\n",
    "    model, prompt_id = anno_fn.replace(\".jsonl\", \"\").split(\"_\")\n",
    "    with open(\"data/detection_preds/\"+anno_fn, \"r\") as f:\n",
    "        for line in f:\n",
    "            d = json.loads(line)\n",
    "            if d[\"id\"] in id2sample:\n",
    "                id2sample[d[\"id\"]][f\"pred_{model}_{prompt_id}\"] = d[\"detection\"]\n",
    "\n",
    "for sample in data:\n",
    "    prep_sample_indices(sample)\n",
    "\n",
    "# categories = sorted(set([span[\"categorization\"] for sample in data for span in sample[\"fine_grained_edits\"]]))\n",
    "all_cats = list(data[0][\"gold_indices\"].keys())\n",
    "pred_keys = list(set([k for d in data for k in d if k.startswith(\"pred_\")]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/export/home/anaconda3/lib/python3.9/site-packages/numpy/core/fromnumeric.py:3432: RuntimeWarning: Mean of empty slice.\n",
      "  return _methods._mean(a, axis=axis, dtype=dtype,\n",
      "/export/home/anaconda3/lib/python3.9/site-packages/numpy/core/_methods.py:190: RuntimeWarning: invalid value encountered in double_scalars\n",
      "  ret = ret.dtype.type(ret / rcount)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_2baae_row0_col2, #T_2baae_row6_col2 {\n",
       "  background-color: #caddf0;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_2baae_row0_col3 {\n",
       "  background-color: #c1d9ed;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_2baae_row1_col2 {\n",
       "  background-color: #d1e2f3;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_2baae_row1_col3 {\n",
       "  background-color: #000000;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_2baae_row2_col2 {\n",
       "  background-color: #ddeaf7;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_2baae_row2_col3 {\n",
       "  background-color: #ecf4fb;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_2baae_row3_col2 {\n",
       "  background-color: #b2d2e8;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_2baae_row3_col3 {\n",
       "  background-color: #cbdef1;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_2baae_row4_col2 {\n",
       "  background-color: #91c3de;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_2baae_row4_col3 {\n",
       "  background-color: #add0e6;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_2baae_row5_col2 {\n",
       "  background-color: #6caed6;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_2baae_row5_col3 {\n",
       "  background-color: #82bbdb;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_2baae_row6_col3 {\n",
       "  background-color: #f7fbff;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_2baae_row7_col2 {\n",
       "  background-color: #7fb9da;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_2baae_row7_col3 {\n",
       "  background-color: #abd0e6;\n",
       "  color: #000000;\n",
       "}\n",
       "</style>\n",
       "<table style=\"display:inline\" id=\"T_2baae\">\n",
       "  <caption>Precision</caption>\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_2baae_level0_col0\" class=\"col_heading level0 col0\" >Model</th>\n",
       "      <th id=\"T_2baae_level0_col1\" class=\"col_heading level0 col1\" >N</th>\n",
       "      <th id=\"T_2baae_level0_col2\" class=\"col_heading level0 col2\" >v2-fs25</th>\n",
       "      <th id=\"T_2baae_level0_col3\" class=\"col_heading level0 col3\" >v2-fs5</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_2baae_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
       "      <td id=\"T_2baae_row0_col0\" class=\"data row0 col0\" >gemini-1.5-pro</td>\n",
       "      <td id=\"T_2baae_row0_col1\" class=\"data row0 col1\" >896</td>\n",
       "      <td id=\"T_2baae_row0_col2\" class=\"data row0 col2\" >0.418</td>\n",
       "      <td id=\"T_2baae_row0_col3\" class=\"data row0 col3\" >0.423</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_2baae_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
       "      <td id=\"T_2baae_row1_col0\" class=\"data row1 col0\" >llama3.1-70b</td>\n",
       "      <td id=\"T_2baae_row1_col1\" class=\"data row1 col1\" >2</td>\n",
       "      <td id=\"T_2baae_row1_col2\" class=\"data row1 col2\" >0.412</td>\n",
       "      <td id=\"T_2baae_row1_col3\" class=\"data row1 col3\" >nan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_2baae_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n",
       "      <td id=\"T_2baae_row2_col0\" class=\"data row2 col0\" >gemini-1.5-flash</td>\n",
       "      <td id=\"T_2baae_row2_col1\" class=\"data row2 col1\" >1445</td>\n",
       "      <td id=\"T_2baae_row2_col2\" class=\"data row2 col2\" >0.403</td>\n",
       "      <td id=\"T_2baae_row2_col3\" class=\"data row2 col3\" >0.391</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_2baae_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n",
       "      <td id=\"T_2baae_row3_col0\" class=\"data row3 col0\" >mistral-large2</td>\n",
       "      <td id=\"T_2baae_row3_col1\" class=\"data row3 col1\" >1460</td>\n",
       "      <td id=\"T_2baae_row3_col2\" class=\"data row3 col2\" >0.431</td>\n",
       "      <td id=\"T_2baae_row3_col3\" class=\"data row3 col3\" >0.417</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_2baae_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n",
       "      <td id=\"T_2baae_row4_col0\" class=\"data row4 col0\" >claude3.5-sonnet</td>\n",
       "      <td id=\"T_2baae_row4_col1\" class=\"data row4 col1\" >1460</td>\n",
       "      <td id=\"T_2baae_row4_col2\" class=\"data row4 col2\" >0.446</td>\n",
       "      <td id=\"T_2baae_row4_col3\" class=\"data row4 col3\" >0.434</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_2baae_level0_row5\" class=\"row_heading level0 row5\" >5</th>\n",
       "      <td id=\"T_2baae_row5_col0\" class=\"data row5 col0\" >gpt-4o</td>\n",
       "      <td id=\"T_2baae_row5_col1\" class=\"data row5 col1\" >1460</td>\n",
       "      <td id=\"T_2baae_row5_col2\" class=\"data row5 col2\" >0.460</td>\n",
       "      <td id=\"T_2baae_row5_col3\" class=\"data row5 col3\" >0.451</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_2baae_level0_row6\" class=\"row_heading level0 row6\" >6</th>\n",
       "      <td id=\"T_2baae_row6_col0\" class=\"data row6 col0\" >claude3-haiku</td>\n",
       "      <td id=\"T_2baae_row6_col1\" class=\"data row6 col1\" >1460</td>\n",
       "      <td id=\"T_2baae_row6_col2\" class=\"data row6 col2\" >0.418</td>\n",
       "      <td id=\"T_2baae_row6_col3\" class=\"data row6 col3\" >0.382</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_2baae_level0_row7\" class=\"row_heading level0 row7\" >7</th>\n",
       "      <td id=\"T_2baae_row7_col0\" class=\"data row7 col0\" >gpt-4o-mini</td>\n",
       "      <td id=\"T_2baae_row7_col1\" class=\"data row7 col1\" >1460</td>\n",
       "      <td id=\"T_2baae_row7_col2\" class=\"data row7 col2\" >0.452</td>\n",
       "      <td id=\"T_2baae_row7_col3\" class=\"data row7 col3\" >0.434</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table style=\"display:inline\">\n",
       "<style type=\"text/css\">\n",
       "#T_4d293_row0_col1 {\n",
       "  background-color: #9ac8e0;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_4d293_row0_col2 {\n",
       "  background-color: #7fb9da;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_4d293_row1_col1 {\n",
       "  background-color: #cadef0;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_4d293_row1_col2 {\n",
       "  background-color: #000000;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_4d293_row2_col1 {\n",
       "  background-color: #94c4df;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_4d293_row2_col2 {\n",
       "  background-color: #6aaed6;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_4d293_row3_col1 {\n",
       "  background-color: #c2d9ee;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_4d293_row3_col2 {\n",
       "  background-color: #a6cee4;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_4d293_row4_col1 {\n",
       "  background-color: #d2e3f3;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_4d293_row4_col2 {\n",
       "  background-color: #bed8ec;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_4d293_row5_col1 {\n",
       "  background-color: #eef5fc;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_4d293_row5_col2 {\n",
       "  background-color: #dce9f6;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_4d293_row6_col1 {\n",
       "  background-color: #eaf3fb;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_4d293_row6_col2 {\n",
       "  background-color: #c7dbef;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_4d293_row7_col1 {\n",
       "  background-color: #f7fbff;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_4d293_row7_col2 {\n",
       "  background-color: #ecf4fb;\n",
       "  color: #000000;\n",
       "}\n",
       "</style>\n",
       "<table style=\"display:inline\" id=\"T_4d293\">\n",
       "  <caption>Recall</caption>\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_4d293_level0_col0\" class=\"col_heading level0 col0\" >Model</th>\n",
       "      <th id=\"T_4d293_level0_col1\" class=\"col_heading level0 col1\" >v2-fs25</th>\n",
       "      <th id=\"T_4d293_level0_col2\" class=\"col_heading level0 col2\" >v2-fs5</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_4d293_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
       "      <td id=\"T_4d293_row0_col0\" class=\"data row0 col0\" >gemini-1.5-pro</td>\n",
       "      <td id=\"T_4d293_row0_col1\" class=\"data row0 col1\" >0.803</td>\n",
       "      <td id=\"T_4d293_row0_col2\" class=\"data row0 col2\" >0.858</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_4d293_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
       "      <td id=\"T_4d293_row1_col0\" class=\"data row1 col0\" >llama3.1-70b</td>\n",
       "      <td id=\"T_4d293_row1_col1\" class=\"data row1 col1\" >0.679</td>\n",
       "      <td id=\"T_4d293_row1_col2\" class=\"data row1 col2\" >nan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_4d293_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n",
       "      <td id=\"T_4d293_row2_col0\" class=\"data row2 col0\" >gemini-1.5-flash</td>\n",
       "      <td id=\"T_4d293_row2_col1\" class=\"data row2 col1\" >0.817</td>\n",
       "      <td id=\"T_4d293_row2_col2\" class=\"data row2 col2\" >0.896</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_4d293_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n",
       "      <td id=\"T_4d293_row3_col0\" class=\"data row3 col0\" >mistral-large2</td>\n",
       "      <td id=\"T_4d293_row3_col1\" class=\"data row3 col1\" >0.705</td>\n",
       "      <td id=\"T_4d293_row3_col2\" class=\"data row3 col2\" >0.774</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_4d293_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n",
       "      <td id=\"T_4d293_row4_col0\" class=\"data row4 col0\" >claude3.5-sonnet</td>\n",
       "      <td id=\"T_4d293_row4_col1\" class=\"data row4 col1\" >0.645</td>\n",
       "      <td id=\"T_4d293_row4_col2\" class=\"data row4 col2\" >0.715</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_4d293_level0_row5\" class=\"row_heading level0 row5\" >5</th>\n",
       "      <td id=\"T_4d293_row5_col0\" class=\"data row5 col0\" >gpt-4o</td>\n",
       "      <td id=\"T_4d293_row5_col1\" class=\"data row5 col1\" >0.534</td>\n",
       "      <td id=\"T_4d293_row5_col2\" class=\"data row5 col2\" >0.604</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_4d293_level0_row6\" class=\"row_heading level0 row6\" >6</th>\n",
       "      <td id=\"T_4d293_row6_col0\" class=\"data row6 col0\" >claude3-haiku</td>\n",
       "      <td id=\"T_4d293_row6_col1\" class=\"data row6 col1\" >0.546</td>\n",
       "      <td id=\"T_4d293_row6_col2\" class=\"data row6 col2\" >0.693</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_4d293_level0_row7\" class=\"row_heading level0 row7\" >7</th>\n",
       "      <td id=\"T_4d293_row7_col0\" class=\"data row7 col0\" >gpt-4o-mini</td>\n",
       "      <td id=\"T_4d293_row7_col1\" class=\"data row7 col1\" >0.494</td>\n",
       "      <td id=\"T_4d293_row7_col2\" class=\"data row7 col2\" >0.540</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table style=\"display:inline\">\n",
       "<style type=\"text/css\">\n",
       "#T_cf413_row0_col1 {\n",
       "  background-color: #8dc1dd;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cf413_row0_col2 {\n",
       "  background-color: #6aaed6;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_cf413_row1_col1 {\n",
       "  background-color: #8fc2de;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cf413_row1_col2 {\n",
       "  background-color: #000000;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_cf413_row2_col1 {\n",
       "  background-color: #a0cbe2;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cf413_row2_col2 {\n",
       "  background-color: #91c3de;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cf413_row3_col1 {\n",
       "  background-color: #aed1e7;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cf413_row3_col2 {\n",
       "  background-color: #9dcae1;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cf413_row4_col1 {\n",
       "  background-color: #bad6eb;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cf413_row4_col2 {\n",
       "  background-color: #a4cce3;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cf413_row5_col1 {\n",
       "  background-color: #d9e8f5;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cf413_row5_col2 {\n",
       "  background-color: #c7dbef;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cf413_row6_col1 {\n",
       "  background-color: #f7fbff;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cf413_row6_col2 {\n",
       "  background-color: #deebf7;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cf413_row7_col1 {\n",
       "  background-color: #f0f6fd;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cf413_row7_col2 {\n",
       "  background-color: #eaf2fb;\n",
       "  color: #000000;\n",
       "}\n",
       "</style>\n",
       "<table style=\"display:inline\" id=\"T_cf413\">\n",
       "  <caption>F1</caption>\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_cf413_level0_col0\" class=\"col_heading level0 col0\" >Model</th>\n",
       "      <th id=\"T_cf413_level0_col1\" class=\"col_heading level0 col1\" >v2-fs25</th>\n",
       "      <th id=\"T_cf413_level0_col2\" class=\"col_heading level0 col2\" >v2-fs5</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_cf413_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
       "      <td id=\"T_cf413_row0_col0\" class=\"data row0 col0\" >gemini-1.5-pro</td>\n",
       "      <td id=\"T_cf413_row0_col1\" class=\"data row0 col1\" >0.514</td>\n",
       "      <td id=\"T_cf413_row0_col2\" class=\"data row0 col2\" >0.532</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_cf413_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
       "      <td id=\"T_cf413_row1_col0\" class=\"data row1 col0\" >llama3.1-70b</td>\n",
       "      <td id=\"T_cf413_row1_col1\" class=\"data row1 col1\" >0.513</td>\n",
       "      <td id=\"T_cf413_row1_col2\" class=\"data row1 col2\" >nan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_cf413_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n",
       "      <td id=\"T_cf413_row2_col0\" class=\"data row2 col0\" >gemini-1.5-flash</td>\n",
       "      <td id=\"T_cf413_row2_col1\" class=\"data row2 col1\" >0.504</td>\n",
       "      <td id=\"T_cf413_row2_col2\" class=\"data row2 col2\" >0.513</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_cf413_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n",
       "      <td id=\"T_cf413_row3_col0\" class=\"data row3 col0\" >mistral-large2</td>\n",
       "      <td id=\"T_cf413_row3_col1\" class=\"data row3 col1\" >0.494</td>\n",
       "      <td id=\"T_cf413_row3_col2\" class=\"data row3 col2\" >0.505</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_cf413_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n",
       "      <td id=\"T_cf413_row4_col0\" class=\"data row4 col0\" >claude3.5-sonnet</td>\n",
       "      <td id=\"T_cf413_row4_col1\" class=\"data row4 col1\" >0.486</td>\n",
       "      <td id=\"T_cf413_row4_col2\" class=\"data row4 col2\" >0.501</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_cf413_level0_row5\" class=\"row_heading level0 row5\" >5</th>\n",
       "      <td id=\"T_cf413_row5_col0\" class=\"data row5 col0\" >gpt-4o</td>\n",
       "      <td id=\"T_cf413_row5_col1\" class=\"data row5 col1\" >0.456</td>\n",
       "      <td id=\"T_cf413_row5_col2\" class=\"data row5 col2\" >0.477</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_cf413_level0_row6\" class=\"row_heading level0 row6\" >6</th>\n",
       "      <td id=\"T_cf413_row6_col0\" class=\"data row6 col0\" >claude3-haiku</td>\n",
       "      <td id=\"T_cf413_row6_col1\" class=\"data row6 col1\" >0.424</td>\n",
       "      <td id=\"T_cf413_row6_col2\" class=\"data row6 col2\" >0.451</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_cf413_level0_row7\" class=\"row_heading level0 row7\" >7</th>\n",
       "      <td id=\"T_cf413_row7_col0\" class=\"data row7 col0\" >gpt-4o-mini</td>\n",
       "      <td id=\"T_cf413_row7_col1\" class=\"data row7 col1\" >0.432</td>\n",
       "      <td id=\"T_cf413_row7_col2\" class=\"data row7 col2\" >0.439</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table style=\"display:inline\">\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from utils_misc import display_results\n",
    "import numpy as np, pandas as pd\n",
    "\n",
    "results_ps, results_rs, results_f1s = {}, {}, {}\n",
    "for pred_key in pred_keys:\n",
    "    _, model, prompt = pred_key.split(\"_\")\n",
    "    results_ps[pred_key] = {\"Model\": model, \"Prompt\": prompt}\n",
    "    results_rs[pred_key] = {\"Model\": model, \"Prompt\": prompt}\n",
    "    results_f1s[pred_key] = {\"Model\": model, \"Prompt\": prompt}\n",
    "    \n",
    "    precs = {cat: [] for cat in all_cats}\n",
    "    recs = {cat: [] for cat in all_cats}\n",
    "    f1s = {cat: [] for cat in all_cats}\n",
    "\n",
    "    for sample in data:\n",
    "        if pred_key not in sample:\n",
    "            continue\n",
    "        idx_key = pred_key.replace(\"pred_\", \"idx_\")\n",
    "        sample[\"f1_\" + pred_key] = {}\n",
    "        for cat in all_cats:\n",
    "            if len(sample[\"gold_indices\"][cat]) == 0:\n",
    "                continue\n",
    "            \n",
    "            gold = set(sample[\"gold_indices\"][cat])\n",
    "            pred = set(sample[idx_key][cat])\n",
    "            tp, fp, fn = len(gold & pred), len(pred - gold), len(gold - pred)\n",
    "            if tp == 0:\n",
    "                precision, recall, f1 = 0, 0, 0\n",
    "            else:\n",
    "                precision = tp / (tp + fp)\n",
    "                recall = tp / (tp + fn)\n",
    "                f1 = 2 * precision * recall / (precision + recall)\n",
    "            precs[cat].append(precision)\n",
    "            recs[cat].append(recall)\n",
    "            f1s[cat].append(f1)\n",
    "\n",
    "    results_ps[pred_key][\"N\"] = len(precs[\"all\"])\n",
    "    for cat in all_cats:\n",
    "        results_ps[pred_key][cat] = np.mean(precs[cat])\n",
    "        results_rs[pred_key][cat] = np.mean(recs[cat])\n",
    "        results_f1s[pred_key][cat] = np.mean(f1s[cat])\n",
    "        \n",
    "# sort by F1\n",
    "pred_keys = sorted(pred_keys, key=lambda x: results_f1s[x][\"all\"], reverse=True)\n",
    "results_ps = sorted(results_ps.values(), key=lambda x: pred_keys.index(f\"pred_{x['Model']}_{x['Prompt']}\"))\n",
    "results_rs = sorted(results_rs.values(), key=lambda x: pred_keys.index(f\"pred_{x['Model']}_{x['Prompt']}\"))\n",
    "results_f1s = sorted(results_f1s.values(), key=lambda x: pred_keys.index(f\"pred_{x['Model']}_{x['Prompt']}\"))\n",
    "\n",
    "# display_results(results_ps, results_rs, results_f1s)\n",
    "\n",
    "# global results focused on \"all\", put promps on the same row\n",
    "all_prompts = sorted(set([r[\"Prompt\"] for r in results_f1s]))\n",
    "results_ps_all, results_rs_all, results_f1s_all = [], [], []\n",
    "all_models = [r[\"Model\"] for r in results_f1s]\n",
    "models = sorted(set(all_models), key=lambda x: all_models.index(x))\n",
    "for model in models:\n",
    "    results_ps_all.append({\"Model\": model, \"N\": 0})\n",
    "    results_rs_all.append({\"Model\": model})\n",
    "    results_f1s_all.append({\"Model\": model})\n",
    "    for prompt in all_prompts:\n",
    "        for results, results_all in [(results_ps, results_ps_all), (results_rs, results_rs_all), (results_f1s, results_f1s_all)]:\n",
    "            for r in results:\n",
    "                if r[\"Model\"] == model and r[\"Prompt\"] == prompt:\n",
    "                    if results == results_ps:\n",
    "                        results_all[-1][\"N\"] += r[\"N\"]\n",
    "                    results_all[-1][prompt] = r[\"all\"]\n",
    "display_results(results_ps_all, results_rs_all, results_f1s_all)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
