{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "353800af-88ec-4980-ba16-a68bb4bc3008",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from shared_utils.indexing_utils import SparseIndexer, DocumentCollection\n",
    "import json\n",
    "import jsonlines\n",
    "from tqdm import tqdm\n",
    "from copy import deepcopy\n",
    "import io\n",
    "import argparse\n",
    "from statistics import mean, stdev\n",
    "4\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import pytrec_eval\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "207efa5f-9669-433d-93b6-ab3e502a0abc",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"/data/../nlp_data/LongAlpaca-12k/LongAlpaca-12k.json\"\n",
    "lines = json.load(open(path,\"r\", encoding=\"utf-8\"))\n",
    "\n",
    "attr_required = list(lines[0].keys())\n",
    "list(lines[0].keys()), list(lines[-1].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0db498d8-79a6-496d-a416-7b95e0d8c1a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def print_res(run_file, qrel_data, rel_threshold,\n",
    "               qrel_prefix=None, return_summary=True):\n",
    "    # qrel_prefix=\"/../../nlp_data/kisti/(분류) 국내 논문 전문 텍스트 데이터셋/\"\n",
    "    with open(run_file, 'r' )as f:\n",
    "        run_data = f.readlines()\n",
    "    # with open(qrel_file, 'r') as f:\n",
    "    #     qrel_data = f.readlines()\n",
    "    # print(run_data)\n",
    "    qrels = {}\n",
    "    qrels_ndcg = {}\n",
    "    runs = {}\n",
    "    \n",
    "    for line in qrel_data:\n",
    "        line = line.strip().split()\n",
    "        # print(\"line: \", line)\n",
    "        query = line[0]\n",
    "        passage = line[2]\n",
    "        rel = int(line[3])\n",
    "        if query not in qrels:\n",
    "            qrels[query] = {}\n",
    "        if query not in qrels_ndcg:\n",
    "            qrels_ndcg[query] = {}\n",
    "\n",
    "        # for NDCG\n",
    "        if qrel_prefix:\n",
    "            passage = qrel_prefix + passage\n",
    "        qrels_ndcg[query][passage] = rel\n",
    "        # for MAP, MRR, Recall\n",
    "        if rel >= rel_threshold:\n",
    "            rel = 1\n",
    "        else:\n",
    "            rel = 0\n",
    "        qrels[query][passage] = rel\n",
    "    \n",
    "    for line in run_data:\n",
    "        original_line = deepcopy(line)\n",
    "        try:\n",
    "            line = line.split(\"\\t\")\n",
    "            query = line[0]\n",
    "            passage = line[2]\n",
    "            rel = int(line[4])\n",
    "        except IndexError:\n",
    "            line = original_line.split(\" \")\n",
    "            query = line[0]\n",
    "            passage = line[2]\n",
    "            rel = int(line[4])\n",
    "            \n",
    "            \n",
    "        if query not in runs:\n",
    "            runs[query] = {}\n",
    "        runs[query][passage] = rel\n",
    "\n",
    "    # pytrec_eval eval\n",
    "    evaluator = pytrec_eval.RelevanceEvaluator(qrels, {\"map\", \"recip_rank\", \"recall.1\", \"recall.3\", \n",
    "                                                       \"recall.5\", \"recall.10\", \"recall.20\", \"recall.100\"})\n",
    "    res = evaluator.evaluate(runs)\n",
    "    # map_list = [v['map'] for v in res.values()]\n",
    "    mrr_list = [v['recip_rank'] for v in res.values()]\n",
    "    recall_100_list = [v['recall_100'] for v in res.values()]\n",
    "    recall_20_list = [v['recall_20'] for v in res.values()]\n",
    "    recall_10_list = [v['recall_10'] for v in res.values()]\n",
    "    recall_5_list = [v['recall_5'] for v in res.values()]\n",
    "    recall_3_list = [v['recall_3'] for v in res.values()]\n",
    "    recall_1_list = [v['recall_1'] for v in res.values()]\n",
    "    # print(res)\n",
    "\n",
    "    evaluator = pytrec_eval.RelevanceEvaluator(qrels_ndcg, {\"ndcg_cut.3\"})\n",
    "    res_ndcg = evaluator.evaluate(runs)\n",
    "    ndcg_3_list = [v['ndcg_cut_3'] for v in res_ndcg.values()]\n",
    "    # print(res)\n",
    "    \n",
    "    res_summary = {\n",
    "            # \"MAP\": round(100*np.average(map_list),2),\n",
    "            \"MRR\": round(100*np.average(mrr_list),2),\n",
    "            \"NDCG@3\": round(100*np.average(ndcg_3_list),2),\n",
    "            \"Recall@1\": round(100*np.average(recall_1_list),2),\n",
    "            \"Recall@3\": round(100*np.average(recall_3_list),2),\n",
    "            \"Recall@5\": round(100*np.average(recall_5_list),2),\n",
    "            \"Recall@10\": round(100*np.average(recall_10_list),2),\n",
    "            \"Recall@20\": round(100*np.average(recall_20_list),2),\n",
    "            \"Recall@100\": round(100*np.average(recall_100_list),2), \n",
    "        }\n",
    "    if return_summary:\n",
    "        return res_summary \n",
    "    else:\n",
    "        for k in res.keys():\n",
    "            res[k].update(res_ndcg[k])\n",
    "        return res\n",
    "    \n",
    "\n",
    "def print_res_pseudo_qrels(run_file, pseudo_qrels, rel_threshold, return_summary=True):\n",
    "    with open(run_file, 'r' )as f:\n",
    "        run_data = f.readlines()\n",
    "    # with open(qrel_file, 'r') as f:\n",
    "    #     qrel_data = f.readlines()\n",
    "    # print(run_data)\n",
    "    qrels = pseudo_qrels # {}\n",
    "    qrels_ndcg = pseudo_qrels # {}\n",
    "    runs = {}\n",
    "    \n",
    "    \n",
    "    for line in run_data:\n",
    "        line = line.split(\" \")\n",
    "        query = line[0]\n",
    "        passage = line[2]\n",
    "        rel = int(line[4])\n",
    "        if query not in runs:\n",
    "            runs[query] = {}\n",
    "        runs[query][passage] = rel\n",
    "\n",
    "    # pytrec_eval eval\n",
    "    evaluator = pytrec_eval.RelevanceEvaluator(qrels, {\"map\", \"recip_rank\", \"recall.1\", \"recall.3\", \"recall.5\", \"recall.10\", \n",
    "                         \"recall.20\", \"recall.30\", \"recall.100\", })\n",
    "    res = evaluator.evaluate(runs)\n",
    "    # map_list = [v['map'] for v in res.values()]\n",
    "    mrr_list = [v['recip_rank'] for v in res.values()]\n",
    "    recall_1_list = [v[\"recall_1\"] for v in res.values()]\n",
    "    recall_3_list = [v[\"recall_3\"] for v in res.values()]\n",
    "    recall_5_list = [v[\"recall_5\"] for v in res.values()]\n",
    "    recall_10_list = [v[\"recall_10\"] for v in res.values()]\n",
    "    recall_20_list = [v[\"recall_20\"] for v in res.values()]\n",
    "    recall_30_list = [v[\"recall_30\"] for v in res.values()]\n",
    "    recall_100_list = [v[\"recall_100\"] for v in res.values()]\n",
    "    # print(res)\n",
    "\n",
    "    evaluator = pytrec_eval.RelevanceEvaluator(qrels_ndcg, {\"ndcg_cut.3\"})\n",
    "    res_ndcg = evaluator.evaluate(runs)\n",
    "    ndcg_3_list = [v['ndcg_cut_3'] for v in res_ndcg.values()]\n",
    "    # print(res)\n",
    "    \n",
    "    res_summary = {\n",
    "            # \"MAP\": round(100*np.average(map_list),2),\n",
    "            \"MRR\": round(100*np.average(mrr_list),2),\n",
    "            \"NDCG@3\": round(100*np.average(ndcg_3_list),2),\n",
    "            \"Recall@1\": round(100*np.average(recall_1_list), 2),\n",
    "            \"Recall@3\": round(100*np.average(recall_3_list), 2),\n",
    "            \"Recall@5\": round(100*np.average(recall_5_list), 2),\n",
    "            \"Recall@10\": round(100*np.average(recall_10_list), 2),\n",
    "            \"Recall@20\": round(100*np.average(recall_20_list), 2),\n",
    "            \"Recall@30\": round(100*np.average(recall_30_list), 2),\n",
    "            \"Recall@100\": round(100*np.average(recall_100_list), 2),\n",
    "        }\n",
    "    if return_summary:\n",
    "        return res_summary \n",
    "    else:\n",
    "        for k in res.keys():\n",
    "            res[k].update(res_ndcg[k])\n",
    "        return res\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef4466a2-356b-40ce-beb9-a5fd19c0db2c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# print(lines[500]['instruction'])\n",
    "\n",
    "# print(lines[500]['output'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23aa3247-f950-4e13-bf4e-0f138b8df3d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_prompt(line, args, n_recent=3):\n",
    "    # (w pssg, wo pssg) x (icl, zsl): prompt_type, use_pssg\n",
    "    \n",
    "    # Inst: \n",
    "    # \"Given a question and its context, decontextualize the question by addressing coreference and omission issues. \n",
    "    # The resulting question should retain its original meaning and be as informative as possible, \n",
    "    # and should not duplicate any previously asked questions in the context.\"\n",
    "    # if args.use_pssg:    \n",
    "    #     Instruction = \"Given a question, its previous questions (Q) & answers (A) and retrieved documents (Document), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context.\"\n",
    "    # else:\n",
    "    #     Instruction = \"Given a question and its context, decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context.\"\n",
    "    \n",
    "    if args.use_pssg:    \n",
    "        if args.instruct_pssg == 'original':\n",
    "            Instruction = \"Given a question, its previous questions (Q), retrieved documents (Document), and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context.\"\n",
    "        elif args.instruct_pssg == 'filter_irrelevant':\n",
    "            Instruction = \"Given a question, its previous questions (Q), retrieved documents (Document), and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Use the documents to enrich your question if they're relevant, or draw on the Q&A context for a precise reformulation if the documents aren't helpful.\"\n",
    "        elif args.instruct_pssg == 'summary':\n",
    "            Instruction = \"Given a question, its previous questions (Q), retrieved documents (Document), and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Given the potential noise and dependencies within the context, creating a concise summary of it first could be an effective strategy for accurately rephrasing the question. Therefore, start by summarizing the context before you decontextualize the question.\"\n",
    "        elif args.instruct_pssg == 'filter_irrelevant_summary':\n",
    "            Instruction = \"Given a question, its previous questions (Q), retrieved documents (Document), and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Use the documents to enrich your question if they're relevant, or draw on the Q&A context for a precise reformulation if the documents aren't helpful. Considering the potential noise and dependencies within the context, creating a concise summary of it first could be an effective strategy for accurately rephrasing the question. Therefore, start by summarizing the context before you decontextualize the question.\"\n",
    "        elif args.instruct_pssg == 'reasoning':\n",
    "            Instruction = \"Given a question, its previous questions (Q), retrieved documents (Document), and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Use the documents to enrich your question if they're relevant, or draw on the Q&A context for a precise reformulation if the documents aren't helpful.\"\n",
    "            Instruction = Instruction + \" Before rewriting, evaluate which parts of the context are essential to address, helping to rewrite your question effectively.\"\n",
    "    else:\n",
    "        if args.instruct_pssg == 'original':\n",
    "            Instruction = \"Given a question, its previous questions (Q) and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context.\"\n",
    "        elif args.instruct_pssg == 'filter_irrelevant':\n",
    "            Instruction = \"Given a question, its previous questions (Q) and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context.\"\n",
    "        elif args.instruct_pssg == 'summary':\n",
    "            Instruction = \"Given a question, its previous questions (Q) and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Given the potential noise and dependencies within the context, creating a concise summary of it first could be an effective strategy for accurately rephrasing the question. Therefore, start by summarizing the context before you decontextualize the question.\"\n",
    "        elif args.instruct_pssg == 'filter_irrelevant_summary':\n",
    "            Instruction = \"Given a question, its previous questions (Q) and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Considering the potential noise and dependencies within the context, creating a concise summary of it first could be an effective strategy for accurately rephrasing the question. Therefore, start by summarizing the context before you decontextualize the question.\"\n",
    "        elif args.instruct_pssg == 'reasoning':\n",
    "            Instruction = \"Given a question, its previous questions (Q) and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context.\"\n",
    "            Instruction = Instruction + \" Before rewriting, evaluate which parts of the context are essential to address, helping to rewrite your question effectively.\"\n",
    "            \n",
    "            \n",
    "    curr_ctx = []\n",
    "    if args.use_pssg: # using {}-passages-per-line.json\n",
    "        n_prev_QAturn = len(line['NewContext'])//2\n",
    "        s_idx_adddocs = max(n_prev_QAturn - n_recent, 0) * 2 # starting-idx to add passage\n",
    "        p_docs = [ f\"Document: {d}.\" for d in line['Truth_passages_contents'][-n_recent:] ] # recent top1 docs\n",
    "        \n",
    "        p_docs_i = 0\n",
    "        # (Q-Doc-A)-...\n",
    "        for idx, sent in enumerate(line['NewContext']): # run the below when turn_no >= 1\n",
    "            if idx % 2 == 0:\n",
    "                curr_ctx.append(f\"Q: {sent}\")\n",
    "                if idx >= s_idx_adddocs:\n",
    "                    curr_ctx.append(p_docs[p_docs_i])\n",
    "                    p_docs_i += 1\n",
    "                else:\n",
    "                    curr_ctx.append(\"Document: No relevant documents.\")\n",
    "            else:\n",
    "                curr_ctx.append(f\"A: {sent}\")\n",
    "                \n",
    "    else:\n",
    "        ctx = [ x for pair in zip(line[\"history_query\"], line[\"history_answer\"]) for x in pair]\n",
    "        for idx, sent in enumerate(ctx):\n",
    "            if idx % 2 == 0:\n",
    "                curr_ctx.append(f\"Q: {sent}\")\n",
    "            else:\n",
    "                curr_ctx.append(f\"A: {sent}\")\n",
    "                \n",
    "    curr_ctx = \" \".join(curr_ctx)\n",
    "    curr_ctx = f\"[{curr_ctx}]\"\n",
    "    \n",
    "    if args.prompt_type == \"icl\":\n",
    "        if args.use_pssg:\n",
    "            e1 = \"Context: [Q: When was Born to Fly released? Document: Born to Fly is a song co-written and recorded by American country music artist Sara Evans. It was released in June 2000 as the first single and title track from her 2000 album of the same name. A: Sara Evans's third studio album, Born to Fly, was released on October 10, 2000.] \\nQuestion: Was Born to Fly well received by critics?\\nRewrite: Was Born to Fly well received by critics?\"\n",
    "            e2 = \"Context: [Q: When was Keith Carradine born? Document: No relevant documents. A: Keith Ian Carradine was born August 8, 1949. Q: Is he married? Document: Carradine married Sandra Will on February 6, 1982. They were separated in 1993, before Will filed for divorce in 1999. The couple had two children: Cade Richmond Carradine (born July 19, 1982) and Sorel Johannah Carradine (born June 18, 1985). A: Keith Carradine married Sandra Will on February 6, 1982.]\\nQuestion: Do they have any children?\\nRewrite: Do Keith Carradine and Sandra Will have any children?\"\n",
    "            e3 = \"Context: [Q: Who proposed that atoms are the basic units of matter? Document: Arguably the most important of all Dalton's investigations are concerned with the atomic theory in chemistry. While his name is inseparably associated with this theory, the origin of Dalton's atomic theory is not fully understood. The theory may have been suggested to him either by researches on ethylene (olefiant gas) and methane (carburetted hydrogen) or by analysis of nitrous oxide (protoxide of azote) and nitrogen dioxide (deutoxide of azote), both views resting on the authority of Thomas Thomson. A: John Dalton proposed that each chemical element is composed of atoms of a single, unique type, and they can combine to form more complex structures called chemical compounds.] \\nQuestion: How did the proposal come about?\\nRewrite: How did John Dalton's proposal that each chemical element is composed of atoms of a single unique type, and they can combine to form more complex structures called chemical compounds come about?\"\n",
    "            e4 = \"Context: [Q: What is it called when two liquids separate? Document: Decantation is a process for the separation of mixtures of immiscible liquids or of a liquid and a solid mixture such as a suspension. The layer closer to the top of the container—the less dense of the two liquids, or the liquid from which the precipitate or sediment has settled out—is poured off, leaving denser liquid or the solid behind. The process typically is unable to remove all of the top layer, meaning the separation is incomplete or at least one of the two separated components is still contaminated by the other one. A: Decantation is a process for the separation of mixtures of immiscible liquids or of a liquid and a solid mixture such as a suspension.  Q: How does the separation occur?  Document: No relevant documents.  A: The layer closer to the top of the container-the less dense of the two liquids, or the liquid from which the precipitate or sediment has settled out-is poured off.]\\nQuestion: Then what happens?\\nRewrite: Then what happens after the layer closer to the top of the container is poured off with decantation?\"\n",
    "            if args.instruct_pssg == 'original' or args.instruct_pssg == 'filter_irrelevant':\n",
    "                e1, e2, e3, e4 = e1, e2, e3, e4\n",
    "\n",
    "            elif args.instruct_pssg == 'summary' or args.instruct_pssg == 'filter_irrelevant_summary':\n",
    "                e1_tldr = \"TLDR Summary: Born to Fly is both a song and the title of Sara Evans's third studio album. The song was released as the album's first single in June 2000, and the album itself was released on October 10, 2000.\"\n",
    "                e2_tldr = \"TLDR Summary: Keith Ian Carradine, born on August 8, 1949, married Sandra Will on February 6, 1982. They separated in 1993, and Sandra Will filed for divorce in 1999. The couple has two children, Cade Richmond Carradine and Sorel Johannah Carradine.\"\n",
    "                e3_tldr = \"TLDR Summary: John Dalton proposed the atomic theory, which posits that atoms are the fundamental units of matter, with each chemical element being composed of unique atoms that can combine to form complex compounds. The exact inspiration for Dalton's theory is unclear, but it might have stemmed from his research on gases or the analysis of nitrous oxide and nitrogen dioxide, possibly influenced by Thomas Thomson.\"\n",
    "                e4_tldr = \"TLDR Summary: The context explains decantation, a separation process for mixtures of immiscible liquids or liquid-solid mixtures like suspensions. It involves pouring off the top, less dense liquid or the liquid cleared of sediment, leaving behind the denser liquid or solid. The process may not completely remove the top layer, potentially leaving some contamination.\"\n",
    "\n",
    "                e1 = e1.split('Rewrite:')[0] + 'Rewrite: ' + e1_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e1.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e2 = e2.split('Rewrite:')[0] + 'Rewrite: ' + e2_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e2.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e3 = e3.split('Rewrite:')[0] + 'Rewrite: ' + e3_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e3.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e4 = e4.split('Rewrite:')[0] + 'Rewrite: ' + e4_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e4.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "\n",
    "            elif args.instruct_pssg == 'reasoning':\n",
    "                e1_reasoning = \"The question is already clear.\"\n",
    "                e2_reasoning = \"The original question uses the pronoun \\\"they\\\" which is ambiguous without explicit context. By specifying \\\"Keith Carradine and Sandra Will\\\" as the subjects, the revised question eliminates any ambiguity about who \\\"they\\\" refers to, directly connecting the inquiry to the individuals mentioned in the previous context.\"\n",
    "                e3_reasoning = \"The original question omits what the proposal actually is. Including the specific details of Dalton's atomic theory (that each chemical element is composed of atoms of a single unique type, and they can combine to form more complex structures called chemical compounds) directly in the question adds necessary context and allows the question to stand alone, making it understandable even without prior knowledge of the conversation.\"\n",
    "                e4_reasoning = \"The context revolves around decantation, a specific scientific process. Recognizing this as the core topic ensures that the rewrite focuses on the next logical step in this particular procedure. Question: Then what happens? is vague without specifying what it refers to. By identifying that it refers to the action of pouring off the top layer in the decantation process, we address coreference issues, making it clear what the 'then' is referring to.\"\n",
    "\n",
    "                e1 = e1.split('Rewrite:')[0] + 'Rewrite: ' + e1_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e1.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e2 = e2.split('Rewrite:')[0] + 'Rewrite: ' + e2_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e2.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e3 = e3.split('Rewrite:')[0] + 'Rewrite: ' + e3_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e3.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e4 = e4.split('Rewrite:')[0] + 'Rewrite: ' + e4_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e4.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "\n",
    "        else: # without past passages    \n",
    "            \n",
    "            e1 = \"Context: [Q: When was Born to Fly released? A: Sara Evans's third studio album, Born to Fly, was released on October 10, 2000.]\\nQuestion: Was Born to Fly well received by critics?\\nRewrite: Was Born to Fly well received by critics?\"\n",
    "            e2 = \"Context: [Q: When was Keith Carradine born? A: Keith Ian Carradine was born August 8, 1949. Q: Is he married? A: Keith Carradine married Sandra Will on February 6, 1982.]\\nQuestion: Do they have any children?\\nRewrite: Do Keith Carradine and Sandra Will have any children?\"\n",
    "            e3 = \"Context: [Q: Who proposed that atoms are the basic units of matter? A: John Dalton proposed that each chemical element is composed of atoms of a single, unique type, and they can combine to form more complex structures called chemical compounds.]\\nQuestion: How did the proposal come about?\\nRewrite: How did John Dalton's proposal that each chemical element is composed of atoms of a single unique type, and they can combine to form more complex structures called chemical compounds come about?\"\n",
    "            e4 = \"Context: [Q: What is it called when two liquids separate? A: Decantation is a process for the separation of mixtures of immiscible liquids or of a liquid and a solid mixture such as a suspension. Q: How does the separation occur? A: The layer closer to the top of the container-the less dense of the two liquids, or the liquid from which the precipitate or sediment has settled out-is poured off.]\\nQuestion: Then what happens?\\nRewrite: Then what happens after the layer closer to the top of the container is poured off with decantation?\"\n",
    "            # e4 = \"Context: [No previous conversation.]\\nQuestion: Then what happens?\\nRewrite: Then what happens after the layer closer to the top of the container is poured off with decantation?\"\n",
    "            \n",
    "            if args.instruct_pssg == 'original' or args.instruct_pssg == 'filter_irrelevant':\n",
    "                e1, e2, e3, e4 = e1, e2, e3, e4\n",
    "            \n",
    "            elif args.instruct_pssg == 'summary' or args.instruct_pssg == 'filter_irrelevant_summary':\n",
    "                e1_tldr = \"TLDR Summary: Inquiry about the release date of Sara Evans's album \\\"Born to Fly,\\\" which was on October 10, 2000.\"\n",
    "                e2_tldr = \"TLDR Summary: Inquiry about Keith Carradine's birth date, which is August 8, 1949, and marital status, revealing he married Sandra Will on February 6, 1982.\"\n",
    "                e3_tldr = \"TLDR Summary: John Dalton proposed atoms as the basic units of matter, which can combine to form chemical compounds.\"\n",
    "                e4_tldr = \"TLDR Summary: Decantation separates mixtures of immiscible liquids or liquids and solids by pouring off the top layer after settling.\"\n",
    "\n",
    "                e1 = e1.split('Rewrite:')[0] + 'Rewrite: ' + e1_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e1.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e2 = e2.split('Rewrite:')[0] + 'Rewrite: ' + e2_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e2.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e3 = e3.split('Rewrite:')[0] + 'Rewrite: ' + e3_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e3.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e4 = e4.split('Rewrite:')[0] + 'Rewrite: ' + e4_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e4.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "\n",
    "            elif args.instruct_pssg == 'reasoning':\n",
    "                e1_reasoning = \"The question is already clear.\"\n",
    "                e2_reasoning = \"The question \\\"Do they have any children?\\\" is ambiguous without directly referencing who \\\"they\\\" are. By naming \\\"Keith Carradine and Sandra Will\\\" explicitly, we eliminate any ambiguity regarding who the question is about.\"\n",
    "                e3_reasoning = \"The question \\\"How did the proposal come about?\\\" is vague because it doesn't specify which proposal it's referring to. By restating that the proposal is about \\\"each chemical element being composed of atoms of a single, unique type, and they can combine to form more complex structures called chemical compounds,\\\" we make the question self-contained.\"\n",
    "                e4_reasoning = \"The question \\\"Then what happens?\\\" is vague without specifying which process it refers to. By stating \\\"after the layer closer to the top of the container is poured off,\\\" the question explicitly refers to the action that was previously described, making it clear which stage of the process we're inquiring about what happens next.\"\n",
    "\n",
    "                e1 = e1.split('Rewrite:')[0] + 'Rewrite: ' + e1_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e1.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e2 = e2.split('Rewrite:')[0] + 'Rewrite: ' + e2_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e2.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e3 = e3.split('Rewrite:')[0] + 'Rewrite: ' + e3_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e3.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e4 = e4.split('Rewrite:')[0] + 'Rewrite: ' + e4_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e4.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                         \n",
    "\n",
    "        prompt = f\"{Instruction}\\n\\n{e1}\\n\\n{e2}\\n\\n{e3}\\n\\n{e4}\\n\\nContext: {curr_ctx}\\nQuestion: {line['query']}\\nRewrite: \"\n",
    "        \n",
    "        \n",
    "    elif args.prompt_type == \"zsl\":\n",
    "        prompt = f\"{Instruction}\\n\\nContext: {curr_ctx}\\nQuestion: {line['Question']}\\nRewrite: \"\n",
    "    # print(\"prompt: \", prompt)\n",
    "\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1da9ce3-7760-4b30-9f98-d6db20ce7742",
   "metadata": {},
   "outputs": [],
   "source": [
    "qrel_file = \"/../../nlp_data/kisti/train_gold.trec\"\n",
    "with open(qrel_file, 'r') as f:\n",
    "    qrel_data = f.readlines()\n",
    "    \n",
    "rel_threshold = 1\n",
    "run_file_dir = \"/../../nlp_data/convgqr/bm25/kisti/\"\n",
    "p_type = \"icl\"\n",
    "inst_pssg = \"original\"\n",
    "seed = \"0\"\n",
    "temp = \"8\"\n",
    "topp = \"8\"\n",
    "lr = \"2e-4\"\n",
    "step = 564\n",
    "eval_type = \"oracle\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36e2f2cc-c089-4871-bb5b-2e25df270e7c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for s in range(15):\n",
    "    fname = f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec\"\n",
    "    # /../../nlp_data/infocqr_data/kisti/train_chatgpt_icl_WOpssg_original_originalQ_seed13_temp8_p8.jsonl\n",
    "    run_file = run_file_dir + fname\n",
    "    res = print_res(run_file, qrel_data, rel_threshold, return_summary=True)\n",
    "    print(fname, res)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cdfba7b-cc10-4e48-9c54-0beccd81e982",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "all_res = []\n",
    "for s in range(15):\n",
    "    fname = f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec\"\n",
    "    run_file = run_file_dir + fname\n",
    "    res = print_res(run_file, qrel_data, rel_threshold, return_summary=False)\n",
    "    all_res += [res]\n",
    "    \n",
    "# all_res = [res1, ..., res12]\n",
    "best_res_dict = {}\n",
    "conv_q_ids = list(all_res[0].keys())\n",
    "# print(\"conv_q_ids: \", conv_q_ids[:30])\n",
    "for conv_q_i in conv_q_ids:\n",
    "    res_list = []\n",
    "    for res in all_res:\n",
    "        res_list += [res[conv_q_i]]\n",
    "    # print('res_list ', res_list)\n",
    "    # take best\n",
    "    # Calculate the average score for each dictionary\n",
    "    avg_scores = [sum(d.values()) / len(d) for d in res_list]\n",
    "\n",
    "    # Identify the index of the dictionary with the highest average score\n",
    "    index_of_highest_avg = avg_scores.index(max(avg_scores))\n",
    "\n",
    "    # Retrieve the dictionary with the highest average score\n",
    "    dict_with_highest_avg = res_list[index_of_highest_avg]\n",
    "    # print('dict_with_highest_avg: ', dict_with_highest_avg)\n",
    "    best_res_dict[conv_q_i] = dict_with_highest_avg\n",
    "\n",
    "metrics = best_res_dict\n",
    "map_list = [v['map'] for v in metrics.values()]\n",
    "mrr_list = [v['recip_rank'] for v in metrics.values()]\n",
    "recall_100_list = [v['recall_100'] for v in metrics.values()]\n",
    "recall_20_list = [v['recall_20'] for v in metrics.values()]\n",
    "recall_10_list = [v['recall_10'] for v in metrics.values()]\n",
    "recall_5_list = [v['recall_5'] for v in metrics.values()]\n",
    "ndcg_3_list = [v['ndcg_cut_3'] for v in metrics.values()]\n",
    "\n",
    "np.set_printoptions(precision=4)\n",
    "\n",
    "eval_metrics = {\n",
    "            \"MAP\": round(100*np.average(map_list),2),\n",
    "            \"MRR\": round(100*np.average(mrr_list),2),\n",
    "            \"NDCG@3\": round(100*np.average(ndcg_3_list),2),\n",
    "            \"Recall@5\": round(100*np.average(recall_5_list),2),\n",
    "            \"Recall@10\": round(100*np.average(recall_10_list),2),\n",
    "            \"Recall@20\": round(100*np.average(recall_20_list),2),\n",
    "            \"Recall@100\": round(100*np.average(recall_100_list),2), \n",
    "        }\n",
    "eval_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9173dd3-d20b-4f57-b5aa-24f2019cbf4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get perf dicts by conv-turn\n",
    "\n",
    "all_res = all_res # + [gt_res]\n",
    "stat_res_dict = {}\n",
    "detailed_res_dict = {}\n",
    "conv_q_ids = list(all_res[0].keys())\n",
    "# print(\"conv_q_ids: \", conv_q_ids[:30])\n",
    "for conv_q_i in conv_q_ids:\n",
    "    res_list = []\n",
    "    for res in all_res:\n",
    "        res_list += [res[conv_q_i]]\n",
    "\n",
    "    values_by_d = []\n",
    "\n",
    "    # Populate the lists with values from each dictionary\n",
    "    for d in res_list:\n",
    "        vals = list(d.values())\n",
    "        d_avg = mean(vals)\n",
    "        values_by_d += [d_avg]\n",
    "        \n",
    "    # Calculate averages and stds\n",
    "    avg, std = mean(values_by_d), stdev(values_by_d) if len(values_by_d) > 1 else 0\n",
    "    stat_res_dict[conv_q_i] = (avg,std)\n",
    "    \n",
    "    # save detailed performance\n",
    "    detailed_res_dict[conv_q_i] = res_list\n",
    "    \n",
    "# stat_res_dict\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cded214a-a1a4-4f26-b9a7-5cbf2e14cc25",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get rewritten queries by conv-turns\n",
    "\n",
    "pred_file_dir = \"/../../nlp_data/infocqr_data/kisti/\"\n",
    "\n",
    "# load all pred-queries from temp_paths\n",
    "all_pred_data = {}\n",
    "pred_i = 0\n",
    "for s in range(15):\n",
    "    fname = f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_originalQ_seed{s}_temp{temp}_p{topp}.jsonl\"\n",
    "    pred_file = pred_file_dir + fname\n",
    "    with open(pred_file, \"r\") as f:\n",
    "        data = f.readlines()\n",
    "    data = [json.loads(data[i]) for i in range(len(data))]\n",
    "    all_pred_data[pred_i] = data\n",
    "    pred_i += 1\n",
    "\n",
    "for s in range(21,24):\n",
    "    fname = f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_originalQ_seed{s}_temp{temp}_p{topp}.jsonl\"\n",
    "    pred_file = pred_file_dir + fname\n",
    "    with open(pred_file, \"r\") as f:\n",
    "        data = f.readlines()\n",
    "    data = [json.loads(data[i]) for i in range(len(data))]\n",
    "    all_pred_data[pred_i] = data\n",
    "    pred_i += 1\n",
    "\n",
    "all_proc_preds = {}\n",
    "for i,data in all_pred_data.items():\n",
    "    temp_data = {}\n",
    "    \n",
    "    for dt in tqdm(data):\n",
    "        guid = f\"{dt['conv_id']}-{dt['turn_id']}\"\n",
    "        pred_query = dt['oracle_utt_text']\n",
    "        temp_data[guid] = {'pred_query':pred_query}\n",
    "    all_proc_preds[i] = temp_data\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4f9f14f-8bb7-4884-9be3-3550cad6e462",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Step 1 & 2: Calculate the average score for each key\n",
    "avg_scores = {}\n",
    "for key, values in detailed_res_dict.items():\n",
    "    avg_scores[key] = []\n",
    "    for metrics in values:\n",
    "        # metrics = v  # Assuming we're always interested in the 0-th element\n",
    "        avg_score = sum(metrics.values()) / len(metrics)\n",
    "        avg_scores[key] += [avg_score]\n",
    "\n",
    "# Step 3: Group keys by their average scores\n",
    "grouped_by_avg_score = {}\n",
    "for key, avgs in avg_scores.items():\n",
    "    groups_by_avg = {}\n",
    "    for i, avg in enumerate(avgs):\n",
    "        if avg not in groups_by_avg:\n",
    "            groups_by_avg[avg] = [i]\n",
    "        else:\n",
    "            # print(key, i)\n",
    "            groups_by_avg[avg].append(i)\n",
    "            \n",
    "    grouped_by_avg_score[key] = groups_by_avg\n",
    "\n",
    "# # If you need the groups sorted by the average score\n",
    "# sorted_grouped_by_avg_score = dict(sorted(grouped_by_avg_score.items()))\n",
    "\n",
    "# Displaying the result\n",
    "for avg_score, keys in grouped_by_avg_score.items():\n",
    "    print(f\"Average Score: {avg_score}, Keys: {keys}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77bdc3a7-e0ad-4459-8a07-efdb0834a17e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for k in list(grouped_by_avg_score.keys()):\n",
    "    grouped_by_avg_score[k] = dict(sorted(grouped_by_avg_score[k].items(), key=lambda k: -k[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bfc8b02-3150-4194-8ad1-0469a9d831cf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40cd2d30-b108-4f69-865d-cbe1e00ae5cb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ab4637d-7641-48b4-9f64-73f999b240a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "# back_retrieval_answer_rewrite_bm25 = torch.load(\"/../../nlp_data/llm_qr/outputs/BM25/topi_back_retrieval_answer_query\")\n",
    "back_retrieval_answer_bm25 = torch.load(\"/../../nlp_data/llm_qr/outputs/BM25/kisti_back_retrieval_answer\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3699d913-6e0a-4b21-b703-d4fe298b67e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(back_retrieval_answer_bm25), len(back_retrieval_answer_bm25['1-11'])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b54e4f3b-77c9-45f6-b68f-0e4bf0402a72",
   "metadata": {},
   "outputs": [],
   "source": [
    "# back_retrieval_answer_dpr['1-11'][0], back_retrieval_answer_rewrite_dpr['1-11'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4317677-97f0-4ac9-a44c-fd6a2e64a66f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# back_retrieval_answer_bm25['1-11'][0], \n",
    "back_retrieval_answer_bm25['1-11'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d96332ca-ffc0-4c22-81be-90ba39766c65",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sparse predictions\n",
    "# get predicted docs \n",
    "\n",
    "run_file_dir = \"/../../nlp_data/convgqr/bm25/kisti/\"\n",
    "all_results_cands_bm25 = []\n",
    "# /../../nlp_data/convgqr/bm25/kisti/train_chatgpt_icl_WOpssg_original_seed2_oracle.trec\n",
    "    \n",
    "for s in range(15): #12\n",
    "    runs = {}\n",
    "    fname = f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec\"\n",
    "    run_file = run_file_dir + fname\n",
    "    # res = print_res(run_file, qrel_data, rel_threshold, return_summary=True)\n",
    "    with open(run_file, 'r' )as f:\n",
    "        run_data = f.readlines()\n",
    "    for line in run_data:\n",
    "        line = line.split(\" \")\n",
    "        query = line[0]\n",
    "        passage = line[2]\n",
    "        rel = int(line[4])\n",
    "        if query not in runs:\n",
    "            runs[query] = []\n",
    "        runs[query] += [passage] # [passage] = rel\n",
    "        \n",
    "    all_results_cands_bm25 += [runs] # 12 x 13k x 100\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31e95e0a-3844-425d-bf0e-a5687b479df8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# collect docids\n",
    "all_docids_back_retrieval = []\n",
    "# for qid, list_dict in back_retrieval_answer_dpr.items():\n",
    "#     for dic in list_dict:\n",
    "#         all_docids_back_retrieval += [int(dic['id'])]\n",
    "    \n",
    "# for qid, list_dict in back_retrieval_answer_rewrite_dpr.items():\n",
    "#     for dic in list_dict:\n",
    "#         all_docids_back_retrieval += [int(dic['id'])]\n",
    "        \n",
    "# for qid, list_dict in back_retrieval_answer_bm25.items():\n",
    "#     for dic in list_dict:\n",
    "#         all_docids_back_retrieval += [int(dic['id'])]\n",
    "    \n",
    "for qid, list_dict in back_retrieval_answer_rewrite_bm25.items():\n",
    "    for dic in list_dict:\n",
    "        all_docids_back_retrieval += [int(dic['id'])]\n",
    "        \n",
    "# # all_docids_back_retrieval_answer_rewrite = []\n",
    "# for qid, list_dict in bm25_back_retrieval_answer_rewrite.items():\n",
    "#     for dic in list_dict:\n",
    "#         all_docids_back_retrieval_answer_rewrite += [dic['id']]\n",
    "        \n",
    "#         print(dic['id'], type(dic['id'])) # int\n",
    "#         break\n",
    "#     break\n",
    "    \n",
    "uniq_docids_back_retrieval = set(all_docids_back_retrieval)\n",
    "len(uniq_docids_back_retrieval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd7cd6e5-9002-430c-9c8f-51eba1ff394d",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(all_docids_back_retrieval), len(uniq_docids_back_retrieval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54102c91-dd99-4a96-887f-b38b9771e0f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dense predictions\n",
    "# # get predicted docs \n",
    "\n",
    "# run_file_dir = \"/../../nlp_data/convgqr/ance/test_sentence_trsf_large/\"\n",
    "  \n",
    "# temp_paths = [\n",
    "#     run_file_dir + \"traindt_chatgpt3.5_seed{}_oracle_dpr.json\".format(s) \\\n",
    "#         for s in range(12)\n",
    "# ]\n",
    "\n",
    "# all_results_cands_dpr = []\n",
    "# for i in range(12):\n",
    "#     # j = i+12\n",
    "    \n",
    "#     all_result_ori = json.load(open(temp_paths[i], \"r\"))\n",
    "    \n",
    "#     all_result = {}\n",
    "#     for did in all_result_ori:\n",
    "#         new_did = \"_\".join(did.split(\"_\")[-2:])\n",
    "#         all_result[new_did] = [int(k) for k in list(all_result_ori[did].keys())]\n",
    "        \n",
    "#     all_results_cands_dpr += [all_result] # 12 x 13k x 100\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "233f9bec-abc3-417c-8398-b4694a318a2c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "len(all_results_cands_bm25), len(all_results_cands_bm25[0]), len(all_results_cands_bm25[0]['1-1'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0041a769-6e1b-4894-918f-68decb8efd63",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(all_results_cands_dpr), len(all_results_cands_dpr[0]), len(all_results_cands_dpr[0]['1-1'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c440ee7d-cc51-45b4-a38d-b5f24ef60fc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_doc_ids = []\n",
    "doc_ids_grouped = []\n",
    "for i_gens in all_results_cands_dpr: # 12 of dicts\n",
    "    doc_ids_gen = {}\n",
    "    for qid, qid_top100_dict in i_gens.items(): # 13k x 100\n",
    "        all_doc_ids += list(qid_top100_dict)\n",
    "        doc_ids_gen[qid] = list(qid_top100_dict)\n",
    "    \n",
    "    doc_ids_grouped += [doc_ids_gen]\n",
    "    \n",
    "for i_gens in all_results_cands_bm25: # 12 of dicts\n",
    "    doc_ids_gen = {}\n",
    "    for qid, qid_top100_dict in i_gens.items(): # 13k x 100\n",
    "        all_doc_ids += list(qid_top100_dict)\n",
    "        doc_ids_gen[qid] = list(qid_top100_dict)\n",
    "    \n",
    "    doc_ids_grouped += [doc_ids_gen]\n",
    "    \n",
    "unique_doc_ids = set(all_doc_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73885f47-cbe5-4acf-9d45-bdbc365990fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(unique_doc_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77f90a88-cd87-4ccd-b203-ccee363c6010",
   "metadata": {},
   "outputs": [],
   "source": [
    "docid2emb = torch.load(\"/../../nlp_data/llm_qr/outputs/DPR/topi_docid2emb_curr\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cf59ed2-2552-4481-8d08-b6eb2c7b44a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "docid2emb = {int(k):v for k,v in docid2emb.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a29d6727-e34e-4eaf-af81-8d7c66f90f48",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"#docids that we need to encode: {}, #docids that additionally need to encode: {} \".format(len(uniq_docids_back_retrieval | unique_doc_ids),\n",
    "                                                                                                len((uniq_docids_back_retrieval | unique_doc_ids)- set(docid2emb.keys())))\n",
    "     )\n",
    "additional_uniq_docids = (uniq_docids_back_retrieval | unique_doc_ids) - set(docid2emb.keys())\n",
    "len(additional_uniq_docids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b649bf2e-70da-44de-9c54-fdbc66fbd422",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(additional_uniq_docids, \n",
    "           \"/../../nlp_data/llm_qr/outputs/DPR/topi_add_uniq_docids4emb\")\n",
    "# torch.save(unique_doc_ids, \n",
    "#            \"/../../nlp_data/llm_qr/outputs/DPR/qr_pred_docs4emb\")\n",
    "# torch.save(uniq_docids_back_retrieval_answer_rewrite - unique_doc_ids, \n",
    "#            \"/../../nlp_data/llm_qr/outputs/DPR/qr_pseudo_docs4emb\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a9d8d4b-6a4e-46a3-889b-e729d4006d29",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load from saved embeddings\n",
    "unique_added_doc_embeddings = torch.load(\"/../../nlp_data/llm_qr/outputs/DPR/topi_add_uniq_docs_embs\")\n",
    "print(\"#uniq-add: \",unique_added_doc_embeddings.size())\n",
    "\n",
    "docid2emb.update({doc_id:unique_added_doc_embeddings[i] for i, doc_id in enumerate(list(additional_uniq_docids))})\n",
    "\n",
    "# len(docid2emb), len(docid2emb)== len(set(docid2emb.keys()))+unique_added_doc_embeddings.size(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "727803f0-4aa1-40d9-9e4a-be14e6113407",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(docid2emb, \n",
    "           \"/../../nlp_data/llm_qr/outputs/DPR/topi_docid2emb_curr\")\n",
    "del unique_added_doc_embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9accdeae-5dd9-40b6-aec2-3083fd43bba5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e16f753-5d12-49c2-be92-d1b2fa2e45e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def jaccard_index(set_a, set_b):\n",
    "    # Calculate the intersection of the two sets\n",
    "    intersection = set_a.intersection(set_b)\n",
    "    \n",
    "    # Calculate the union of the two sets\n",
    "    union = set_a.union(set_b)\n",
    "    \n",
    "    # Compute the Jaccard Index, which is the size of the intersection divided by the size of the union\n",
    "    jaccard = len(intersection) / len(union)\n",
    "    \n",
    "    return jaccard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "882a9d7f-7b67-4d71-a4fd-a2c7c2fb0aa3",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(all_res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9f66fff-049a-424b-8494-b690dcbabf82",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"train_chatgpt_icl_WOpssg_original_seed0_oracle.trec\", {'MAP': 13.37, 'MRR': 13.37, 'NDCG@3': 12.5, 'Recall@1': 8.78, 'Recall@3': 15.23, 'Recall@5': 18.14, 'Recall@10': 22.47, 'Recall@20': 26.85, 'Recall@100': 37.33})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e12153b1-e22e-49cd-a6f5-691a03bc51ab",
   "metadata": {},
   "source": [
    "### create SFT data\n",
    "- compute rankings of predictions using psuedo gold\n",
    "- compute perf of best data\n",
    "- generate SFT data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "522edbf0-0605-49d4-b62a-3b7d7e731878",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get rewritten queries by conv-turns\n",
    "\n",
    "pred_file_dir = \"/../../nlp_data/infocqr_data/kisti/\"\n",
    "\n",
    "# load all pred-queries from temp_paths\n",
    "all_pred_data = {}\n",
    "pred_i = 0\n",
    "for s in range(15):\n",
    "    fname = f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_originalQ_seed{s}_temp{temp}_p{topp}.jsonl\"\n",
    "    pred_file = pred_file_dir + fname\n",
    "    with open(pred_file, \"r\") as f:\n",
    "        data = f.readlines()\n",
    "    data = [json.loads(data[i]) for i in range(len(data))]\n",
    "    all_pred_data[pred_i] = data\n",
    "    pred_i += 1\n",
    "\n",
    "\n",
    "all_proc_preds = {}\n",
    "for i,data in all_pred_data.items():\n",
    "    temp_data = {}\n",
    "    \n",
    "    for dt in tqdm(data):\n",
    "        guid = f\"{dt['conv_id']}-{dt['turn_id']}\"\n",
    "        pred_query = dt['oracle_utt_text']\n",
    "        temp_data[guid] = {'pred_query':pred_query}\n",
    "    all_proc_preds[i] = temp_data\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c966c83e-f84a-48b1-af81-87ace81fb94e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d1a7793-c091-405e-8880-fbede9bcecea",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "822d8ff3-a503-4d69-9eef-5d05eefb85dc",
   "metadata": {},
   "source": [
    "#### Best perf with **Pseudo** Gold reference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daeedab7-5fb4-43bf-9aa1-0fef79a0360d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize a dictionary to store results for each cut_pseudo value\n",
    "pseudo_qrels_dict = {}\n",
    "\n",
    "# Iterate over cut_pseudo values from 1 to 10\n",
    "for cut_pseudo in tqdm([3,]):\n",
    "    pseudo_qrels = {}\n",
    "    \n",
    "    # Iterate over keys in grouped_by_avg_score\n",
    "    for iter_i in range(len(grouped_by_avg_score)):\n",
    "        qid = list(grouped_by_avg_score.keys())[iter_i]\n",
    "        \n",
    "        if qid in back_retrieval_answer_bm25:\n",
    "            pseudo_qrels[qid] = {}\n",
    "            n_iters = min(len(back_retrieval_answer_bm25[qid]), 100)\n",
    "            pseudo_gold = set()\n",
    "            \n",
    "            for i in range(n_iters):\n",
    "                pseudo_gold.add(back_retrieval_answer_bm25[qid][i]['id'])\n",
    "                if len(pseudo_gold) >= cut_pseudo:\n",
    "                    break\n",
    "            \n",
    "            for passage in pseudo_gold:\n",
    "                pseudo_qrels[qid][passage] = 1\n",
    "                \n",
    "    # Store the result for the current cut_pseudo value\n",
    "    pseudo_qrels_dict[cut_pseudo] = pseudo_qrels\n",
    "\n",
    "# Optionally, print the length of pseudo_qrels for each cut_pseudo value\n",
    "for cut_pseudo, qrels in pseudo_qrels_dict.items():\n",
    "    print(f\"cut_pseudo = {cut_pseudo}, number of qrels: {len(qrels)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bb10536-ec5d-438b-bf06-7c36a53e464c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize a dictionary to store results for each cut_pseudo value\n",
    "all_res_pseudo_dict = {}\n",
    "\n",
    "# Iterate over each cut_pseudo value and its corresponding pseudo_qrels\n",
    "for cut_pseudo, pseudo_qrels in tqdm(pseudo_qrels_dict.items()):\n",
    "    all_res_pseudo = []\n",
    "    \n",
    "    # Iterate over seed values and compute results\n",
    "    for s in range(15):\n",
    "        fname = f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec\"\n",
    "        run_file = run_file_dir + fname\n",
    "        res = print_res_pseudo_qrels(run_file, pseudo_qrels, rel_threshold, return_summary=False)\n",
    "        all_res_pseudo.append(res)\n",
    "        \n",
    "    # Store the result for the current cut_pseudo value\n",
    "    all_res_pseudo_dict[cut_pseudo] = all_res_pseudo\n",
    "\n",
    "\n",
    "\n",
    "# Optionally, print the results for each cut_pseudo value\n",
    "for cut_pseudo, results in all_res_pseudo_dict.items():\n",
    "    print(f\"cut_pseudo = {cut_pseudo}, results: {len(results), len(results[0])}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc043b76-c0df-4c43-bc60-c9b185241506",
   "metadata": {},
   "outputs": [],
   "source": [
    "    \n",
    "# all_res = [res1, ..., res12]\n",
    "best_res_dict = {}\n",
    "pseudo_avg_scores = {}\n",
    "conv_q_ids = list(all_res_pseudo[0].keys())\n",
    "indice_of_jaccard_highest = {}\n",
    "# print(\"conv_q_ids: \", conv_q_ids[:30])\n",
    "for conv_q_i in conv_q_ids:\n",
    "    res_list = []\n",
    "    for res in all_res:\n",
    "        res_list += [res[conv_q_i]]\n",
    "        \n",
    "    pseudo_res_lists_dict = {}\n",
    "\n",
    "    for cut_pseudo, all_res_pseudo in all_res_pseudo_dict.items():\n",
    "        pseudo_res_list = []\n",
    "\n",
    "        for res in all_res_pseudo:\n",
    "            pseudo_res_list.append(res[conv_q_i])\n",
    "\n",
    "        pseudo_res_lists_dict[cut_pseudo] = pseudo_res_list\n",
    "        \n",
    "    # take best\n",
    "    # Calculate the average score for each dictionary\n",
    "    avg_scores = []\n",
    "    for pred_i in range(len(pseudo_res_lists_dict[1])):\n",
    "        scores = 0\n",
    "        for cut in list(pseudo_res_lists_dict.keys()):\n",
    "            res_dict  = pseudo_res_lists_dict[cut][pred_i]\n",
    "            scores += (sum(res_dict.values()) / len(res_dict))*(1/cut)\n",
    "        avg_scores += [scores]\n",
    "    \n",
    "    pseudo_avg_scores[conv_q_i] = avg_scores\n",
    "    \n",
    "    # Identify the index of the dictionary with the highest average score\n",
    "    index_of_highest_avg = avg_scores.index(max(avg_scores))\n",
    "    \n",
    "    indice_of_jaccard_highest[conv_q_i] = index_of_highest_avg\n",
    "    # Retrieve the dictionary with the highest average score\n",
    "    dict_with_highest_avg = res_list[index_of_highest_avg]\n",
    "    # print('dict_with_highest_avg: ', dict_with_highest_avg)\n",
    "    best_res_dict[conv_q_i] = dict_with_highest_avg\n",
    "\n",
    "metrics = best_res_dict\n",
    "map_list = [v['map'] for v in metrics.values()]\n",
    "mrr_list = [v['recip_rank'] for v in metrics.values()]\n",
    "recall_100_list = [v['recall_100'] for v in metrics.values()]\n",
    "recall_20_list = [v['recall_20'] for v in metrics.values()]\n",
    "recall_10_list = [v['recall_10'] for v in metrics.values()]\n",
    "recall_5_list = [v['recall_5'] for v in metrics.values()]\n",
    "ndcg_3_list = [v['ndcg_cut_3'] for v in metrics.values()]\n",
    "\n",
    "np.set_printoptions(precision=4)\n",
    "\n",
    "eval_metrics = {\n",
    "            \"MAP\": round(100*np.average(map_list),2),\n",
    "            \"MRR\": round(100*np.average(mrr_list),2),\n",
    "            \"NDCG@3\": round(100*np.average(ndcg_3_list),2),\n",
    "            \"Recall@5\": round(100*np.average(recall_5_list),2),\n",
    "            \"Recall@10\": round(100*np.average(recall_10_list),2),\n",
    "            \"Recall@20\": round(100*np.average(recall_20_list),2),\n",
    "            \"Recall@100\": round(100*np.average(recall_100_list),2), \n",
    "        }\n",
    "eval_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0624b0c5-ad2c-40ed-835f-ead6ee457b30",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90f33823-81ae-46c2-aee2-d900f9371c69",
   "metadata": {},
   "outputs": [],
   "source": [
    "pseudo_best_reQ = {}\n",
    "# for k in list(indice_of_jaccard_highest['Abm-Pbm'].keys()):\n",
    "pseudo_best_reQ['3-100'] = {qid: all_proc_preds[pred_ind][qid]['pred_query'] \\\n",
    "                            for qid, pred_ind in indice_of_jaccard_highest.items()}\n",
    "\n",
    "pseudo_best_reQ['3-100']['1-10']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a85bd298-7da3-4d70-9549-57feca8ebe6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "split = 'train'\n",
    "root = \"/../../nlp_data/kisti/\" # 'datasets/qrecc/' \n",
    "input_data = \"train_newAs_comb.json\" # 'train-modified-passages-per-line-add15k.json'\n",
    "with open(os.path.join(root, input_data), encoding=\"utf-8\") as f:\n",
    "    lines = f.readlines()\n",
    "    \n",
    "lines = [json.loads(l) for l in lines]\n",
    "args = {\n",
    "    'use_pssg': False, \n",
    "    'instruct_pssg': 'original',\n",
    "    'prompt_type': 'icl'\n",
    "}\n",
    "args = argparse.Namespace(**args)\n",
    "        \n",
    "gt_data = []\n",
    "best_data = {}\n",
    "for k in list(pseudo_best_reQ.keys()):\n",
    "    best_data[k] = []\n",
    "    for line in tqdm(lines):\n",
    "        conv_id = f\"{line['conv_id']}-{line['turn_id']}\"\n",
    "\n",
    "        prompt = set_prompt(line, args)\n",
    "        best_data[k] += [ {'instruction':prompt, 'output':pseudo_best_reQ[k][conv_id], } ]\n",
    "\n",
    "best_data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b19e7a7-6be9-4402-a6fe-07b66667873f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in list(pseudo_best_reQ.keys()):\n",
    "    f_path = \"/data/../nlp_data/LongAlpaca-12k/Kisti_SFT.json\"\n",
    "    with open(f_path, 'w', encoding='utf-8') as f:\n",
    "        json.dump(best_data[k], f, indent=4)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c54c4a00-c8ed-42af-8e04-22a662adfc0e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56a5994c-2b26-4bdb-b571-0e7cac603836",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f673c4a1-408d-4b2b-a0a5-dfa9cfb2ec21",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "589232b3-5b15-4a49-953e-af9f344a517a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96618f38-23fc-486d-aa61-554593435723",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6172f38d-3c94-464c-bae8-d5baa9ed60c8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b8bafea-d7ce-4f41-93fb-bf6189e92194",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55e41945-5c94-43ce-b5a6-043092d02a3d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee970751-083e-4f4c-9b49-2dee39d3031b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f73b2de8-74d0-429f-81ed-a3effcddd99b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b474230-02a2-481c-a221-6e0a4f277e43",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1294fdd6-3337-4d14-98dc-e196ab8ec2ed",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a756a1e-6840-450a-9094-304e56ed1afe",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "436bcb2c-b8ea-4acb-9dec-12f0931871ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # save concat of qrecc and topiocqa\n",
    "\n",
    "df_qr = pd.read_csv('/data/../nlp_data/LongAlpaca-12k/pref_data.csv')\n",
    "df_qr.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8f739da-6cee-4132-8081-153744e7e9ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pref.shape, df_qr.shape, df_pref.shape[0] + df_qr.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d15f3fdd-16c0-4d55-9326-2d992b115928",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_topi_qr = pd.concat((df_pref,df_qr), axis=0)\n",
    "df_topi_qr.to_csv('/data/../nlp_data/LongAlpaca-12k/pref_data_pseudo_topiocqa_{}_{}_gold_qr.csv'.format(\n",
    "                cut_pseudo_top, cut_gen_top    \n",
    "), index=False)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36eb0331-ef3c-4ccb-8a26-24de6918563b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llmcqr",
   "language": "python",
   "name": "llmcqr"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
