{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60679094-6f53-4c9d-ab3d-1aa8b98707ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import os\n",
    "# os.environ['JAVA_HOME'] = \"/usr/lib/jvm/java-11-openjdk-amd64\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "509a01f4-a9e6-4ee7-966a-9c1cf99f8765",
   "metadata": {},
   "source": [
    "### topi 13k turns!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "353800af-88ec-4980-ba16-a68bb4bc3008",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from shared_utils.indexing_utils import SparseIndexer, DocumentCollection\n",
    "import json\n",
    "import jsonlines\n",
    "from tqdm import tqdm\n",
    "from copy import deepcopy\n",
    "import io\n",
    "import argparse\n",
    "from statistics import mean, stdev\n",
    "4\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import pytrec_eval\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "207efa5f-9669-433d-93b6-ab3e502a0abc",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"/data/../nlp_data/LongAlpaca-12k/LongAlpaca-12k.json\"\n",
    "lines = json.load(open(path,\"r\", encoding=\"utf-8\"))\n",
    "\n",
    "attr_required = list(lines[0].keys())\n",
    "list(lines[0].keys()), list(lines[-1].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0db498d8-79a6-496d-a416-7b95e0d8c1a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def print_res(run_file, qrel_data, rel_threshold, return_summary=True):\n",
    "    with open(run_file, 'r' )as f:\n",
    "        run_data = f.readlines()\n",
    "    # with open(qrel_file, 'r') as f:\n",
    "    #     qrel_data = f.readlines()\n",
    "    # print(run_data)\n",
    "    qrels = {}\n",
    "    qrels_ndcg = {}\n",
    "    runs = {}\n",
    "    \n",
    "    for line in qrel_data:\n",
    "        line = line.strip().split()\n",
    "        query = line[0]\n",
    "        passage = line[2]\n",
    "        rel = int(line[3])\n",
    "        if query not in qrels:\n",
    "            qrels[query] = {}\n",
    "        if query not in qrels_ndcg:\n",
    "            qrels_ndcg[query] = {}\n",
    "\n",
    "        # for NDCG\n",
    "        qrels_ndcg[query][passage] = rel\n",
    "        # for MAP, MRR, Recall\n",
    "        if rel >= rel_threshold:\n",
    "            rel = 1\n",
    "        else:\n",
    "            rel = 0\n",
    "        qrels[query][passage] = rel\n",
    "    \n",
    "    for line in run_data:\n",
    "        line = line.split(\" \")\n",
    "        query = line[0]\n",
    "        passage = line[2]\n",
    "        rel = int(line[4])\n",
    "        if query not in runs:\n",
    "            runs[query] = {}\n",
    "        runs[query][passage] = rel\n",
    "\n",
    "    # pytrec_eval eval\n",
    "    evaluator = pytrec_eval.RelevanceEvaluator(qrels, {\"map\", \"recip_rank\", \"recall.1\", \"recall.3\", \"recall.5\", \"recall.10\", \n",
    "                         \"recall.20\", \"recall.30\", \"recall.100\", })\n",
    "    res = evaluator.evaluate(runs)\n",
    "    # map_list = [v['map'] for v in res.values()]\n",
    "    mrr_list = [v['recip_rank'] for v in res.values()]\n",
    "    recall_1_list = [v[\"recall_1\"] for v in res.values()]\n",
    "    recall_3_list = [v[\"recall_3\"] for v in res.values()]\n",
    "    recall_5_list = [v[\"recall_5\"] for v in res.values()]\n",
    "    recall_10_list = [v[\"recall_10\"] for v in res.values()]\n",
    "    recall_20_list = [v[\"recall_20\"] for v in res.values()]\n",
    "    recall_30_list = [v[\"recall_30\"] for v in res.values()]\n",
    "    recall_100_list = [v[\"recall_100\"] for v in res.values()]\n",
    "    # print(res)\n",
    "\n",
    "    evaluator = pytrec_eval.RelevanceEvaluator(qrels_ndcg, {\"ndcg_cut.3\"})\n",
    "    res_ndcg = evaluator.evaluate(runs)\n",
    "    ndcg_3_list = [v['ndcg_cut_3'] for v in res_ndcg.values()]\n",
    "    # print(res)\n",
    "    \n",
    "    res_summary = {\n",
    "            # \"MAP\": round(100*np.average(map_list),2),\n",
    "            \"MRR\": round(100*np.average(mrr_list),2),\n",
    "            \"NDCG@3\": round(100*np.average(ndcg_3_list),2),\n",
    "            \"Recall@1\": round(100*np.average(recall_1_list), 2),\n",
    "            \"Recall@3\": round(100*np.average(recall_3_list), 2),\n",
    "            \"Recall@5\": round(100*np.average(recall_5_list), 2),\n",
    "            \"Recall@10\": round(100*np.average(recall_10_list), 2),\n",
    "            \"Recall@20\": round(100*np.average(recall_20_list), 2),\n",
    "            \"Recall@30\": round(100*np.average(recall_30_list), 2),\n",
    "            \"Recall@100\": round(100*np.average(recall_100_list), 2),\n",
    "        }\n",
    "    if return_summary:\n",
    "        return res_summary \n",
    "    else:\n",
    "        for k in res.keys():\n",
    "            res[k].update(res_ndcg[k])\n",
    "        return res\n",
    "    \n",
    "def print_res_pseudo_qrels(run_file, pseudo_qrels, rel_threshold, return_summary=True):\n",
    "    with open(run_file, 'r' )as f:\n",
    "        run_data = f.readlines()\n",
    "    # with open(qrel_file, 'r') as f:\n",
    "    #     qrel_data = f.readlines()\n",
    "    # print(run_data)\n",
    "    qrels = pseudo_qrels # {}\n",
    "    qrels_ndcg = pseudo_qrels # {}\n",
    "    runs = {}\n",
    "    \n",
    "#     for line in qrel_data:\n",
    "#         line = line.strip().split()\n",
    "#         query = line[0]\n",
    "#         passage = line[2]\n",
    "#         rel = int(line[3])\n",
    "#         if query not in qrels:\n",
    "#             qrels[query] = {}\n",
    "#         if query not in qrels_ndcg:\n",
    "#             qrels_ndcg[query] = {}\n",
    "\n",
    "#         # for NDCG\n",
    "#         qrels_ndcg[query][passage] = rel\n",
    "#         # for MAP, MRR, Recall\n",
    "#         if rel >= rel_threshold:\n",
    "#             rel = 1\n",
    "#         else:\n",
    "#             rel = 0\n",
    "#         qrels[query][passage] = rel\n",
    "    \n",
    "    for line in run_data:\n",
    "        line = line.split(\" \")\n",
    "        query = line[0]\n",
    "        passage = line[2]\n",
    "        rel = int(line[4])\n",
    "        if query not in runs:\n",
    "            runs[query] = {}\n",
    "        runs[query][passage] = rel\n",
    "\n",
    "    # pytrec_eval eval\n",
    "    evaluator = pytrec_eval.RelevanceEvaluator(qrels, {\"map\", \"recip_rank\", \"recall.1\", \"recall.3\", \"recall.5\", \"recall.10\", \n",
    "                         \"recall.20\", \"recall.30\", \"recall.100\", })\n",
    "    res = evaluator.evaluate(runs)\n",
    "    # map_list = [v['map'] for v in res.values()]\n",
    "    mrr_list = [v['recip_rank'] for v in res.values()]\n",
    "    recall_1_list = [v[\"recall_1\"] for v in res.values()]\n",
    "    recall_3_list = [v[\"recall_3\"] for v in res.values()]\n",
    "    recall_5_list = [v[\"recall_5\"] for v in res.values()]\n",
    "    recall_10_list = [v[\"recall_10\"] for v in res.values()]\n",
    "    recall_20_list = [v[\"recall_20\"] for v in res.values()]\n",
    "    recall_30_list = [v[\"recall_30\"] for v in res.values()]\n",
    "    recall_100_list = [v[\"recall_100\"] for v in res.values()]\n",
    "    # print(res)\n",
    "\n",
    "    evaluator = pytrec_eval.RelevanceEvaluator(qrels_ndcg, {\"ndcg_cut.3\"})\n",
    "    res_ndcg = evaluator.evaluate(runs)\n",
    "    ndcg_3_list = [v['ndcg_cut_3'] for v in res_ndcg.values()]\n",
    "    # print(res)\n",
    "    \n",
    "    res_summary = {\n",
    "            # \"MAP\": round(100*np.average(map_list),2),\n",
    "            \"MRR\": round(100*np.average(mrr_list),2),\n",
    "            \"NDCG@3\": round(100*np.average(ndcg_3_list),2),\n",
    "            \"Recall@1\": round(100*np.average(recall_1_list), 2),\n",
    "            \"Recall@3\": round(100*np.average(recall_3_list), 2),\n",
    "            \"Recall@5\": round(100*np.average(recall_5_list), 2),\n",
    "            \"Recall@10\": round(100*np.average(recall_10_list), 2),\n",
    "            \"Recall@20\": round(100*np.average(recall_20_list), 2),\n",
    "            \"Recall@30\": round(100*np.average(recall_30_list), 2),\n",
    "            \"Recall@100\": round(100*np.average(recall_100_list), 2),\n",
    "        }\n",
    "    if return_summary:\n",
    "        return res_summary \n",
    "    else:\n",
    "        for k in res.keys():\n",
    "            res[k].update(res_ndcg[k])\n",
    "        return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23aa3247-f950-4e13-bf4e-0f138b8df3d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_prompt(line, args, n_recent=3):\n",
    "    # (w pssg, wo pssg) x (icl, zsl): prompt_type, use_pssg\n",
    "    \n",
    "    # Inst: \n",
    "    # \"Given a question and its context, decontextualize the question by addressing coreference and omission issues. \n",
    "    # The resulting question should retain its original meaning and be as informative as possible, \n",
    "    # and should not duplicate any previously asked questions in the context.\"\n",
    "    # if args.use_pssg:    \n",
    "    #     Instruction = \"Given a question, its previous questions (Q) & answers (A) and retrieved documents (Document), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context.\"\n",
    "    # else:\n",
    "    #     Instruction = \"Given a question and its context, decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context.\"\n",
    "    \n",
    "    if args.use_pssg:    \n",
    "        if args.instruct_pssg == 'original':\n",
    "            Instruction = \"Given a question, its previous questions (Q), retrieved documents (Document), and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context.\"\n",
    "        elif args.instruct_pssg == 'filter_irrelevant':\n",
    "            Instruction = \"Given a question, its previous questions (Q), retrieved documents (Document), and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Use the documents to enrich your question if they're relevant, or draw on the Q&A context for a precise reformulation if the documents aren't helpful.\"\n",
    "        elif args.instruct_pssg == 'summary':\n",
    "            Instruction = \"Given a question, its previous questions (Q), retrieved documents (Document), and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Given the potential noise and dependencies within the context, creating a concise summary of it first could be an effective strategy for accurately rephrasing the question. Therefore, start by summarizing the context before you decontextualize the question.\"\n",
    "        elif args.instruct_pssg == 'filter_irrelevant_summary':\n",
    "            Instruction = \"Given a question, its previous questions (Q), retrieved documents (Document), and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Use the documents to enrich your question if they're relevant, or draw on the Q&A context for a precise reformulation if the documents aren't helpful. Considering the potential noise and dependencies within the context, creating a concise summary of it first could be an effective strategy for accurately rephrasing the question. Therefore, start by summarizing the context before you decontextualize the question.\"\n",
    "        elif args.instruct_pssg == 'reasoning':\n",
    "            Instruction = \"Given a question, its previous questions (Q), retrieved documents (Document), and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Use the documents to enrich your question if they're relevant, or draw on the Q&A context for a precise reformulation if the documents aren't helpful.\"\n",
    "            Instruction = Instruction + \" Before rewriting, evaluate which parts of the context are essential to address, helping to rewrite your question effectively.\"\n",
    "    else:\n",
    "        if args.instruct_pssg == 'original':\n",
    "            Instruction = \"Given a question, its previous questions (Q) and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context.\"\n",
    "        elif args.instruct_pssg == 'filter_irrelevant':\n",
    "            Instruction = \"Given a question, its previous questions (Q) and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context.\"\n",
    "        elif args.instruct_pssg == 'summary':\n",
    "            Instruction = \"Given a question, its previous questions (Q) and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Given the potential noise and dependencies within the context, creating a concise summary of it first could be an effective strategy for accurately rephrasing the question. Therefore, start by summarizing the context before you decontextualize the question.\"\n",
    "        elif args.instruct_pssg == 'filter_irrelevant_summary':\n",
    "            Instruction = \"Given a question, its previous questions (Q) and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Considering the potential noise and dependencies within the context, creating a concise summary of it first could be an effective strategy for accurately rephrasing the question. Therefore, start by summarizing the context before you decontextualize the question.\"\n",
    "        elif args.instruct_pssg == 'reasoning':\n",
    "            Instruction = \"Given a question, its previous questions (Q) and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context.\"\n",
    "            Instruction = Instruction + \" Before rewriting, evaluate which parts of the context are essential to address, helping to rewrite your question effectively.\"\n",
    "            \n",
    "            \n",
    "    curr_ctx = []\n",
    "    if args.use_pssg: # using {}-passages-per-line.json\n",
    "        n_prev_QAturn = len(line['NewContext'])//2\n",
    "        s_idx_adddocs = max(n_prev_QAturn - n_recent, 0) * 2 # starting-idx to add passage\n",
    "        p_docs = [ f\"Document: {d}.\" for d in line['Truth_passages_contents'][-n_recent:] ] # recent top1 docs\n",
    "        \n",
    "        p_docs_i = 0\n",
    "        # (Q-Doc-A)-...\n",
    "        for idx, sent in enumerate(line['NewContext']): # run the below when turn_no >= 1\n",
    "            if idx % 2 == 0:\n",
    "                curr_ctx.append(f\"Q: {sent}\")\n",
    "                if idx >= s_idx_adddocs:\n",
    "                    curr_ctx.append(p_docs[p_docs_i])\n",
    "                    p_docs_i += 1\n",
    "                else:\n",
    "                    curr_ctx.append(\"Document: No relevant documents.\")\n",
    "            else:\n",
    "                curr_ctx.append(f\"A: {sent}\")\n",
    "                \n",
    "    else:\n",
    "        # ctx = [ x for pair in zip(line[\"history_rewrite\"], line[\"history_answer\"]) for x in pair]\n",
    "        ctx = [ x for pair in zip(line[\"history_query\"], line[\"history_answer\"]) for x in pair]\n",
    "        for idx, sent in enumerate(ctx):\n",
    "            if idx % 2 == 0:\n",
    "                curr_ctx.append(f\"Q: {sent}\")\n",
    "            else:\n",
    "                curr_ctx.append(f\"A: {sent}\")\n",
    "                \n",
    "    curr_ctx = \" \".join(curr_ctx)\n",
    "    curr_ctx = f\"[{curr_ctx}]\"\n",
    "    \n",
    "    if args.prompt_type == \"icl\":\n",
    "        if args.use_pssg:\n",
    "            # e1 = \"Context: [Q: When was Born to Fly released? Document: Born to Fly is a song co-written and recorded by American country music artist Sara Evans. It was released in June 2000 as the first single and title track from her 2000 album of the same name. A: Sara Evans's third studio album, Born to Fly, was released on October 10, 2000.] \\nQuestion: Was Born to Fly well received by critics?\\nRewrite: Was Born to Fly well received by critics?\"\n",
    "            # e2 = \"Context: [Q: When was Keith Carradine born? Document: No relevant documents. A: Keith Ian Carradine was born August 8, 1949. Q: Is he married? Document: Carradine married Sandra Will on February 6, 1982. They were separated in 1993, before Will filed for divorce in 1999. The couple had two children: Cade Richmond Carradine (born July 19, 1982) and Sorel Johannah Carradine (born June 18, 1985). A: Keith Carradine married Sandra Will on February 6, 1982.]\\nQuestion: Do they have any children?\\nRewrite: Do Keith Carradine and Sandra Will have any children?\"\n",
    "            # e3 = \"Context: [Q: Who proposed that atoms are the basic units of matter? Document: Arguably the most important of all Dalton's investigations are concerned with the atomic theory in chemistry. While his name is inseparably associated with this theory, the origin of Dalton's atomic theory is not fully understood.[19][20] The theory may have been suggested to him either by researches on ethylene (olefiant gas) and methane (carburetted hydrogen) or by analysis of nitrous oxide (protoxide of azote) and nitrogen dioxide (deutoxide of azote), both views resting on the authority of Thomas Thomson. A: John Dalton proposed that each chemical element is composed of atoms of a single, unique type, and they can combine to form more complex structures called chemical compounds.] \\nQuestion: How did the proposal come about?\\nRewrite: How did John Dalton's proposal that each chemical element is composed of atoms of a single unique type, and they can combine to form more complex structures called chemical compounds come about?\"\n",
    "            # e4 = \"Context: [Q: What is it called when two liquids separate? Document: Decantation is a process for the separation of mixtures of immiscible liquids or of a liquid and a solid mixture such as a suspension.[1] The layer closer to the top of the container—the less dense of the two liquids, or the liquid from which the precipitate or sediment has settled out—is poured off, leaving denser liquid or the solid behind. The process typically is unable to remove all of the top layer, meaning the separation is incomplete or at least one of the two separated components is still contaminated by the other one. A: Decantation is a process for the separation of mixtures of immiscible liquids or of a liquid and a solid mixture such as a suspension.  Q: How does the separation occur?  Document: No relevant documents.  A: The layer closer to the top of the container-the less dense of the two liquids, or the liquid from which the precipitate or sediment has settled out-is poured off.]\\nQuestion: Then what happens?\\nRewrite: Then what happens after the layer closer to the top of the container is poured off with decantation?\"\n",
    "            # # e4 = \"Context: [No previous conversation.]\\nQuestion: Then what happens?\\nRewrite: Then what happens after the layer closer to the top of the container is poured off with decantation?\"\n",
    "            e1 = \"Context: [Q: When was Born to Fly released? Document: Born to Fly is a song co-written and recorded by American country music artist Sara Evans. It was released in June 2000 as the first single and title track from her 2000 album of the same name. A: Sara Evans's third studio album, Born to Fly, was released on October 10, 2000.] \\nQuestion: Was Born to Fly well received by critics?\\nRewrite: Was Born to Fly well received by critics?\"\n",
    "            e2 = \"Context: [Q: When was Keith Carradine born? Document: No relevant documents. A: Keith Ian Carradine was born August 8, 1949. Q: Is he married? Document: Carradine married Sandra Will on February 6, 1982. They were separated in 1993, before Will filed for divorce in 1999. The couple had two children: Cade Richmond Carradine (born July 19, 1982) and Sorel Johannah Carradine (born June 18, 1985). A: Keith Carradine married Sandra Will on February 6, 1982.]\\nQuestion: Do they have any children?\\nRewrite: Do Keith Carradine and Sandra Will have any children?\"\n",
    "            e3 = \"Context: [Q: Who proposed that atoms are the basic units of matter? Document: Arguably the most important of all Dalton's investigations are concerned with the atomic theory in chemistry. While his name is inseparably associated with this theory, the origin of Dalton's atomic theory is not fully understood. The theory may have been suggested to him either by researches on ethylene (olefiant gas) and methane (carburetted hydrogen) or by analysis of nitrous oxide (protoxide of azote) and nitrogen dioxide (deutoxide of azote), both views resting on the authority of Thomas Thomson. A: John Dalton proposed that each chemical element is composed of atoms of a single, unique type, and they can combine to form more complex structures called chemical compounds.] \\nQuestion: How did the proposal come about?\\nRewrite: How did John Dalton's proposal that each chemical element is composed of atoms of a single unique type, and they can combine to form more complex structures called chemical compounds come about?\"\n",
    "            e4 = \"Context: [Q: What is it called when two liquids separate? Document: Decantation is a process for the separation of mixtures of immiscible liquids or of a liquid and a solid mixture such as a suspension. The layer closer to the top of the container—the less dense of the two liquids, or the liquid from which the precipitate or sediment has settled out—is poured off, leaving denser liquid or the solid behind. The process typically is unable to remove all of the top layer, meaning the separation is incomplete or at least one of the two separated components is still contaminated by the other one. A: Decantation is a process for the separation of mixtures of immiscible liquids or of a liquid and a solid mixture such as a suspension.  Q: How does the separation occur?  Document: No relevant documents.  A: The layer closer to the top of the container-the less dense of the two liquids, or the liquid from which the precipitate or sediment has settled out-is poured off.]\\nQuestion: Then what happens?\\nRewrite: Then what happens after the layer closer to the top of the container is poured off with decantation?\"\n",
    "            if args.instruct_pssg == 'original' or args.instruct_pssg == 'filter_irrelevant':\n",
    "                e1, e2, e3, e4 = e1, e2, e3, e4\n",
    "\n",
    "            elif args.instruct_pssg == 'summary' or args.instruct_pssg == 'filter_irrelevant_summary':\n",
    "                e1_tldr = \"TLDR Summary: Born to Fly is both a song and the title of Sara Evans's third studio album. The song was released as the album's first single in June 2000, and the album itself was released on October 10, 2000.\"\n",
    "                e2_tldr = \"TLDR Summary: Keith Ian Carradine, born on August 8, 1949, married Sandra Will on February 6, 1982. They separated in 1993, and Sandra Will filed for divorce in 1999. The couple has two children, Cade Richmond Carradine and Sorel Johannah Carradine.\"\n",
    "                e3_tldr = \"TLDR Summary: John Dalton proposed the atomic theory, which posits that atoms are the fundamental units of matter, with each chemical element being composed of unique atoms that can combine to form complex compounds. The exact inspiration for Dalton's theory is unclear, but it might have stemmed from his research on gases or the analysis of nitrous oxide and nitrogen dioxide, possibly influenced by Thomas Thomson.\"\n",
    "                e4_tldr = \"TLDR Summary: The context explains decantation, a separation process for mixtures of immiscible liquids or liquid-solid mixtures like suspensions. It involves pouring off the top, less dense liquid or the liquid cleared of sediment, leaving behind the denser liquid or solid. The process may not completely remove the top layer, potentially leaving some contamination.\"\n",
    "\n",
    "                e1 = e1.split('Rewrite:')[0] + 'Rewrite: ' + e1_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e1.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e2 = e2.split('Rewrite:')[0] + 'Rewrite: ' + e2_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e2.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e3 = e3.split('Rewrite:')[0] + 'Rewrite: ' + e3_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e3.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e4 = e4.split('Rewrite:')[0] + 'Rewrite: ' + e4_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e4.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "\n",
    "            elif args.instruct_pssg == 'reasoning':\n",
    "                e1_reasoning = \"The question is already clear.\"\n",
    "                e2_reasoning = \"The original question uses the pronoun \\\"they\\\" which is ambiguous without explicit context. By specifying \\\"Keith Carradine and Sandra Will\\\" as the subjects, the revised question eliminates any ambiguity about who \\\"they\\\" refers to, directly connecting the inquiry to the individuals mentioned in the previous context.\"\n",
    "                e3_reasoning = \"The original question omits what the proposal actually is. Including the specific details of Dalton's atomic theory (that each chemical element is composed of atoms of a single unique type, and they can combine to form more complex structures called chemical compounds) directly in the question adds necessary context and allows the question to stand alone, making it understandable even without prior knowledge of the conversation.\"\n",
    "                e4_reasoning = \"The context revolves around decantation, a specific scientific process. Recognizing this as the core topic ensures that the rewrite focuses on the next logical step in this particular procedure. Question: Then what happens? is vague without specifying what it refers to. By identifying that it refers to the action of pouring off the top layer in the decantation process, we address coreference issues, making it clear what the 'then' is referring to.\"\n",
    "\n",
    "                e1 = e1.split('Rewrite:')[0] + 'Rewrite: ' + e1_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e1.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e2 = e2.split('Rewrite:')[0] + 'Rewrite: ' + e2_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e2.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e3 = e3.split('Rewrite:')[0] + 'Rewrite: ' + e3_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e3.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e4 = e4.split('Rewrite:')[0] + 'Rewrite: ' + e4_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e4.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "\n",
    "        else: # without past passages    \n",
    "            \n",
    "            e1 = \"Context: [Q: When was Born to Fly released? A: Sara Evans's third studio album, Born to Fly, was released on October 10, 2000.]\\nQuestion: Was Born to Fly well received by critics?\\nRewrite: Was Born to Fly well received by critics?\"\n",
    "            e2 = \"Context: [Q: When was Keith Carradine born? A: Keith Ian Carradine was born August 8, 1949. Q: Is he married? A: Keith Carradine married Sandra Will on February 6, 1982.]\\nQuestion: Do they have any children?\\nRewrite: Do Keith Carradine and Sandra Will have any children?\"\n",
    "            e3 = \"Context: [Q: Who proposed that atoms are the basic units of matter? A: John Dalton proposed that each chemical element is composed of atoms of a single, unique type, and they can combine to form more complex structures called chemical compounds.]\\nQuestion: How did the proposal come about?\\nRewrite: How did John Dalton's proposal that each chemical element is composed of atoms of a single unique type, and they can combine to form more complex structures called chemical compounds come about?\"\n",
    "            e4 = \"Context: [Q: What is it called when two liquids separate? A: Decantation is a process for the separation of mixtures of immiscible liquids or of a liquid and a solid mixture such as a suspension. Q: How does the separation occur? A: The layer closer to the top of the container-the less dense of the two liquids, or the liquid from which the precipitate or sediment has settled out-is poured off.]\\nQuestion: Then what happens?\\nRewrite: Then what happens after the layer closer to the top of the container is poured off with decantation?\"\n",
    "            # e4 = \"Context: [No previous conversation.]\\nQuestion: Then what happens?\\nRewrite: Then what happens after the layer closer to the top of the container is poured off with decantation?\"\n",
    "            \n",
    "            if args.instruct_pssg == 'original' or args.instruct_pssg == 'filter_irrelevant':\n",
    "                e1, e2, e3, e4 = e1, e2, e3, e4\n",
    "            \n",
    "            elif args.instruct_pssg == 'summary' or args.instruct_pssg == 'filter_irrelevant_summary':\n",
    "                e1_tldr = \"TLDR Summary: Inquiry about the release date of Sara Evans's album \\\"Born to Fly,\\\" which was on October 10, 2000.\"\n",
    "                e2_tldr = \"TLDR Summary: Inquiry about Keith Carradine's birth date, which is August 8, 1949, and marital status, revealing he married Sandra Will on February 6, 1982.\"\n",
    "                e3_tldr = \"TLDR Summary: John Dalton proposed atoms as the basic units of matter, which can combine to form chemical compounds.\"\n",
    "                e4_tldr = \"TLDR Summary: Decantation separates mixtures of immiscible liquids or liquids and solids by pouring off the top layer after settling.\"\n",
    "\n",
    "                e1 = e1.split('Rewrite:')[0] + 'Rewrite: ' + e1_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e1.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e2 = e2.split('Rewrite:')[0] + 'Rewrite: ' + e2_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e2.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e3 = e3.split('Rewrite:')[0] + 'Rewrite: ' + e3_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e3.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e4 = e4.split('Rewrite:')[0] + 'Rewrite: ' + e4_tldr +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e4.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "\n",
    "            elif args.instruct_pssg == 'reasoning':\n",
    "                e1_reasoning = \"The question is already clear.\"\n",
    "                e2_reasoning = \"The question \\\"Do they have any children?\\\" is ambiguous without directly referencing who \\\"they\\\" are. By naming \\\"Keith Carradine and Sandra Will\\\" explicitly, we eliminate any ambiguity regarding who the question is about.\"\n",
    "                e3_reasoning = \"The question \\\"How did the proposal come about?\\\" is vague because it doesn't specify which proposal it's referring to. By restating that the proposal is about \\\"each chemical element being composed of atoms of a single, unique type, and they can combine to form more complex structures called chemical compounds,\\\" we make the question self-contained.\"\n",
    "                e4_reasoning = \"The question \\\"Then what happens?\\\" is vague without specifying which process it refers to. By stating \\\"after the layer closer to the top of the container is poured off,\\\" the question explicitly refers to the action that was previously described, making it clear which stage of the process we're inquiring about what happens next.\"\n",
    "\n",
    "                e1 = e1.split('Rewrite:')[0] + 'Rewrite: ' + e1_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e1.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e2 = e2.split('Rewrite:')[0] + 'Rewrite: ' + e2_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e2.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e3 = e3.split('Rewrite:')[0] + 'Rewrite: ' + e3_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e3.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                e4 = e4.split('Rewrite:')[0] + 'Rewrite: ' + e4_reasoning +\\\n",
    "                         ' The rewritten query is ' + \"\\\"\" + e4.split('Rewrite: ')[-1] + \"\\\"\"\n",
    "                         \n",
    "\n",
    "        prompt = f\"{Instruction}\\n\\n{e1}\\n\\n{e2}\\n\\n{e3}\\n\\n{e4}\\n\\nContext: {curr_ctx}\\nQuestion: {line['query']}\\nRewrite: \"\n",
    "        \n",
    "        \n",
    "    elif args.prompt_type == \"zsl\":\n",
    "        prompt = f\"{Instruction}\\n\\nContext: {curr_ctx}\\nQuestion: {line['Question']}\\nRewrite: \"\n",
    "    # print(\"prompt: \", prompt)\n",
    "\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1da9ce3-7760-4b30-9f98-d6db20ce7742",
   "metadata": {},
   "outputs": [],
   "source": [
    "qrel_file = \"/data/../nlp_data/topiocqa/train_gold.trec\"\n",
    "with open(qrel_file, 'r') as f:\n",
    "    qrel_data = f.readlines()\n",
    "    \n",
    "rel_threshold = 1\n",
    "run_file_dir = \"/data2/../nlp_data/convgqr/bm25/chatgpt/\"\n",
    "p_type = \"icl\"\n",
    "inst_pssg = \"original\"\n",
    "seed = \"0\"\n",
    "temp = \"8\"\n",
    "topp = \"8\"\n",
    "eval_type = \"oracle\"\n",
    "run_file = run_file_dir + f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{seed}_{eval_type}.trec\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cdfba7b-cc10-4e48-9c54-0beccd81e982",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "all_res = []\n",
    "for s in range(12):\n",
    "    fname = f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec\"\n",
    "    run_file = run_file_dir + fname\n",
    "    res = print_res(run_file, qrel_data, rel_threshold, return_summary=False)\n",
    "    all_res += [res]\n",
    "for s in range(15,18):\n",
    "    fname = f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec\"\n",
    "    run_file = run_file_dir + fname\n",
    "    res = print_res(run_file, qrel_data, rel_threshold, return_summary=False)\n",
    "    all_res += [res]\n",
    "    \n",
    "# all_res = [res1, ..., res12]\n",
    "best_res_dict = {}\n",
    "conv_q_ids = list(all_res[0].keys())\n",
    "# print(\"conv_q_ids: \", conv_q_ids[:30])\n",
    "for conv_q_i in conv_q_ids:\n",
    "    res_list = []\n",
    "    for res in all_res:\n",
    "        res_list += [res[conv_q_i]]\n",
    "    # print('res_list ', res_list)\n",
    "    # take best\n",
    "    # Calculate the average score for each dictionary\n",
    "    avg_scores = [sum(d.values()) / len(d) for d in res_list]\n",
    "\n",
    "    # Identify the index of the dictionary with the highest average score\n",
    "    index_of_highest_avg = avg_scores.index(max(avg_scores))\n",
    "\n",
    "    # Retrieve the dictionary with the highest average score\n",
    "    dict_with_highest_avg = res_list[index_of_highest_avg]\n",
    "    # print('dict_with_highest_avg: ', dict_with_highest_avg)\n",
    "    best_res_dict[conv_q_i] = dict_with_highest_avg\n",
    "\n",
    "metrics = best_res_dict\n",
    "map_list = [v['map'] for v in metrics.values()]\n",
    "mrr_list = [v['recip_rank'] for v in metrics.values()]\n",
    "recall_100_list = [v['recall_100'] for v in metrics.values()]\n",
    "recall_20_list = [v['recall_20'] for v in metrics.values()]\n",
    "recall_10_list = [v['recall_10'] for v in metrics.values()]\n",
    "recall_5_list = [v['recall_5'] for v in metrics.values()]\n",
    "ndcg_3_list = [v['ndcg_cut_3'] for v in metrics.values()]\n",
    "\n",
    "np.set_printoptions(precision=4)\n",
    "\n",
    "eval_metrics = {\n",
    "            \"MAP\": round(100*np.average(map_list),2),\n",
    "            \"MRR\": round(100*np.average(mrr_list),2),\n",
    "            \"NDCG@3\": round(100*np.average(ndcg_3_list),2),\n",
    "            \"Recall@5\": round(100*np.average(recall_5_list),2),\n",
    "            \"Recall@10\": round(100*np.average(recall_10_list),2),\n",
    "            \"Recall@20\": round(100*np.average(recall_20_list),2),\n",
    "            \"Recall@100\": round(100*np.average(recall_100_list),2), \n",
    "        }\n",
    "eval_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9173dd3-d20b-4f57-b5aa-24f2019cbf4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get perf dicts by conv-turn\n",
    "\n",
    "all_res = all_res # + [gt_res]\n",
    "stat_res_dict = {}\n",
    "detailed_res_dict = {}\n",
    "conv_q_ids = list(all_res[0].keys())\n",
    "# print(\"conv_q_ids: \", conv_q_ids[:30])\n",
    "for conv_q_i in conv_q_ids:\n",
    "    res_list = []\n",
    "    for res in all_res:\n",
    "        res_list += [res[conv_q_i]]\n",
    "\n",
    "    values_by_d = []\n",
    "\n",
    "    # Populate the lists with values from each dictionary\n",
    "    for d in res_list:\n",
    "        vals = list(d.values())\n",
    "        d_avg = mean(vals)\n",
    "        values_by_d += [d_avg]\n",
    "        \n",
    "    # Calculate averages and stds\n",
    "    avg, std = mean(values_by_d), stdev(values_by_d) if len(values_by_d) > 1 else 0\n",
    "    stat_res_dict[conv_q_i] = (avg,std)\n",
    "    \n",
    "    # save detailed performance\n",
    "    detailed_res_dict[conv_q_i] = res_list\n",
    "    \n",
    "# stat_res_dict\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cded214a-a1a4-4f26-b9a7-5cbf2e14cc25",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get rewritten queries by conv-turns\n",
    "\n",
    "# /data2/../nlp_data/infocqr_data/topiocqa/\n",
    "# train_chatgpt_${p_type}_WOpssg_${inst_pssg}_seed${seed}_temp${temp}_p${topp}_sampled.jsonl\n",
    "pred_file_dir = \"/data2/../nlp_data/infocqr_data/topiocqa/\"\n",
    "\n",
    "# load all pred-queries from temp_paths\n",
    "all_pred_data = {}\n",
    "pred_i = 0\n",
    "for s in range(12):\n",
    "    fname = f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_temp{temp}_p{topp}_sampled.jsonl\"\n",
    "    pred_file = pred_file_dir + fname\n",
    "    with open(pred_file, \"r\") as f:\n",
    "        data = f.readlines()\n",
    "    data = [json.loads(data[i]) for i in range(len(data))]\n",
    "    all_pred_data[pred_i] = data\n",
    "    pred_i += 1\n",
    "    \n",
    "for s in range(15,18):\n",
    "    fname = f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_originalQ_seed{s}_temp{temp}_p{topp}_sampled.jsonl\"\n",
    "    pred_file = pred_file_dir + fname\n",
    "    with open(pred_file, \"r\") as f:\n",
    "        data = f.readlines()\n",
    "    data = [json.loads(data[i]) for i in range(len(data))]\n",
    "    all_pred_data[pred_i] = data\n",
    "    pred_i += 1\n",
    "    \n",
    "    \n",
    "all_proc_preds = {}\n",
    "for i,data in all_pred_data.items():\n",
    "    temp_data = {}\n",
    "    for dt in tqdm(data):\n",
    "        guid = f\"{dt['conv_id']}-{dt['turn_id']}\"\n",
    "        pred_query = dt['oracle_utt_text']\n",
    "        temp_data[guid] = {'pred_query':pred_query}\n",
    "    all_proc_preds[i] = temp_data\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4f9f14f-8bb7-4884-9be3-3550cad6e462",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Step 1 & 2: Calculate the average score for each key\n",
    "avg_scores = {}\n",
    "for key, values in detailed_res_dict.items():\n",
    "    avg_scores[key] = []\n",
    "    for metrics in values:\n",
    "        # metrics = v  # Assuming we're always interested in the 0-th element\n",
    "        avg_score = sum(metrics.values()) / len(metrics)\n",
    "        avg_scores[key] += [avg_score]\n",
    "\n",
    "# Step 3: Group keys by their average scores\n",
    "grouped_by_avg_score = {}\n",
    "for key, avgs in avg_scores.items():\n",
    "    groups_by_avg = {}\n",
    "    for i, avg in enumerate(avgs):\n",
    "        if avg not in groups_by_avg:\n",
    "            groups_by_avg[avg] = [i]\n",
    "        else:\n",
    "            # print(key, i)\n",
    "            groups_by_avg[avg].append(i)\n",
    "            \n",
    "    grouped_by_avg_score[key] = groups_by_avg\n",
    "\n",
    "# # If you need the groups sorted by the average score\n",
    "# sorted_grouped_by_avg_score = dict(sorted(grouped_by_avg_score.items()))\n",
    "\n",
    "# Displaying the result\n",
    "for avg_score, keys in grouped_by_avg_score.items():\n",
    "    print(f\"Average Score: {avg_score}, Keys: {keys}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77bdc3a7-e0ad-4459-8a07-efdb0834a17e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for k in list(grouped_by_avg_score.keys()):\n",
    "    grouped_by_avg_score[k] = dict(sorted(grouped_by_avg_score[k].items(), key=lambda k: -k[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96ab03b8-dfdd-4e2a-9419-15c1b2d57039",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in list(grouped_by_avg_score.keys()):\n",
    "    # remove duplicates\n",
    "    new_dict = {}\n",
    "    for score, gen_ids in grouped_by_avg_score[k].items():\n",
    "        uniq_ids = []\n",
    "        same_score = []\n",
    "        for i in gen_ids:\n",
    "            if all_proc_preds[i][k]['pred_query'] not in same_score:\n",
    "                same_score += [all_proc_preds[i][k]['pred_query']]\n",
    "                uniq_ids += [i]\n",
    "        new_dict[score] = uniq_ids\n",
    "    \n",
    "    # update\n",
    "    grouped_by_avg_score[k] = new_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a590a83-fafa-4f41-9ea6-fa4a64e9a845",
   "metadata": {},
   "source": [
    "#### get pseudo gold documents using back-retrieval results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8233e27f-ea3f-463d-bc46-7c58f34ae3c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['JAVA_HOME'] = \"/usr/lib/jvm/java-11-openjdk-amd64\"\n",
    "# from shared_utils.indexing_utils import SparseIndexer, DocumentCollection\n",
    "from pyserini.search.lucene import LuceneSearcher\n",
    "\n",
    "bm25_k1 = 0.9\n",
    "bm25_b = 0.4\n",
    "\n",
    "index_dir_path = \"/data2/../nlp_data/topiocqa/indexes/bm25\"\n",
    "searcher = LuceneSearcher(index_dir_path)\n",
    "searcher.set_bm25(bm25_k1, bm25_b)\n",
    "\n",
    "\n",
    "doc = searcher.doc(1)\n",
    "json_doc = json.loads(doc.raw())\n",
    "print(json_doc['contents'])\n",
    "\n",
    "run_file_dir = \"/data2/../nlp_data/convgqr/bm25/topiocqa-back-retrieval/\"\n",
    "\n",
    "# fname = f\"train_chatgpt_selfask+this_answer_only_train_selfask_init.trec\"\n",
    "# fname = f\"train_chatgpt_selfask+this_answer_only_train_selfask_1R.trec\"\n",
    "fname = f\"train_chatgpt_selfask+this_answer_only_train_selfask_2R.trec\"\n",
    "back_ret_pssgs_answer_save = {}\n",
    "\n",
    "with open(run_file_dir+fname, 'r' )as f:\n",
    "    run_data = f.readlines()\n",
    "    \n",
    "    for line in run_data:\n",
    "        line = line.split(\" \")\n",
    "        qid = line[0]\n",
    "        docid = line[2]\n",
    "        score = float((line[5])) \n",
    "        \n",
    "        doc = searcher.doc(int(docid))\n",
    "        json_doc = json.loads(doc.raw())\n",
    "        text = json_doc['contents']\n",
    "        \n",
    "        if qid not in back_ret_pssgs_answer_save:\n",
    "            back_ret_pssgs_answer_save[qid] = []\n",
    "        back_ret_pssgs_answer_save[qid] += [{\"id\": docid, \"score\": score, \"text\": text}]\n",
    "        \n",
    "            \n",
    "print(len(back_ret_pssgs_answer_save))\n",
    "\n",
    "import torch\n",
    "\n",
    "# torch.save(back_ret_pssgs_answer_save, \n",
    "#            \"/data2/../nlp_data/llm_qr/outputs/BM25/1R_topi_back_retrieval_answer\")\n",
    "# torch.save(back_ret_pssgs_answer_save, \n",
    "#            \"/data2/../nlp_data/llm_qr/outputs/BM25/2R_topi_back_retrieval_answer\")\n",
    "torch.save(back_ret_pssgs_answer_save, \n",
    "           \"/data2/../nlp_data/llm_qr/outputs/BM25/3R_topi_back_retrieval_answer\")\n",
    "\n",
    "# back_retrieval_answer_bm25 = torch.load(\"/data2/../nlp_data/llm_qr/outputs/BM25/1R_topi_back_retrieval_answer\")\n",
    "# back_retrieval_answer_bm25 = torch.load(\"/data2/../nlp_data/llm_qr/outputs/BM25/2R_topi_back_retrieval_answer\")\n",
    "back_retrieval_answer_bm25 = torch.load(\"/data2/../nlp_data/llm_qr/outputs/BM25/3R_topi_back_retrieval_answer\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07073412-c544-49b2-ba89-6f1d8e041944",
   "metadata": {},
   "source": [
    "#### get predicted documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0242d3e-2204-4bfb-95b2-1bbf02a9e0d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sparse predictions\n",
    "# get predicted docs \n",
    "\n",
    "run_file_dir = \"/data2/../nlp_data/convgqr/bm25/chatgpt/\"\n",
    "all_results_cands_bm25 = []\n",
    "\n",
    "    \n",
    "for s in range(12): #12\n",
    "    runs = {}\n",
    "    fname = f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec\"\n",
    "    run_file = run_file_dir + fname\n",
    "    # res = print_res(run_file, qrel_data, rel_threshold, return_summary=True)\n",
    "    with open(run_file, 'r' )as f:\n",
    "        run_data = f.readlines()\n",
    "    for line in run_data:\n",
    "        line = line.split(\" \")\n",
    "        query = line[0]\n",
    "        passage = line[2]\n",
    "        rel = int(line[4])\n",
    "        if query not in runs:\n",
    "            runs[query] = []\n",
    "        runs[query] += [int(passage)] # [passage] = rel\n",
    "        \n",
    "    all_results_cands_bm25 += [runs] # 12 x 13k x 100\n",
    "    \n",
    "for s in range(15,18): #12\n",
    "    runs = {}\n",
    "    fname = f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec\"\n",
    "    run_file = run_file_dir + fname\n",
    "    # res = print_res(run_file, qrel_data, rel_threshold, return_summary=True)\n",
    "    with open(run_file, 'r' )as f:\n",
    "        run_data = f.readlines()\n",
    "    for line in run_data:\n",
    "        line = line.split(\" \")\n",
    "        query = line[0]\n",
    "        passage = line[2]\n",
    "        rel = int(line[4])\n",
    "        if query not in runs:\n",
    "            runs[query] = []\n",
    "        runs[query] += [int(passage)] # [passage] = rel\n",
    "        \n",
    "    all_results_cands_bm25 += [runs] # 12 x 13k x 100\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c55550f-9100-4304-a3d9-107038343717",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def jaccard_index(set_a, set_b):\n",
    "    # Calculate the intersection of the two sets\n",
    "    intersection = set_a.intersection(set_b)\n",
    "    \n",
    "    # Calculate the union of the two sets\n",
    "    union = set_a.union(set_b)\n",
    "    \n",
    "    # Compute the Jaccard Index, which is the size of the intersection divided by the size of the union\n",
    "    jaccard = len(intersection) / len(union)\n",
    "    \n",
    "    return jaccard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa935c40-0ad1-467e-95aa-f965c6b6e9a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(all_res)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37f6683a-7076-4755-b9bd-3721a27cfd4b",
   "metadata": {},
   "source": [
    "### create preference data\n",
    "- group completions by their jaccard scores\n",
    "- sample winning completions based on jaccard scores\n",
    "- generate preference data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c117ed5-1315-485c-ace6-44ca6faff622",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize a dictionary to store results for each cut_pseudo value\n",
    "pseudo_qrels_dict = {}\n",
    "\n",
    "# Iterate over cut_pseudo values from 1 to 10\n",
    "for cut_pseudo in [3]:\n",
    "    pseudo_qrels = {}\n",
    "    \n",
    "    # Iterate over keys in grouped_by_avg_score\n",
    "    for iter_i in range(len(grouped_by_avg_score)):\n",
    "        qid = list(grouped_by_avg_score.keys())[iter_i]\n",
    "        \n",
    "        if qid in pseudo:\n",
    "            pseudo_qrels[qid] = {}\n",
    "            n_iters = min(len(pseudo[qid]), 100)\n",
    "            pseudo_gold = set()\n",
    "            \n",
    "            for i in range(n_iters):\n",
    "                pseudo_gold.add(pseudo[qid][i]['id'])\n",
    "                if len(pseudo_gold) >= cut_pseudo:\n",
    "                    break\n",
    "            \n",
    "            for passage in pseudo_gold:\n",
    "                pseudo_qrels[qid][passage] = 1\n",
    "                \n",
    "    # Store the result for the current cut_pseudo value\n",
    "    pseudo_qrels_dict[cut_pseudo] = pseudo_qrels\n",
    "\n",
    "# Optionally, print the length of pseudo_qrels for each cut_pseudo value\n",
    "for cut_pseudo, qrels in pseudo_qrels_dict.items():\n",
    "    print(f\"cut_pseudo = {cut_pseudo}, number of qrels: {len(qrels)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73b74420-d664-422e-a97b-be1d84e680d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize a dictionary to store results for each cut_pseudo value\n",
    "all_res_pseudo_dict = {}\n",
    "\n",
    "# Iterate over each cut_pseudo value and its corresponding pseudo_qrels\n",
    "for cut_pseudo, pseudo_qrels in tqdm(pseudo_qrels_dict.items()):\n",
    "    all_res_pseudo = []\n",
    "    \n",
    "    # Iterate over seed values and compute results\n",
    "    for s in range(12):\n",
    "        fname = f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec\"\n",
    "        run_file = run_file_dir + fname\n",
    "        res = print_res_pseudo_qrels(run_file, pseudo_qrels, rel_threshold, return_summary=False)\n",
    "        all_res_pseudo.append(res)\n",
    "        \n",
    "    for s in range(15, 18):\n",
    "        fname = f\"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec\"\n",
    "        run_file = run_file_dir + fname\n",
    "        res = print_res_pseudo_qrels(run_file, pseudo_qrels, rel_threshold, return_summary=False)\n",
    "        all_res_pseudo.append(res)\n",
    "    \n",
    "    # Store the result for the current cut_pseudo value\n",
    "    all_res_pseudo_dict[cut_pseudo] = all_res_pseudo\n",
    "\n",
    "\n",
    "\n",
    "# Optionally, print the results for each cut_pseudo value\n",
    "for cut_pseudo, results in all_res_pseudo_dict.items():\n",
    "    print(f\"cut_pseudo = {cut_pseudo}, results: {len(results), len(results[0])}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2fbbbb6-cd6c-4e19-b978-a9dbef64e23f",
   "metadata": {},
   "outputs": [],
   "source": [
    "## top3\n",
    "all_res_pseudo_dict.keys()\n",
    "temp_all_res_pseudo_dict = {}\n",
    "temp_all_res_pseudo_dict[3] = all_res_pseudo_dict[3]\n",
    "\n",
    "# all_res = [res1, ..., res12]\n",
    "best_res_dict = {}\n",
    "pseudo_avg_scores = {}\n",
    "conv_q_ids = list(all_res_pseudo[0].keys())\n",
    "# print(\"conv_q_ids: \", conv_q_ids[:30])\n",
    "for conv_q_i in conv_q_ids:\n",
    "    res_list = []\n",
    "    for res in all_res:\n",
    "        res_list += [res[conv_q_i]]\n",
    "        \n",
    "    # 모든 cut_pseudo 값에 대해 pseudo_res_list를 저장할 딕셔너리 초기화\n",
    "    pseudo_res_lists_dict = {}\n",
    "\n",
    "    # 각 cut_pseudo 값에 대해 반복\n",
    "    for cut_pseudo, all_res_pseudo in temp_all_res_pseudo_dict.items():\n",
    "        pseudo_res_list = []\n",
    "\n",
    "        # 각 res에 대해 conv_q_i 인덱스를 사용하여 결과를 수집\n",
    "        for res in all_res_pseudo:\n",
    "            pseudo_res_list.append(res[conv_q_i])\n",
    "\n",
    "        # 현재 cut_pseudo 값에 대한 결과를 딕셔너리에 저장\n",
    "        pseudo_res_lists_dict[cut_pseudo] = pseudo_res_list\n",
    "        \n",
    "    # take best\n",
    "    # Calculate the average score for each dictionary\n",
    "    avg_scores = []\n",
    "    for pred_i in range(len(pseudo_res_lists_dict[1])):\n",
    "        scores = 0\n",
    "        for cut in list(pseudo_res_lists_dict.keys()):\n",
    "            res_dict  = pseudo_res_lists_dict[cut][pred_i]\n",
    "            scores += (sum(res_dict.values()) / len(res_dict))*(1/cut)\n",
    "        avg_scores += [scores]\n",
    "    \n",
    "    pseudo_avg_scores[conv_q_i] = avg_scores\n",
    "    \n",
    "    # print(avg_scores)\n",
    "    # avg_scores = [sum(d.values()) / len(d) for d in pseudo_res_list]\n",
    "    # avg_scores = [sum(d1.values()) / len(d1) + sum(d2.values()) / len(d2) for d1, d2 in zip(pseudo_res_list, pseudo_res_list_cut3)]\n",
    "\n",
    "    # Identify the index of the dictionary with the highest average score\n",
    "    index_of_highest_avg = avg_scores.index(max(avg_scores))\n",
    "\n",
    "    # Retrieve the dictionary with the highest average score\n",
    "    dict_with_highest_avg = res_list[index_of_highest_avg]\n",
    "    # print('dict_with_highest_avg: ', dict_with_highest_avg)\n",
    "    best_res_dict[conv_q_i] = dict_with_highest_avg\n",
    "\n",
    "metrics = best_res_dict\n",
    "map_list = [v['map'] for v in metrics.values()]\n",
    "mrr_list = [v['recip_rank'] for v in metrics.values()]\n",
    "recall_100_list = [v['recall_100'] for v in metrics.values()]\n",
    "recall_20_list = [v['recall_20'] for v in metrics.values()]\n",
    "recall_10_list = [v['recall_10'] for v in metrics.values()]\n",
    "recall_5_list = [v['recall_5'] for v in metrics.values()]\n",
    "ndcg_3_list = [v['ndcg_cut_3'] for v in metrics.values()]\n",
    "\n",
    "np.set_printoptions(precision=4)\n",
    "\n",
    "eval_metrics = {\n",
    "            \"MAP\": round(100*np.average(map_list),2),\n",
    "            \"MRR\": round(100*np.average(mrr_list),2),\n",
    "            \"NDCG@3\": round(100*np.average(ndcg_3_list),2),\n",
    "            \"Recall@5\": round(100*np.average(recall_5_list),2),\n",
    "            \"Recall@10\": round(100*np.average(recall_10_list),2),\n",
    "            \"Recall@20\": round(100*np.average(recall_20_list),2),\n",
    "            \"Recall@100\": round(100*np.average(recall_100_list),2), \n",
    "        }\n",
    "eval_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7cb912f-5a89-4b2c-a55e-a2dad91a53f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Step 3: Group keys by their average scores\n",
    "grouped_by_avg_score_pseudo = {}\n",
    "# for comb, pseudo_avgs in pseudo_avg_scores.items():\n",
    "#     grouped_by_avg_score_pseudo[comb] = {}\n",
    "for key, avgs in pseudo_avg_scores.items():\n",
    "    groups_by_avg = {}\n",
    "    for i, avg in enumerate(avgs):\n",
    "        if avg not in groups_by_avg:\n",
    "            groups_by_avg[avg] = [i]\n",
    "        else:\n",
    "            # print(key, i)\n",
    "            groups_by_avg[avg].append(i)\n",
    "\n",
    "    grouped_by_avg_score_pseudo[key] = groups_by_avg\n",
    "\n",
    "# # If you need the groups sorted by the average score\n",
    "# sorted_grouped_by_avg_score = dict(sorted(grouped_by_avg_score.items()))\n",
    "\n",
    "# # Displaying the result\n",
    "# for avg_score, keys in grouped_by_avg_score_pseudo.items():\n",
    "#     print(f\"Average Score: {avg_score}, Keys: {keys}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c305640-aed9-417a-a668-dcd5d15693f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in list(grouped_by_avg_score_pseudo.keys()):\n",
    "    # remove duplicates\n",
    "    new_dict = {}\n",
    "    for score, gen_ids in grouped_by_avg_score_pseudo[k].items():\n",
    "        uniq_ids = []\n",
    "        same_score = []\n",
    "        for i in gen_ids:\n",
    "            if all_proc_preds[i][k]['pred_query'] not in same_score:\n",
    "                same_score += [all_proc_preds[i][k]['pred_query']]\n",
    "                uniq_ids += [i]\n",
    "        new_dict[score] = uniq_ids\n",
    "    \n",
    "    # update\n",
    "    grouped_by_avg_score_pseudo[k] = new_dict\n",
    "    \n",
    "for k in list(grouped_by_avg_score_pseudo.keys()):\n",
    "    grouped_by_avg_score_pseudo[k] = dict(sorted(grouped_by_avg_score_pseudo[k].items(), key=lambda k: -k[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f02e6ca-53df-459c-b6fc-85ad3e3e8a3c",
   "metadata": {},
   "source": [
    "### Stratified ratio\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36557195-398c-421d-bdc3-05dc100c8a71",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from collections import defaultdict\n",
    "\n",
    "SEED = 1\n",
    "random.seed(SEED)\n",
    "\n",
    "score_gap_dict = {}\n",
    "num_pair_dict = {}\n",
    "pairs_dict_Q_Q = {}\n",
    "pairs_dict_QA_QA = {}\n",
    "pairs_dict_Q_QA = {}\n",
    "pairs_dict_QA_Q = {}\n",
    "allpairs_dict = defaultdict(list)\n",
    "for key, score_dict in grouped_by_avg_score_pseudo.items():\n",
    "    max_score = list(score_dict.keys())[0]\n",
    "    min_score = list(score_dict.keys())[-1]\n",
    "    gap = max_score - min_score\n",
    "    score_gap_dict[key] = gap\n",
    "    \n",
    "    \n",
    "    if gap > 0.4: # minimal gap to create pref pair\n",
    "        num_pair_dict[key] = min(len(score_dict[max_score]) * len(score_dict[min_score]), 5)\n",
    "        \n",
    "        pairs = [(j,k) for j in score_dict[max_score] for k in score_dict[min_score] if j<12 and k<12]\n",
    "        random.shuffle(pairs)\n",
    "        if pairs[:num_pair_dict[key]]:\n",
    "            pairs_dict_Q_Q[key] = pairs[:num_pair_dict[key]] \n",
    "            allpairs_dict[key] += pairs[:num_pair_dict[key]]\n",
    "            \n",
    "    if gap > 1.3: # minimal gap to create pref pair\n",
    "        num_pair_dict[key] = min(len(score_dict[max_score]) * len(score_dict[min_score]), 5)\n",
    "        pairs = [(j,k) for j in score_dict[max_score] for k in score_dict[min_score] if j>=12 and k<12]\n",
    "        random.shuffle(pairs)\n",
    "        if pairs[:num_pair_dict[key]]:\n",
    "            pairs_dict_QA_Q[key] = pairs[:num_pair_dict[key]]\n",
    "            allpairs_dict[key] += pairs[:num_pair_dict[key]]\n",
    "        \n",
    "    if gap > 0.1: # minimal gap to create pref pair\n",
    "        num_pair_dict[key] = min(len(score_dict[max_score]) * len(score_dict[min_score]), 50)\n",
    "        pairs = [(j,k) for j in score_dict[max_score] for k in score_dict[min_score] if j>=12 and k>=12]\n",
    "        random.shuffle(pairs)\n",
    "        if pairs[:num_pair_dict[key]]:\n",
    "            pairs_dict_QA_QA[key] = pairs[:num_pair_dict[key]]\n",
    "            allpairs_dict[key] += pairs[:num_pair_dict[key]]\n",
    "\n",
    "    if gap > 0.4: # minimal gap to create pref pair\n",
    "        num_pair_dict[key] = min(len(score_dict[max_score]) * len(score_dict[min_score]), 5)\n",
    "        pairs = [(j,k) for j in score_dict[max_score] for k in score_dict[min_score] if j<12 and k>=12]\n",
    "        random.shuffle(pairs)\n",
    "        if pairs[:num_pair_dict[key]]:\n",
    "            pairs_dict_Q_QA[key] = pairs[:num_pair_dict[key]]\n",
    "            allpairs_dict[key] += pairs[:num_pair_dict[key]]\n",
    "        \n",
    "    \n",
    "        \n",
    "# len([key for key, gap in score_gap_dict.items() if gap >0.0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bf1aa59-f44e-42e0-b269-9b0597a95526",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "split = 'train'\n",
    "root = \"/data/../nlp_data/topiocqa/\" # 'datasets/qrecc/' \n",
    "input_data = \"train_new.json\" # \n",
    "with open(os.path.join(root, input_data), encoding=\"utf-8\") as f:\n",
    "    lines = f.readlines()\n",
    "    \n",
    "lines = [json.loads(l) for l in lines]\n",
    "\n",
    "args = {\n",
    "    'use_pssg': False, \n",
    "    'instruct_pssg': 'original',\n",
    "    'prompt_type': 'icl'\n",
    "}\n",
    "args = argparse.Namespace(**args)\n",
    "\n",
    "qid_prompt_dict = dict()\n",
    "for line in tqdm(lines):\n",
    "    # conv_id = f\"{line['Conversation_no']}_{line['Turn_no']}\"\n",
    "    conv_id = f\"{line['conv_id']}-{line['turn_id']}\"\n",
    "\n",
    "    prompt = set_prompt(line, args)\n",
    "    qid_prompt_dict[conv_id] = prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d5f9e57-8be5-4606-b1ab-017b9b2f0ab1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from random import sample\n",
    "columns = ['qid', 'question', 'response_j', 'response_k']\n",
    "\n",
    "for comb in list(allpairs_dict.keys()):\n",
    "    if comb != \"1-1\": # uncomment to make 1-1 as well\n",
    "        print(comb)\n",
    "        \n",
    "        qid_list, gen_j_list, gen_k_list, qs_list, js_list, ks_list = [], [], [], [], [], []\n",
    "        \n",
    "        for key, sample_pair in allpairs_dict[comb].items():\n",
    "            for gen_j,gen_k in sample_pair:\n",
    "                # preds\n",
    "                qid_list += [ key ]\n",
    "                qs_list += [ qid_prompt_dict[key] ]\n",
    "                js_list += [ all_proc_preds[gen_j][key]['pred_query'] ]\n",
    "                ks_list += [ all_proc_preds[gen_k][key]['pred_query'] ]\n",
    "                gen_j_list += [gen_j]\n",
    "                gen_k_list += [gen_k]\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ac4f1b1-ad20-4ceb-8bde-270690ae4c02",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(f\"{len(gen_k_list)} samples to 10000 samples\")\n",
    "n_pref_samples = 10000\n",
    "sampled_inds = sorted(sample(range(len(gen_k_list)),n_pref_samples, ))\n",
    "qid_list = [qid_list[_] for _ in sampled_inds]\n",
    "qs_list = [qs_list[_] for _ in sampled_inds]\n",
    "js_list = [js_list[_] for _ in sampled_inds]\n",
    "ks_list = [ks_list[_] for _ in sampled_inds]\n",
    "gen_j_list = [gen_j_list[_] for _ in sampled_inds]\n",
    "gen_k_list = [gen_k_list[_] for _ in sampled_inds]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21a97a1a-2fe6-4e37-a088-441aef1bcb65",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pref = pd.DataFrame({'qid': qid_list, 'question':qs_list, 'response_j':js_list, 'response_k':ks_list,\n",
    "                       'gen_j':gen_j_list, 'gen_k':gen_k_list })\n",
    "df_pref.to_csv('/data/../nlp_data/LongAlpaca-12k/pref_data_topi_init.csv', index=False)\n",
    "# df_pref.to_csv('/data/../nlp_data/LongAlpaca-12k/pref_data_topi_2R.csv', index=False)\n",
    "# df_pref.to_csv('/data/../nlp_data/LongAlpaca-12k/pref_data_topi_3R.csv', index=False)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a062592-5259-435e-9aff-c81f94875552",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c813f7a6-27e0-4cfd-befc-2270f0d0e62a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65d0a6a0-c39d-47fe-b3ef-4ce24ea7f26e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1e0a26e-862b-460a-886d-5b2fee82cc70",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "441cbc08-c40d-4a14-819b-e1ca1257c7ff",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bfc8b02-3150-4194-8ad1-0469a9d831cf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88260b81-123e-410b-b88e-2c4be1e35dd4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37a917b0-8801-4748-9c56-374746e885b5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llmcqr",
   "language": "python",
   "name": "llmcqr"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
