{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d02398d2",
   "metadata": {},
   "source": [
    "#### Importing Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0572384d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import rouge_score\n",
    "from glob import glob\n",
    "import ast\n",
    "import pandas as pd\n",
    "from rouge_score import rouge_scorer\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65c88b4d",
   "metadata": {},
   "source": [
    "#### Rouge Score calculator function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0ba8a9b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)\n",
    "def score_calculator(ground_truth, gpt3response ):\n",
    "    list_of_ground_truths = ast.literal_eval(ground_truth)\n",
    "    max_f1,max_precision,max_recall = 0,0,0\n",
    "    for gt in list_of_ground_truths:\n",
    "        max_f1 = max(scorer.score(gt,gpt3response)['rougeL'][2],max_f1)   \n",
    "        max_precision = max(scorer.score(gt,gpt3response)['rougeL'][0],max_precision)   \n",
    "        max_recall = max(scorer.score(gt,gpt3response)['rougeL'][1],max_recall)   \n",
    "    return [max_f1,max_precision,max_recall]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f315eb60",
   "metadata": {},
   "source": [
    "#### Data Reading and Stats tables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "07c9a69b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_process_data(base_folder,flag = None):\n",
    "    all_files = glob(base_folder + '/*.csv', recursive=True)\n",
    "    li = []\n",
    "    for filename in all_files:\n",
    "        df = pd.read_csv(filename, index_col=None, header=0)\n",
    "        li.append(df)\n",
    "    frame = pd.concat(li, axis=0, ignore_index=True)\n",
    "    frame = frame[frame['label']!=\"[]\"]\n",
    "    frame['key'].fillna('',inplace=True)\n",
    "    frame['baseline scores'] = frame[['label','key']].apply(lambda x: score_calculator(*x), axis=1)\n",
    "    frame[['baseline_f1','baseline_prec','baseline_recall']] = pd.DataFrame(frame['baseline scores'].tolist(), index= frame.index)\n",
    "    if flag == 'gpt':\n",
    "        frame['gpt3 rouge scores'] = frame[['label','GPT3 Responses']].apply(lambda x: score_calculator(*x), axis=1)\n",
    "        frame[['gpt3_f1','gpt3_prec','gpt3_recall']] = pd.DataFrame(frame['gpt3 rouge scores'].tolist(), index= frame.index)\n",
    "        deduped_df = frame[['sent','entity','label','key','GPT3 Responses','gpt3_f1','gpt3_prec','gpt3_recall',\\\n",
    "                        'baseline_f1','baseline_prec','baseline_recall']].drop_duplicates()\n",
    "    else:\n",
    "        deduped_df = frame[['sent','entity','label','key',\\\n",
    "                    'baseline_f1','baseline_prec','baseline_recall']].drop_duplicates()\n",
    "    entity_map = {k:v for k,v in zip(list(deduped_df.groupby(['sent'])['entity'].nunique().reset_index()['sent']),\\\n",
    "                    list(deduped_df.groupby(['sent'])['entity'].nunique().reset_index()['entity']))}\n",
    "    deduped_df['sent_len'] = deduped_df['sent'].apply(lambda x : len(x.split(' ')))\n",
    "    deduped_df['num_ent'] = deduped_df['sent'].apply(lambda x : entity_map[x])\n",
    "    deduped_df['num_ent1'] = deduped_df['num_ent'].apply(lambda x : '5 or more' if x>=5 else str(x))\n",
    "    \n",
    "    if flag == 'gpt':\n",
    "        grp = deduped_df.groupby(['num_ent1']).agg({'sent':'nunique',\\\n",
    "                                                  'gpt3_f1':'mean',\\\n",
    "                                                   'gpt3_prec':'mean',\\\n",
    "                                                  'gpt3_recall':'mean',\\\n",
    "                                                  'baseline_f1':'mean',\\\n",
    "                                                  'baseline_prec':'mean',\\\n",
    "                                                  'baseline_recall':'mean',\\\n",
    "                                                  'sent_len':'mean'}).reset_index()\n",
    "       \n",
    "    else:\n",
    "        grp = deduped_df.groupby(['num_ent1']).agg({'sent':'nunique',\\\n",
    "                                                  'baseline_f1':'mean',\\\n",
    "                                                  'baseline_prec':'mean',\\\n",
    "                                                  'baseline_recall':'mean',\\\n",
    "                                                  'sent_len':'mean'}).reset_index()\n",
    "    return grp,deduped_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0bb454f2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>num_ent1</th>\n",
       "      <th>sent</th>\n",
       "      <th>gpt3_f1</th>\n",
       "      <th>gpt3_prec</th>\n",
       "      <th>gpt3_recall</th>\n",
       "      <th>baseline_f1</th>\n",
       "      <th>baseline_prec</th>\n",
       "      <th>baseline_recall</th>\n",
       "      <th>sent_len</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>2567</td>\n",
       "      <td>0.358512</td>\n",
       "      <td>0.418636</td>\n",
       "      <td>0.372663</td>\n",
       "      <td>0.314365</td>\n",
       "      <td>0.445057</td>\n",
       "      <td>0.284406</td>\n",
       "      <td>27.676541</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>1195</td>\n",
       "      <td>0.413787</td>\n",
       "      <td>0.447036</td>\n",
       "      <td>0.471250</td>\n",
       "      <td>0.356630</td>\n",
       "      <td>0.479905</td>\n",
       "      <td>0.332139</td>\n",
       "      <td>29.016579</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>557</td>\n",
       "      <td>0.448314</td>\n",
       "      <td>0.471564</td>\n",
       "      <td>0.537460</td>\n",
       "      <td>0.352720</td>\n",
       "      <td>0.469555</td>\n",
       "      <td>0.327552</td>\n",
       "      <td>30.717921</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>103</td>\n",
       "      <td>0.341414</td>\n",
       "      <td>0.363681</td>\n",
       "      <td>0.422724</td>\n",
       "      <td>0.303974</td>\n",
       "      <td>0.408391</td>\n",
       "      <td>0.289925</td>\n",
       "      <td>42.596529</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5 or more</td>\n",
       "      <td>55</td>\n",
       "      <td>0.409523</td>\n",
       "      <td>0.435130</td>\n",
       "      <td>0.479771</td>\n",
       "      <td>0.324413</td>\n",
       "      <td>0.485859</td>\n",
       "      <td>0.276166</td>\n",
       "      <td>53.848066</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    num_ent1  sent   gpt3_f1  gpt3_prec  gpt3_recall  baseline_f1  \\\n",
       "0          1  2567  0.358512   0.418636     0.372663     0.314365   \n",
       "1          2  1195  0.413787   0.447036     0.471250     0.356630   \n",
       "2          3   557  0.448314   0.471564     0.537460     0.352720   \n",
       "3          4   103  0.341414   0.363681     0.422724     0.303974   \n",
       "4  5 or more    55  0.409523   0.435130     0.479771     0.324413   \n",
       "\n",
       "   baseline_prec  baseline_recall   sent_len  \n",
       "0       0.445057         0.284406  27.676541  \n",
       "1       0.479905         0.332139  29.016579  \n",
       "2       0.469555         0.327552  30.717921  \n",
       "3       0.408391         0.289925  42.596529  \n",
       "4       0.485859         0.276166  53.848066  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "base_folder = './GPT3Responses'\n",
    "group_data,frame = read_process_data(base_folder,flag = 'gpt')\n",
    "group_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6b62a13f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>baseline_f1</th>\n",
       "      <th>baseline_prec</th>\n",
       "      <th>baseline_recall</th>\n",
       "      <th>gpt3_f1</th>\n",
       "      <th>gpt3_prec</th>\n",
       "      <th>gpt3_recall</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>8261.000000</td>\n",
       "      <td>8261.000000</td>\n",
       "      <td>8261.000000</td>\n",
       "      <td>8261.000000</td>\n",
       "      <td>8261.000000</td>\n",
       "      <td>8261.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>0.336379</td>\n",
       "      <td>0.461472</td>\n",
       "      <td>0.309335</td>\n",
       "      <td>0.397629</td>\n",
       "      <td>0.437250</td>\n",
       "      <td>0.448669</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.324280</td>\n",
       "      <td>0.421121</td>\n",
       "      <td>0.329171</td>\n",
       "      <td>0.313867</td>\n",
       "      <td>0.362869</td>\n",
       "      <td>0.357473</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.142857</td>\n",
       "      <td>0.125000</td>\n",
       "      <td>0.142857</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>0.285714</td>\n",
       "      <td>0.400000</td>\n",
       "      <td>0.250000</td>\n",
       "      <td>0.363636</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.400000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>0.571429</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.625000</td>\n",
       "      <td>0.750000</td>\n",
       "      <td>0.750000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       baseline_f1  baseline_prec  baseline_recall      gpt3_f1    gpt3_prec  \\\n",
       "count  8261.000000    8261.000000      8261.000000  8261.000000  8261.000000   \n",
       "mean      0.336379       0.461472         0.309335     0.397629     0.437250   \n",
       "std       0.324280       0.421121         0.329171     0.313867     0.362869   \n",
       "min       0.000000       0.000000         0.000000     0.000000     0.000000   \n",
       "25%       0.000000       0.000000         0.000000     0.142857     0.125000   \n",
       "50%       0.285714       0.400000         0.250000     0.363636     0.333333   \n",
       "75%       0.571429       1.000000         0.500000     0.625000     0.750000   \n",
       "max       1.000000       1.000000         1.000000     1.000000     1.000000   \n",
       "\n",
       "       gpt3_recall  \n",
       "count  8261.000000  \n",
       "mean      0.448669  \n",
       "std       0.357473  \n",
       "min       0.000000  \n",
       "25%       0.142857  \n",
       "50%       0.400000  \n",
       "75%       0.750000  \n",
       "max       1.000000  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "frame[['baseline_f1','baseline_prec','baseline_recall','gpt3_f1','gpt3_prec','gpt3_recall']].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f051a04f",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_folder = \"./Output\"\n",
    "group_data,frame = read_process_data(base_folder,flag = None)\n",
    "group_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be80ef6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "frame[['baseline_f1','baseline_prec','baseline_recall']].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af44ddee",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8568ed43",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
