{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## [MAUVE & ROUGE] ASQA & ELI5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import json\n",
    "from metrics import load_file\n",
    "\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/eli5_sonnet_test_base.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/asqa_test_output_sonnet_sonnet.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/eli5_test_output_sonnet_sonnet.json\"\n",
    "# with open(\"/shared/eng/pj20/firas_data/test_datasets/results/asqa_sonnet_retrieval_top_5_base.json\", \"r\") as f:\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_asqa_answerer_test_results_all_graph_no_ret.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_asqa_answerer_test_results_no_graph.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_asqa_answerer_test_results_all_graph_no_gtoken.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/eli5_sonnet_retrieval_top_5_sure_200.jsonl\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_eli5_answerer_test_results_all_graph_3_20.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_asqa_answerer_test_results_all_graph_3_20_8b.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_asqa_answerer_test_results_all_graph_5_20_8b.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_eli5_answerer_test_results_all_graph_7_20_8b.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_asqa_answerer_test_results_all_graph_10_20_8b.json\"\n",
    "file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_asqa_answerer_test_results_graph_0.5_8b.json\"\n",
    "\n",
    "if \"jsonl\" in file_path:\n",
    "    data = load_file(file_path)\n",
    "else:\n",
    "    with open(file_path, \"r\") as f:\n",
    "        data_ = json.load(f)\n",
    "\n",
    "\n",
    "# new_data = []\n",
    "# if \"test_output\" in file_path:\n",
    "#     for i in range(len(data['output'])):\n",
    "#         new_data.append({\"output\": data['output'][i], \"answer\": data['answer'][i], \"question\": data['question'][i]})\n",
    "#     data = new_data\n",
    "\n",
    "# else:\n",
    "#     data = data['data']\n",
    "#     if \"eli5\" in file_path:\n",
    "#         for item in data:\n",
    "#             item['answer'] = item['answers']\n",
    "\n",
    "data = []\n",
    "if \"graphllm\" in file_path:\n",
    "    for item in data_:\n",
    "        data.append({\"question\": item['input'].split(\"[Long Form] Question: \")[1].lower(), \"output\": item['prediction'].lower(), \"answer\": item['label'].lower()})\n",
    "            \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from metrics import *\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "normalized_data = copy.deepcopy(data)\n",
    "# for i in range(len(normalized_data)):\n",
    "#     normalized_data[i]['output'] = remove_citations(normalized_data[i]['output'])\n",
    "\n",
    "references = [' '.join((item['question'] + \" \" + item['answer'].strip()).split()[:100]).rstrip(string.punctuation) for item in normalized_data]\n",
    "predictions = [' '.join((item['question'] + \" \" + item['output'].strip()).split()[:100]).rstrip(string.punctuation) for item in normalized_data]\n",
    "\n",
    "\n",
    "print(\"ROUGE: \", compute_rouge(normalized_data))\n",
    "print(\"MAUVE: \", mauve_score(predictions, references))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PopQA, TriviaQA, PubHealth, ARC-C, 2WikiMultiHop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import json\n",
    "from metrics import *\n",
    "from utils import *\n",
    "\n",
    "# data = load_file(\"/shared/eng/pj20/firas_data/test_datasets/results/popqa_sonnet_retrieval_top_5.jsonl\")\n",
    "# with open(\"/shared/eng/pj20/firas_data/test_datasets/pubhealth_test_output_sonnet_sonnet.json\", \"r\") as f:\n",
    "# with open(\"/shared/eng/pj20/firas_data/test_datasets/arc_c_test_output_sonnet_sonnet.json\", \"r\") as f:\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_triviaqa_answerer_test_results.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_pubhealth_answerer_test_results_all_graph.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_pubhealth_answerer_test_results.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/triviaqa_test_output_sonnet_sonnet.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_popqa_answerer_test_results_all_graph.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_triviaqa_answerer_test_results_all_graph.json\"\n",
    "\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_arc_c_answerer_test_results_all_graph.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/2wikimultihop_test_output_llama2-7b_sonnet_v3.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_2wikimultihop_answerer_test_results_all_graph_both.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_triviaqa_answerer_test_results_all_graph_ptuned.json\"\n",
    "\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_pubhealth_answerer_test_results_all_graph_both.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_pubhealth_answerer_test_results_all_graph_ptuned.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_arc_c_answerer_test_results_all_graph_both.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_pubhealth_answerer_test_results_all_graph_no_gtoken.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_pubhealth_answerer_test_results_no_graph.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_triviaqa_answerer_test_results_no_graph.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_2wikimultihop_answerer_test_results_no_graph.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_triviaqa_answerer_test_results_all_graph_no_gtoken.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_2wikimultihop_answerer_test_results_all_graph_no_gtoken.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_triviaqa_answerer_test_results_all_graph_no_plan.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_2wikimultihop_answerer_test_results_all_graph_no_plan.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_2wikimultihop_answerer_test_results_all_graph_both.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/popqa_sonnet_retrieval_top_5_sure_200.jsonl\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/pubhealth_sonnet_retrieval_top_5_sure_200.jsonl\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/arc_c_sonnet_retrieval_top_5_sure_200.jsonl\"\n",
    "\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_pubhealth_answerer_test_results_all_graph_3_20_8b.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_triviaqa_answerer_test_results_all_graph_5_20_8b.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_pubhealth_answerer_test_results_all_graph_1_20_8b.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_2wikimultihop_answerer_test_results_all_graph_5_20_8b.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_pubhealth_answerer_test_results_all_graph_5_20.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_pubhealth_answerer_test_results_all_graph_7_20_8b.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_triviaqa_answerer_test_results_all_graph_7_20.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_2wikimultihop_answerer_test_results_all_graph_7_20.json\"\n",
    "# file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_2wikimultihop_answerer_test_results_all_graph_10_20_8b.json\"\n",
    "# \n",
    "file_path = \"/shared/eng/pj20/firas_data/test_datasets/results/graphllm_triviaqa_answerer_test_results_graph_0.7_8b.json\"\n",
    "\n",
    "if \"jsonl\" in file_path:\n",
    "    data = load_file(file_path)\n",
    "else:\n",
    "    with open(file_path, \"r\") as f:\n",
    "        data_ = json.load(f)\n",
    "\n",
    "    data = []\n",
    "    if \"graphllm\" in file_path:\n",
    "        print(\"true\")\n",
    "        if \"2wikimultihop\" not in file_path and \"pubhealth\" not in file_path and \"arc_c\" not in file_path:\n",
    "            for item in data_:\n",
    "                data.append({\"output\": item['prediction'].lower(), \"golds\": [d.lower() for d in item['label']]})\n",
    "        elif \"pubhealth\" in file_path or \"arc_c\" in file_path:\n",
    "            for item in data_:\n",
    "                data.append({\"output\": item['prediction'].lower(), \"golds\": item['label'].lower()})\n",
    "        else:\n",
    "            for item in data_:\n",
    "                data.append({\"output\": item['prediction'].lower(), \"golds\": [item['label'].lower()]})\n",
    "\n",
    "    elif \"2wikimultihop\" in file_path:\n",
    "        for i in range(len(data_['output'])):\n",
    "            data.append({\"output\": data_['output'][i].lower(), \"golds\": [data_['answer'][i].lower()]})\n",
    "    else:\n",
    "        for i in range(len(data_['output'])):\n",
    "            data.append({\"output\": data_['output'][i].lower(), \"golds\": [d.lower() for d in data_['answer'][i]]})\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data = [d for d in data if d[\"output\"] == \"true\" or d[\"output\"] == \"false\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "# golds_mapping = {\n",
    "#     \"1\": \"a\",\n",
    "#     \"2\": \"b\",\n",
    "#     \"3\": \"c\",\n",
    "#     \"4\": \"d\",\n",
    "#     \"a\": \"a\",\n",
    "#     \"b\": \"b\",\n",
    "#     \"c\": \"c\",\n",
    "#     \"d\": \"d\"\n",
    "# }\n",
    "\n",
    "\n",
    "# data_[\"metric_result\"] = []\n",
    "metric_result_1 = []\n",
    "for i in tqdm(range(len(data))):\n",
    "    # pubhealth\n",
    "    # data_[\"metric_result\"].append(accuracy(data[i][\"output\"], data[i][\"golds\"]))\n",
    "    # triviaqa, popqa, arc-c\n",
    "    # data_[\"metric_result\"].append(match(data[i][\"output\"], data[i][\"golds\"]))\n",
    "    # if data_[\"metric_result\"][i] == 1:\n",
    "    #     data_[\"answer\"][i].append(data[i][\"output\"])\n",
    "    # result, add = cosine_match(data[i][\"output\"], data[i][\"golds\"])\n",
    "    \n",
    "    # result= match(data[i][\"output\"], golds_mapping[data[i][\"golds\"][0]])\n",
    "    if \"apolo\" in data[i][\"output\"].lower() or \"the information\" in data[i][\"output\"].lower():\n",
    "        continue\n",
    "    \n",
    "    # result, add= cosine_match(data[i][\"output\"], data[i][\"golds\"])\n",
    "    # result = f1_score(data[i][\"output\"], data[i][\"golds\"])\n",
    "    # result = accuracy(data[i][\"output\"], data[i][\"golds\"])\n",
    "    result = match(data[i][\"output\"], data[i][\"golds\"])\n",
    "    \n",
    "    metric_result_1.append(result)\n",
    "    # if add:\n",
    "    #     data_[i]['label'].append(data[i][\"output\"])\n",
    "    #     print(f\"Added: {data[i]['output']} to {data_[i]['label']}\")\n",
    "    # if add:\n",
    "    #     data_['answer'][i].append(data[i][\"output\"])\n",
    "    #     print(f\"Added: {data[i]['output']} to {data_['answer'][i]}\")\n",
    "    \n",
    "# print(\"Match: \", np.mean(data_[\"metric_result\"]))\n",
    "print(\"Match: \", np.mean(metric_result_1))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "handbook",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
