{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ec26b91-d21f-4c2b-882d-3163f4b03c38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import openai\n",
    "import jsonlines\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "import argparse\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "# model_name = \"gpt-4o\" # \"gpt-3.5-turbo-0125\"\n",
    "# openai.api_key = ''\n",
    "from langchain_text_splitters import RecursiveCharacterTextSplitter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aaf07cb3-c8ea-4fbf-bcf8-c47b8a238607",
   "metadata": {},
   "source": [
    "1. gpt_template <- get_template(paper)\n",
    "2. QAs <- gpt4o(gpt_template)\n",
    "3. database mapping: map chunk of papers into chunk_id\n",
    "4. gold passage mapping: map rationale of QAs into chunk_id \n",
    "5. build bm25 & embedding space\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f5de0c6-fe3e-4950-b578-b5d5f76cad52",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f49ed25-5a9d-4501-b182-2d808429c0e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# wiki = pd.read_csv(\"/data/../nlp_data/topiocqa/full_wiki_segments.tsv\",\n",
    "#                    sep='\\t', nrows=2000)\n",
    "wiki = pd.read_csv(\"/data/../nlp_data/topiocqa/full_wiki_segments.tsv\",\n",
    "                   sep='\\t')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83b6a1c4-f14e-4e05-86b1-a60e3a08cb46",
   "metadata": {},
   "outputs": [],
   "source": [
    "# average length of each chunk\n",
    "wiki.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14fc96f4-4146-4e63-bffa-801091dd8b7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "sampled_wiki=wiki.sample(300)\n",
    "inds = list(sampled_wiki.index)\n",
    "lengths = [len(sampled_wiki.loc[ind, 'text']) for ind in inds]\n",
    "sum(lengths)/len(lengths)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14dbd46c-31eb-46ae-80d0-4d134cb16160",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "lengths[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f86c1353-8b0a-486b-a1b3-0303e648debb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0351ace8-d8ea-40b6-9616-c2b54ebc9eab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# english paper "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00c63117-26d5-4727-97e9-46991d820691",
   "metadata": {},
   "outputs": [],
   "source": [
    "en_papers = torch.load(\"../../nlp_data/kisti/db_files/en_paperpath2lang_cnt\")\n",
    "paper2lang_cnt = torch.load(\"../../nlp_data/kisti/db_files/paperpath2lang_cnt\")\n",
    "paperpath2conv = torch.load(\"../../nlp_data/kisti/db_files/paperpath2gen_conv\")\n",
    "\n",
    "\n",
    "long_enough_ps = [f_path for f_path in list(en_papers.keys()) if len(extract_content_path(f_path))>15000]\n",
    "len(long_enough_ps)\n",
    "\n",
    "gen_papers = list(paperpath2conv.keys())\n",
    "# gen_papers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9976be40-46d2-4176-9393-77a0747747fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(en_papers), len(long_enough_ps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4eb5cb9d-085a-47bf-b2c3-e1f2a8df3fde",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_content_path(path):\n",
    "    f = open(path) # certain_domain_en_ps[5]\n",
    "    # print(path)\n",
    "    body = json.load(f,)['body_text']\n",
    "    all_content = \"\"\n",
    "    # print(body)\n",
    "    prev_section_name = \"\"\n",
    "    for section in body:\n",
    "\n",
    "        if 'text' in section:\n",
    "            section_name = section['section'] if 'section' in section else \"\"\n",
    "            for text in section['text']:\n",
    "                if prev_section_name:\n",
    "                    # print(prev_section_name+\": \"+section_name+\": \"+ text)\n",
    "                    all_content += prev_section_name+\": \"+section_name+\": \"+ text + \"\\n\"\n",
    "                else:\n",
    "                    # print(section_name+\": \"+ text)\n",
    "                    all_content += section_name+\": \"+ text + \"\\n\"\n",
    "            prev_section_name = \"\"\n",
    "        else:\n",
    "            prev_section_name = section['section']\n",
    "    # print()\n",
    "    # len(all_content)\n",
    "    return all_content\n",
    "\n",
    "# long_enough_ps = [f_path for f_path in certain_domain_en_ps if len(extract_content_path(f_path))>15000]\n",
    "# len(long_enough_ps)\n",
    "\n",
    "# import random\n",
    "# random.seed(0)\n",
    "# sampled_papers = random.sample(long_enough_ps, 900)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d135ae7-512d-4a30-94f6-040053590eb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def extract_rationale(text):\n",
    "    \"\"\"Extracts the rationale part from the given text.\"\"\"\n",
    "    rationale_pattern = re.compile(r'(?:rationale:|Rationale:|\\*rationale\\*|\\*Rationale\\*|_Rationale_:)(.*?)(?=\\n|$)', \n",
    "                                   re.IGNORECASE | re.DOTALL)\n",
    "    rationales = rationale_pattern.findall(text)\n",
    "    return [clean_rationale(rationale.strip()) for rationale in rationales]\n",
    "def clean_rationale(rationale):\n",
    "    \"\"\"Cleans up the extracted rationale text.\"\"\"\n",
    "    # Remove leading formatting characters like \"** \" or any other unwanted characters\n",
    "    return re.sub(r'^\\*\\*\\s*', '', rationale)\n",
    "\n",
    "def extract_q(text):\n",
    "    \"\"\"Extracts the rationale part from the given text.\"\"\"\n",
    "    rationale_pattern = re.compile(r'(?:rationale:|Rationale:|\\*rationale\\*|\\*Rationale\\*|_Rationale_:)(.*?)(?=\\n|$)', \n",
    "                                   re.IGNORECASE | re.DOTALL)\n",
    "    rationales = rationale_pattern.findall(text)\n",
    "    return [clean_rationale(rationale.strip()) for rationale in rationales]\n",
    "def clean_q(rationale):\n",
    "    \"\"\"Cleans up the extracted rationale text.\"\"\"\n",
    "    # Remove leading formatting characters like \"** \" or any other unwanted characters\n",
    "    return re.sub(r'^\\*\\*\\s*', '', rationale)\n",
    "# rats = extract_rationale(paperpath2conv[gen_papers[3]][\"choices\"][0][\"message\"][\"content\"])\n",
    "# len(rats), rats"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5375227-b3da-4385-9438-b89552ddfcb9",
   "metadata": {},
   "source": [
    "### get rationales"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59bf1f4f-eb17-4391-b7f3-8d47ea2c2e34",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_paperpath2rationales = {}\n",
    "for paper_path in gen_papers:\n",
    "    rats = extract_rationale(paperpath2conv[paper_path][\"choices\"][0][\"message\"][\"content\"])\n",
    "    gen_paperpath2rationales[paper_path] = rats\n",
    "    # len(rats), rats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e24c071f-73ae-430c-882c-66d590bd77fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "len_rats = torch.tensor([len(gen_paperpath2rationales[paper_path]) for i,paper_path in enumerate(gen_papers)])\n",
    "zero_rat_paths = np.array(gen_papers)[len_rats == 0]\n",
    "print(len(zero_rat_paths))\n",
    "# print(paperpath2conv[zero_rat_paths[0]][\"choices\"][0][\"message\"][\"content\"])\n",
    "if zero_rat_paths:\n",
    "    print(paperpath2conv[zero_rat_paths[0]][\"choices\"][0][\"message\"][\"content\"])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ff528b9-d3ff-4235-9b4f-3e86f07b94a2",
   "metadata": {},
   "source": [
    "### get Q&A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "816e6bbe-e611-449e-b61a-738b2ef3351e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# print(paperpath2conv[gen_papers[0]][\"choices\"][0][\"message\"][\"content\"][:500])\n",
    "# print()\n",
    "# print(paperpath2conv[gen_papers[150]][\"choices\"][0][\"message\"][\"content\"][:500])\n",
    "# print()\n",
    "# print(paperpath2conv[gen_papers[180]][\"choices\"][0][\"message\"][\"content\"][:500])\n",
    "# print()\n",
    "# print(paperpath2conv[gen_papers[250]][\"choices\"][0][\"message\"][\"content\"][:500])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "178565ab-13ed-4a68-a081-2759b5c8ce09",
   "metadata": {},
   "outputs": [],
   "source": [
    "The following text is a response from an AI chatbot. Please retain the exact meaning of the subsequent sentence but rephrase it as much as possible, ensuring that the key terms remain unchanged.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e3b9fa5-3630-40a1-95fe-ca504ebac6c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(paperpath2conv[gen_papers[1]][\"choices\"][0][\"message\"][\"content\"][:4000])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d35bc168-3035-4132-ba87-340dbec5870f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def extract_questions(text):\n",
    "    \"\"\"Extracts the questions (Q) from the given text.\"\"\"\n",
    "    # Regular expression pattern to match questions in various formats\n",
    "    question_pattern = re.compile(\n",
    "        r'(?:(?<![\\w(])(?:\\*\\*)?Q\\d+(?:\\s*:|\\s*:\\s*|\\s*|\\*\\*:\\s*)(.*?)(?=\\n|$)|'\n",
    "        r'(?:\\*\\*Questioner\\*\\*:)(.*?)(?=\\n|$)|'\n",
    "        r'\\*\\*Questioner \\(Q1\\):\\*\\*\\s*(.*?)(?=\\n|$))', \n",
    "        re.IGNORECASE | re.DOTALL\n",
    "    )\n",
    "            # r'(?:\\*\\*Questioner\\s*\\(Q1\\)\\*\\*:)(.*?)(?=\\n|$))', \n",
    "    questions = question_pattern.findall(text)\n",
    "    # Flatten the list of tuples and filter out empty strings\n",
    "    questions = [item for sublist in questions for item in sublist if item]\n",
    "    return [clean_question(q.strip()) for q in questions]\n",
    "\n",
    "def clean_question(question):\n",
    "    \"\"\"Cleans up the extracted question text.\"\"\"\n",
    "    # Remove leading formatting characters like \"** \" or any other unwanted characters\n",
    "    return re.sub(r'^\\*\\*\\s*|\\*\\*$', '', question).strip()\n",
    "\n",
    "def extract_answers(text):\n",
    "    \"\"\"Extracts the questions (Q) from the given text.\"\"\"\n",
    "\n",
    "    answer_pattern = re.compile(\n",
    "        r'(?:(?:\\n\\s*)(?:\\*\\*)?A\\d+(?:\\s*:|\\s*:\\s*|\\s*|\\*\\*:\\s*)(.*?)(?=\\n|$)|'\n",
    "        r'(?:\\*\\*Answerer\\*\\*:\\s*)(.*?)(?=\\n|$)|'\n",
    "        r'\\*\\*Answerer \\(A\\d+\\):\\*\\*\\s*(.*?)(?=\\n|$))', \n",
    "        re.IGNORECASE | re.DOTALL\n",
    "    )\n",
    "       \n",
    "    answers = answer_pattern.findall(text)\n",
    "    # Flatten the list of tuples and filter out empty strings\n",
    "    answers = [item for sublist in answers for item in sublist if item]\n",
    "    return [clean_answer(a.strip()) for a in answers]\n",
    "\n",
    "def clean_answer(answer):\n",
    "    \"\"\"Cleans up the extracted question text.\"\"\"\n",
    "    # Remove leading formatting characters like \"** \" or any other unwanted characters\n",
    "    return re.sub(r'^\\*\\*\\s*|\\*\\*$', '', answer).strip()\n",
    "\n",
    "# answers = extract_answers(paperpath2conv[gen_papers[150]][\"choices\"][0][\"message\"][\"content\"])\n",
    "# for a in answers:\n",
    "#     print(a)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "772b4f73-d5b7-44f1-99a2-3ed1a06022a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_paperpath2Qs = {}\n",
    "for paper_path in gen_papers:\n",
    "    qs = extract_questions(paperpath2conv[paper_path][\"choices\"][0][\"message\"][\"content\"])\n",
    "    gen_paperpath2Qs[paper_path] = qs\n",
    "    # len(rats), rats\n",
    "    \n",
    "gen_paperpath2As = {}\n",
    "for paper_path in gen_papers:\n",
    "    ans = extract_answers(paperpath2conv[paper_path][\"choices\"][0][\"message\"][\"content\"])\n",
    "    gen_paperpath2As[paper_path] = ans\n",
    "    # len(rats), rats\n",
    "\n",
    "len_qs = torch.tensor([len(gen_paperpath2Qs[paper_path]) for i,paper_path in enumerate(gen_papers)])\n",
    "zero_q_paths = np.array(gen_papers)[len_qs == 0]\n",
    "print(len(zero_q_paths))\n",
    "# print(paperpath2conv[zero_rat_paths[0]][\"choices\"][0][\"message\"][\"content\"])\n",
    "if zero_q_paths:\n",
    "    print(paperpath2conv[zero_q_paths[0]][\"choices\"][0][\"message\"][\"content\"])\n",
    "\n",
    "len_as = torch.tensor([len(gen_paperpath2As[paper_path]) for i,paper_path in enumerate(gen_papers)])\n",
    "zero_a_paths = np.array(gen_papers)[len_as == 0]\n",
    "print(len(zero_a_paths))\n",
    "# print(paperpath2conv[zero_rat_paths[0]][\"choices\"][0][\"message\"][\"content\"])\n",
    "if zero_a_paths:\n",
    "    print(paperpath2conv[zero_a_paths[0]][\"choices\"][0][\"message\"][\"content\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc15e3eb-bebe-415d-8157-67236c8dbc44",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(gen_paperpath2rationales), sum([len(rs) for rs in list(gen_paperpath2rationales.values())])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bd5763c-721c-4078-bf07-ff156fc9a635",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(gen_paperpath2Qs),sum([len(qs) for qs in list(gen_paperpath2Qs.values())])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0c59754-a073-41e8-b031-47fdf158659a",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(gen_paperpath2As),sum([len(ans) for ans in list(gen_paperpath2As.values())])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f5ef64b-ff3a-4323-85ce-a89ffe9f5a1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(paperpath2conv, \"../../nlp_data/kisti/db_files/paperpath2gen_conv_v2\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e13d967b-a6b6-4627-8cf0-2b936105f2b6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2913512-8cc1-4d6a-b82c-9e5005ca5fb3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "800b05be-e444-4cf7-8ec9-fc3191aeebfc",
   "metadata": {},
   "source": [
    "### DB data chunking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aec0e915-6584-4382-a829-ec47a66efa6e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "def get_chunks_from_path(paper_path, max_n_letters=10000):\n",
    "    # print(extract_content_path(ex_paper_path))\n",
    "    content=extract_content_path(paper_path)[:max_n_letters]\n",
    "\n",
    "    # Initialize the text splitter\n",
    "    text_splitter = RecursiveCharacterTextSplitter(\n",
    "        separators=[\"\\n\", \". \"],  # Secondary and tertiary split markers\n",
    "        keep_separator=False,\n",
    "        chunk_size=500,  # Adjust the chunk size as needed\n",
    "        chunk_overlap=100  # Adjust the overlap as needed\n",
    "    )\n",
    "\n",
    "    # Split the preprocessed content\n",
    "    chunks = text_splitter.split_text(content)\n",
    "    # print([len(chunk) for chunk in chunks])\n",
    "    return chunks\n",
    "\n",
    "chunks = get_chunks_from_path(long_enough_ps[0])\n",
    "len(chunks), chunks,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "081741fa-c53a-4a05-96e1-755e4b4bc201",
   "metadata": {},
   "outputs": [],
   "source": [
    "paperpath2chunks = {}\n",
    "for paper_path in long_enough_ps:\n",
    "    chunks = get_chunks_from_path(paper_path)\n",
    "    paperpath2chunks[paper_path] = chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37318d16-a3e0-4249-ae22-48a1b277601c",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(paperpath2chunks), len(long_enough_ps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2c1482e-e998-4611-9eab-3a3f3a70ac92",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize an empty list to store rows\n",
    "rows = []\n",
    "\n",
    "# Iterate over the dictionary\n",
    "for path, chunks in paperpath2chunks.items():\n",
    "    for i,chunk in enumerate(chunks):\n",
    "        # Append a tuple (path, chunk) to the list of rows\n",
    "        rows.append((path, i, path+'_'+str(i), chunk))\n",
    "\n",
    "# Create a DataFrame from the list of rows\n",
    "df = pd.DataFrame(rows, columns=['path', 'chunk_id', 'path_chunk_id', 'text'])\n",
    "df['id'] = df.index\n",
    "df = df[['id', 'path_chunk_id', 'path','text']]\n",
    "# Display the DataFrame\n",
    "df.head()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb05e302-ea5c-4b8a-9e6f-97506cad348a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "df.shape, df['path_chunk_id'][13:20].values, "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a08ba557-30b0-473c-bc79-51ce9cf30df1",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(df, \"../../nlp_data/kisti/db_files/papers_chunks\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdd538c3-8cca-4279-9db0-d4430d5ca23f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pandas as pd\n",
    "import json\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Step 2: Define the number of rows per segment\n",
    "rows_per_segment = 50000  # Adjust this number based on your needs\n",
    "\n",
    "# Step 3: Split the DataFrame into smaller segments and write each to a .jsonl file\n",
    "# \"/data/../nlp_data/preprocessed\"\n",
    "# \"/data/../nlp_data/collection-paragraph\"\n",
    "output_dir = \"../../nlp_data/kisti/collection-paragraph-kisti\"\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "num_segments = (len(df) + rows_per_segment - 1) // rows_per_segment  # Calculate the number of segments\n",
    "print(\"num_segments: \", num_segments)\n",
    "\n",
    "for i in tqdm(range(num_segments)):\n",
    "    start_row = i * rows_per_segment\n",
    "    end_row = min((i + 1) * rows_per_segment, len(df))\n",
    "    segment = df.iloc[start_row:end_row]\n",
    "\n",
    "    jsonl_file_path = os.path.join(output_dir, f'segment_{i + 1}.jsonl')\n",
    "    with open(jsonl_file_path, 'w') as jsonl_file:\n",
    "        for index, row in segment.iterrows():\n",
    "            entry = {\n",
    "                \"id\": row['path_chunk_id'],\n",
    "                \"contents\": row['text']\n",
    "                # \"contents\": row['title'].replace(\" [SEP]\", \":\") + \": \" + row['text']\n",
    "                # row['text']\n",
    "            }\n",
    "            jsonl_file.write(json.dumps(entry) + '\\n')\n",
    "\n",
    "    # print(f'Segment {i + 1} written to {jsonl_file_path}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ceebb3c-d761-49cf-a341-68e056ed8f96",
   "metadata": {},
   "outputs": [],
   "source": [
    "39*50000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4f7d8c2-7577-422f-9a75-4356f1be5dda",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(df),len(df)/8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f771d847-c160-4b4c-b26d-a8099add08dc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d474ed8-3ab9-445e-b2bb-ee6016675a5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def longest_common_substring(str1, str2):\n",
    "    \"\"\"Returns the length of the longest common substring between two strings.\"\"\"\n",
    "    m = len(str1)\n",
    "    n = len(str2)\n",
    "    # Create a table to store lengths of longest common suffixes of substrings.\n",
    "    # Note that LCSuff[i][j] contains the length of the longest common suffix\n",
    "    # of str1[0...i-1] and str2[0...j-1].\n",
    "    LCSuff = [[0 for k in range(n+1)] for l in range(m+1)]\n",
    "    result = 0  # To store length of the longest common substring.\n",
    "\n",
    "    # Building the LCSuff table in a bottom-up fashion.\n",
    "    for i in range(m + 1):\n",
    "        for j in range(n + 1):\n",
    "            if i == 0 or j == 0:\n",
    "                LCSuff[i][j] = 0\n",
    "            elif str1[i-1] == str2[j-1]:\n",
    "                LCSuff[i][j] = LCSuff[i-1][j-1] + 1\n",
    "                result = max(result, LCSuff[i][j])\n",
    "            else:\n",
    "                LCSuff[i][j] = 0\n",
    "    return result\n",
    "\n",
    "def find_element_with_longest_overlap(list_of_strings, target_string):\n",
    "    \"\"\"Finds the element in the list with the longest overlap with the target string.\"\"\"\n",
    "    max_overlap = 0\n",
    "    best_match = None\n",
    "\n",
    "    for i,element in enumerate(list_of_strings):\n",
    "        overlap_length = longest_common_substring(element, target_string)\n",
    "        if overlap_length > max_overlap:\n",
    "            max_overlap = overlap_length\n",
    "            best_match = i,element\n",
    "\n",
    "    return best_match\n",
    "\n",
    "\n",
    "# target_string = \"Prolactin (PRL) is a polypeptide hormone that is synthesized in the pituitary gland, consists of 199 amino acids with a molecular mass of 23KD, and has more functions than all other pituitary hormones combined... The initial step in the action of PRL, similar to all other hormones, is binding to the extracellular domain of prolactin receptor (PRLR).\"\n",
    "\n",
    "# best_match = find_element_with_longest_overlap(chunks, target_string)\n",
    "# print(\"Element with the longest overlap:\", best_match)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b30fa182-d57f-48c8-998b-e9d126bd1d5a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "paperpath2matches = {}\n",
    "for paper_path, rationales in gen_paperpath2rationales.items():\n",
    "    chunks = paperpath2chunks[paper_path]\n",
    "    matches = []\n",
    "    for target_string in rationales:\n",
    "        best_match = find_element_with_longest_overlap(chunks, target_string)\n",
    "        # print(\"Element with the longest overlap:\", best_match)\n",
    "        matches += [best_match[0]]\n",
    "    paperpath2matches[paper_path] = matches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "197309df-eec1-49ed-a392-a77c81a26974",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(paperpath2matches, \"../../nlp_data/kisti/db_files/gen_conv2matches\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9bd4424-a035-47e3-87f1-a4aee54d43fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(paperpath2matches),sum([len(matches) for matches in list(paperpath2matches.values())])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "914dd35c-df37-4e0d-abe3-1086d25c129f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "exp_paper = list(paperpath2matches.keys())[0]\n",
    "# paperpath2matches[exp_paper]\n",
    "\n",
    "for rat, chunk_id in zip(gen_paperpath2rationales[exp_paper], paperpath2matches[exp_paper]):\n",
    "    print(\"rat: \", rat)\n",
    "    print(\"chunk: \",paperpath2chunks[exp_paper][chunk_id])\n",
    "    print(\"####\"*10 )\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d236436-fc07-4f7a-9927-ecb0a8040b0b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c46a54a-fee6-4032-b2b0-f3ac3cede06a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "891b5808-4dfe-4f2e-ba27-227c0b1ad4bc",
   "metadata": {},
   "source": [
    "### create conv data as json "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c518e747-3c38-4de2-8157-f84662e5ce79",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_paperpath\n",
    "total_paperpaths = set(list(gen_paperpath2Qs.keys()))\n",
    "len(total_paperpaths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7f569bc-d90e-407d-9409-7fbebb0dee66",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "random.seed(2)\n",
    "val_paperpaths = set(random.sample(list(total_paperpaths), 150))\n",
    "train_paperpath = total_paperpaths-val_paperpaths\n",
    "len(train_paperpath), len(val_paperpaths)\n",
    "\n",
    "sum([len(gen_paperpath2Qs[p]) for p in list(train_paperpath)]), sum([len(gen_paperpath2Qs[p]) for p in list(val_paperpaths)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b1770a1-d1b2-474c-a7e1-8759a8fb11da",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dt = []\n",
    "for conv_id, path in enumerate(train_paperpath, 1):\n",
    "    Qs, As, Rs = gen_paperpath2Qs[path], gen_paperpath2As[path], gen_paperpath2rationales[path]\n",
    "    history_qs, history_as = [], []\n",
    "    for turn_id, (q,a,r) in enumerate(zip(Qs, As, Rs), 1):\n",
    "        sample_id = str(conv_id)+\"-\"+str(turn_id)\n",
    "        instance = {'id':sample_id, 'conv_id':conv_id, 'turn_id':turn_id, 'query':q, 'answer':a,\n",
    "                     'history_query':deepcopy(history_qs), 'history_answer':deepcopy(history_as)}\n",
    "        # ['id', 'conv_id', 'turn_id', 'query', 'answer', 'history_query', 'history_answer',\n",
    "        # 'pos_docs', 'pos_docs_id']\n",
    "        history_qs += [q]\n",
    "        history_as += [a]\n",
    "        train_dt += [instance]\n",
    "    \n",
    "val_dt = []\n",
    "for conv_id, path in enumerate(val_paperpaths, 1):\n",
    "    Qs, As, Rs = gen_paperpath2Qs[path], gen_paperpath2As[path], gen_paperpath2rationales[path]\n",
    "    history_qs, history_as = [], []\n",
    "    for turn_id, (q,a,r) in enumerate(zip(Qs, As, Rs), 1):\n",
    "        sample_id = str(conv_id)+\"-\"+str(turn_id)\n",
    "        instance = {'id':sample_id, 'conv_id':conv_id, 'turn_id':turn_id, 'query':q, 'answer':a,\n",
    "                     'history_query':deepcopy(history_qs), 'history_answer':deepcopy(history_as)}\n",
    "        # ['id', 'conv_id', 'turn_id', 'query', 'answer', 'history_query', 'history_answer',\n",
    "        # 'pos_docs', 'pos_docs_id']\n",
    "        history_qs += [q]\n",
    "        history_as += [a]\n",
    "        val_dt += [instance]\n",
    "        \n",
    "len(train_dt), len(val_dt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc3c59c4-e366-4bdc-91c2-082c5e741638",
   "metadata": {},
   "outputs": [],
   "source": [
    "root = \"../../nlp_data/kisti\"\n",
    "split = \"train\"\n",
    "with open(os.path.join(root,f'{split}_new'+\".json\"), 'w', encoding='utf-8') as file:\n",
    "    for item in train_dt:\n",
    "        # Convert the dictionary to a JSON string and write it to the file\n",
    "        json_str = json.dumps(item)\n",
    "        file.write(json_str)\n",
    "        # if item != data[-1]:  # Check if it's not the last item\n",
    "        #     file.write(\",\\n\")  # For the last item, we don't add a comma\n",
    "        # else:\n",
    "        file.write(\"\\n\")\n",
    "\n",
    "root = \"../../nlp_data/kisti\"\n",
    "split = \"dev\"\n",
    "with open(os.path.join(root,f'{split}_new'+\".json\"), 'w', encoding='utf-8') as file:\n",
    "    for item in val_dt:\n",
    "        # Convert the dictionary to a JSON string and write it to the file\n",
    "        json_str = json.dumps(item)\n",
    "        file.write(json_str)\n",
    "        # if item != data[-1]:  # Check if it's not the last item\n",
    "        #     file.write(\",\\n\")  # For the last item, we don't add a comma\n",
    "        # else:\n",
    "        file.write(\"\\n\")\n",
    "\n",
    "# print(f\"Data saved to '{file_path}' successfully.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc0fb9e7-fd04-42c0-9e1c-4717c811bbe5",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(gen_paperpath2As)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c575860-9cb3-4398-9558-c04e77e84de5",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_paperpath2As[list(gen_paperpath2As.keys())[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b28d26a6-0f8d-4e39-8944-fcd6169cb71c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "d6089647-2e3c-493c-a5d6-54ca7be2b9fd",
   "metadata": {},
   "source": [
    "### create qrels for conv data as trec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e454a330-130b-4a8a-897e-889d2cfa5d59",
   "metadata": {},
   "outputs": [],
   "source": [
    "qrels_train = []\n",
    "for conv_id, path in enumerate(train_paperpath, 1):\n",
    "    # Qs, As, Rs = gen_paperpath2Qs[path], gen_paperpath2As[path], gen_paperpath2rationales[path]\n",
    "    matches = paperpath2matches[path]\n",
    "    for i, match_id in enumerate(matches, 0):\n",
    "        turn_id = i+1\n",
    "        sample_id = str(conv_id)+\"-\"+str(turn_id)\n",
    "        path_chunk_id = path+'_'+str(match_id)\n",
    "        path_chunk_id = '/'.join(path_chunk_id.split(\"/\")[6:])\n",
    "        path_chunk_id = path_chunk_id.replace(\" \", \"\")\n",
    "        qrel = f\"{sample_id} Q0 {path_chunk_id} 1\"\n",
    "        \n",
    "        qrels_train += [qrel]\n",
    "\n",
    "\n",
    "# Save to a .trec file\n",
    "with open('../../nlp_data/kisti/train_gold.trec', 'w') as f:\n",
    "    for line in qrels_train:\n",
    "        f.write(line + '\\n')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f40087f7-b816-4681-9d86-e55700c1746b",
   "metadata": {},
   "outputs": [],
   "source": [
    "qrels_train = []\n",
    "for conv_id, path in enumerate(val_paperpaths, 1):\n",
    "    # Qs, As, Rs = gen_paperpath2Qs[path], gen_paperpath2As[path], gen_paperpath2rationales[path]\n",
    "    matches = paperpath2matches[path]\n",
    "    for i, match_id in enumerate(matches, 0):\n",
    "        turn_id = i+1\n",
    "        sample_id = str(conv_id)+\"-\"+str(turn_id)\n",
    "        path_chunk_id = path+'_'+str(match_id)\n",
    "        path_chunk_id = '/'.join(path_chunk_id.split(\"/\")[6:])\n",
    "        path_chunk_id = path_chunk_id.replace(\" \", \"\")\n",
    "        qrel = f\"{sample_id} Q0 {path_chunk_id} 1\"\n",
    "        \n",
    "        qrels_train += [qrel]\n",
    "\n",
    "# Save to a .trec file\n",
    "with open('../../nlp_data/kisti/dev_gold.trec', 'w') as f:\n",
    "    for line in qrels_train:\n",
    "        f.write(line + '\\n')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f02bb566-503d-4c56-ab9a-7c08a174e30b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19a92bd7-5dd6-4b65-80b1-8205e8562bc8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a73db9c9-c80a-46ae-88fa-e8f28ead290c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "046e79f7-7c90-4e19-9fb1-fb2416dcde8b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0a21ccf-f3cd-47f7-bada-e9a8bbc1ae0a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caca35ca-c51e-4daa-a612-3fd89add9977",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25a9a279-714b-4d23-827c-98cd9d9e44d1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llmcqr",
   "language": "python",
   "name": "llmcqr"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
