{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict, List, Callable, Tuple, Union, Callable\n",
    "import string\n",
    "import os\n",
    "import json\n",
    "import re\n",
    "import numpy as np\n",
    "from collections import Counter\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "class MultiHopEvaluator:\n",
    "    @classmethod\n",
    "    def get_all_alias(cls, ground_truth_id: str) -> List[str]:\n",
    "        return {}\n",
    "\n",
    "    @classmethod\n",
    "    def normalize_answer(cls, s):\n",
    "        def remove_articles(text):\n",
    "            return re.sub(r'\\b(a|an|the)\\b', ' ', text)\n",
    "        def white_space_fix(text):\n",
    "            return ' '.join(text.split())\n",
    "        def remove_punc(text):\n",
    "            exclude = set(string.punctuation)\n",
    "            return ''.join(ch for ch in text if ch not in exclude)\n",
    "        def lower(text):\n",
    "            return text.lower()\n",
    "        if not isinstance(s, str):\n",
    "            return \"\"\n",
    "        return white_space_fix(remove_articles(remove_punc(lower(s))))\n",
    "\n",
    "    @classmethod\n",
    "    def exact_match_score(\n",
    "        cls,\n",
    "        prediction: str,\n",
    "        ground_truth: Union[str, List[str]],\n",
    "        ground_truth_id: Union[str, List[str]] = None\n",
    "    ):\n",
    "        if not prediction:\n",
    "            return {'correct': 0, 'incorrect': 1}\n",
    "        ground_truths = {ground_truth} if isinstance(ground_truth, str) else set(ground_truth)\n",
    "        if ground_truth_id and isinstance(ground_truth_id, str):\n",
    "            ground_truths.update(cls.get_all_alias(ground_truth_id))\n",
    "\n",
    "        correct = np.max([int(cls.normalize_answer(prediction) == cls.normalize_answer(gt)) for gt in ground_truths])\n",
    "        return {'correct': correct, 'incorrect': 1 - correct}\n",
    "\n",
    "    @classmethod\n",
    "    def f1_score(\n",
    "        cls,\n",
    "        prediction: str,\n",
    "        ground_truth: Union[str, List[str]],\n",
    "        ground_truth_id: Union[str, List[str]] = None\n",
    "    ):\n",
    "        final_metric = {'f1': 0, 'precision': 0, 'recall': 0}\n",
    "        \n",
    "        if not prediction:\n",
    "            return final_metric\n",
    "        ground_truths = {ground_truth} if isinstance(ground_truth, str) else set(ground_truth)\n",
    "        if ground_truth_id and isinstance(ground_truth_id, str):\n",
    "            ground_truths.update(cls.get_all_alias(ground_truth_id))\n",
    "            \n",
    "        for ground_truth in ground_truths:\n",
    "            normalized_prediction = cls.normalize_answer(prediction)\n",
    "            normalized_ground_truth = cls.normalize_answer(ground_truth)\n",
    "            if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:\n",
    "                continue\n",
    "            if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:\n",
    "                continue\n",
    "            prediction_tokens = normalized_prediction.split()\n",
    "            ground_truth_tokens = normalized_ground_truth.split()\n",
    "            common = Counter(prediction_tokens) & Counter(ground_truth_tokens)\n",
    "            num_same = sum(common.values())\n",
    "            if num_same == 0:\n",
    "                continue\n",
    "\n",
    "            precision = 1.0 * num_same / len(prediction_tokens)\n",
    "            recall = 1.0 * num_same / len(ground_truth_tokens)\n",
    "            f1 = (2 * precision * recall) / (precision + recall)\n",
    "            for k in ['f1', 'precision', 'recall']:\n",
    "                final_metric[k] = max(eval(k), final_metric[k])\n",
    "        return final_metric\n",
    "    \n",
    "    def eval_answer(self, results_df, answer_col=\"Final Answer\"):\n",
    "        # for datasets don't have answer_ids, aliases\n",
    "        em_list = []\n",
    "        f1_list = []\n",
    "        for i, row in results_df.iterrows():\n",
    "            prediction = row[answer_col]\n",
    "            ground_truth = row['ground_truth']\n",
    "            em_list.append(self.exact_match_score(prediction, ground_truth, None)['correct'])\n",
    "            f1_list.append(self.f1_score(prediction, ground_truth, None)['f1'])\n",
    "        print(f\"EM: {sum(em_list)/len(em_list):4f}\\t F1: {sum(f1_list)/len(f1_list):4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TwoWikiHop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class WikiMultiHopEvaluator(MultiHopEvaluator):\n",
    "\n",
    "    def __init__(self, data_path: str=\"data/multihop_data/2wikimultihopqa\"): \n",
    "        # logger.info(f\"Loading WikiMultiHopQA from {data_path}\")\n",
    "        dataset = []\n",
    "        with open(os.path.join(data_path, 'dev.json'), 'r') as fin:\n",
    "            js = json.load(fin)\n",
    "            for example in tqdm(js):\n",
    "                qid = example['_id']\n",
    "                question = example['question']\n",
    "                ans = example['answer']\n",
    "                ans_id = example['answer_id']\n",
    "                # ctxs = example['ctxs']\n",
    "                dataset.append({\n",
    "                    'qid': qid,\n",
    "                    'question': question,\n",
    "                    'answer': ans,\n",
    "                    'answer_id': ans_id,\n",
    "                    # 'ctxs': ctxs,\n",
    "                })\n",
    "        self.dataset = dataset\n",
    "        self.dataset_from_qid = {entry['qid']: entry for entry in self.dataset}\n",
    "        self.init_id_aliases(data_path)\n",
    "        \n",
    "    @classmethod\n",
    "    def init_id_aliases(cls, data_path):\n",
    "        cls.id_alias: Dict[str, List[str]] = {}\n",
    "        with open(os.path.join(data_path, 'id_aliases.json'), 'r') as fin:\n",
    "            for l in fin:\n",
    "                l = json.loads(l)\n",
    "                cls.id_alias[l['Q_id']] = l['aliases']\n",
    "\n",
    "    @classmethod\n",
    "    def get_all_alias(cls, ground_truth_id: str) -> List[str]:\n",
    "        if ground_truth_id and ground_truth_id in cls.id_alias:\n",
    "            return cls.id_alias[ground_truth_id]\n",
    "        else:\n",
    "            return []\n",
    "\n",
    "    def get_real_prediction(self, pred):\n",
    "        if \"the answer is\" in pred:\n",
    "            beg = pred.find(\"the answer is\") + len(\"the answer is\") + 1\n",
    "            pred = pred[beg:] # delete final \".\"\n",
    "            if pred.endswith(\"</s>\"):\n",
    "                pred = pred[:len(pred) - len(\"</s>\")]\n",
    "            if pred.endswith(\"<|endoftext|>\"):\n",
    "                pred = pred[:len(pred) - len(\"<|endoftext|>\")]\n",
    "            if pred.endswith(\".\"):\n",
    "                pred = pred[:-1]\n",
    "            return pred\n",
    "        else:\n",
    "            return pred\n",
    "        \n",
    "    def eval_answer(self, results_df, answer_col=\"Final Answer\"):\n",
    "        em_list = []\n",
    "        f1_list = []\n",
    "        for i, row in results_df.iterrows():\n",
    "            prediction = row[answer_col]\n",
    "            ground_truth = row['ground_truth']\n",
    "            ground_truth_id = self.dataset_from_qid[row['qid']]['answer_id']\n",
    "            em_list.append(self.exact_match_score(prediction, ground_truth, ground_truth_id)['correct'])\n",
    "            f1_list.append(self.f1_score(prediction, ground_truth, ground_truth_id)['f1'])\n",
    "        print(f\"EM: {sum(em_list)/len(em_list):4f}\\t F1: {sum(f1_list)/len(f1_list):4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "results_df = pd.concat([\n",
    "    pd.read_json(\"outputs/twowikihop_finished.jsonl\", lines=True),\n",
    "    pd.read_json(\"outputs/twowikihop_llama3_rerun_failed_0612_0652/results.jsonl\", lines=True),\n",
    "])\n",
    "print(len(results_df))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_path = \"__YOUR_RESULT.JSONL_PATH__\"\n",
    "results_df = pd.read_json(results_path, lines=True)\n",
    "twowikihop_evaluator = WikiMultiHopEvaluator()\n",
    "for column_name in [\"Final Answer\", \"Final Step Answer\", \"Final Read Answer\"]:\n",
    "    print(column_name)\n",
    "    twowikihop_evaluator.eval_answer(results_df=results_df, answer_col=column_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Other Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_path = \"__YOUR_RESULT.JSONL_PATH__\"\n",
    "results_df = pd.read_json(results_path, lines=True)\n",
    "evaluator = MultiHopEvaluator()\n",
    "for column_name in [\"Final Answer\", \"Final Step Answer\", \"Final Read Answer\"]:\n",
    "    print(column_name)\n",
    "    evaluator.eval_answer(results_df=results_df, answer_col=column_name)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "QAFINAL",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
