{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c4343ab0-870c-4a61-b5cc-1623ceb086a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "#set the openai env variable here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "69cd1e4c-a2e0-43cc-8222-724292c79792",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-19T19:33:40.255128100Z",
     "start_time": "2024-08-19T19:33:40.252501700Z"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_chroma import Chroma\n",
    "\n",
    "collection_name = \"name of the constructed RAG data base\"\n",
    "db_path = \"path to the constructed RAG data base\"\n",
    "db = Chroma(\n",
    "    collection_name=collection_name,\n",
    "    persist_directory=db_path,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87af2b0d-8188-46ec-9db0-f18b06289d6f",
   "metadata": {},
   "source": [
    "### data prepare"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0cc476f5-c649-431d-be35-649351a006b5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-19T19:39:06.882949500Z",
     "start_time": "2024-08-19T19:39:06.866367100Z"
    }
   },
   "outputs": [],
   "source": [
    "#batch process\n",
    "import json\n",
    "mimic_report_path = r\"mimic reports folder\"\n",
    "sampled_files_path = r\"mimic_cxr/sampled_paths.json\" # the sampled images from the MIMIC-CXR\n",
    "\n",
    "with open(sampled_files_path, 'r', encoding='utf8') as file:\n",
    "    mimic_files = json.load(file)\n",
    "\n",
    "def to_report(file_path): #get the report of a medical image\n",
    "    path = \"/\".join(file_path.split(\"/\")[:3]) + \".txt\"\n",
    "    with open(os.path.join(mimic_report_path, path), \"r\") as f:\n",
    "        report = f.read()\n",
    "    return path, file_path, report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8ec485c6-5bfd-4fd1-91d8-383e5a3c1624",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-19T19:39:34.419610Z",
     "start_time": "2024-08-19T19:39:34.393975900Z"
    }
   },
   "outputs": [],
   "source": [
    "sampled_type4_paths = mimic_files[1400:1800] #sample 500 for close-ended"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ed2f9a3-dd0c-46df-b3a1-e95097b54418",
   "metadata": {},
   "outputs": [],
   "source": [
    "useful_ids = []\n",
    "useful_data = []\n",
    "for _id in sampled_type4_paths:\n",
    "    _data = dict()\n",
    "    img_path, file_path, report = to_report(_id)\n",
    "    useful_ids.append(_id)\n",
    "    _data['img_path'] = img_path\n",
    "    _data['report'] = report\n",
    "    useful_data.append(_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "outputs": [],
   "source": [
    "from typing import Any, Dict, Iterator, List, Optional, Sequence, Union"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "943074ee-31fd-4909-8a98-87b7d7291e2e"
  },
  {
   "cell_type": "markdown",
   "id": "8681de95-3cb0-44de-ae19-de45a1924705",
   "metadata": {},
   "source": [
    "### "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "b16775e5-594c-4697-8e6a-9f2c10c90479",
   "metadata": {},
   "outputs": [],
   "source": [
    "api_version = \"2024-05-01-preview\"\n",
    "from langchain_openai import AzureChatOpenAI\n",
    "model_kwargs = dict(\n",
    "    model=\"gpt-4-128k\",\n",
    "    azure_endpoint=\"endpoint here\",\n",
    "    api_key=\"key here\",\n",
    "    api_version=api_version,\n",
    "    temperature=0.0,\n",
    ")\n",
    "\n",
    "gpt_model = AzureChatOpenAI(**model_kwargs, cache=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "3836789b-3006-455f-abc0-0b6c1e970f0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_msg = \"\"\"\n",
    "You are provided with the clinical report about a medical image, and relevant retrieved knowledge. You task is to synthesize a set of open-ended QA pairs (asking medical terminologies, such as disease, clinical symptoms) according to the requirements. Unfortunately, you don’t have access to the actual image. Below are requirements for generating the questions and answers in the conversation.\n",
    "\n",
    "- Do not use phrases like \"mentioned\", \"report\" in the conversation. Instead, refer to the information as being \"In the image.\"\n",
    "- Answer responsibly within the given information.\n",
    "- You could rely on the knowledge if it is useful. \n",
    "- Only focus on the most crucial several terminologies in the reports.\n",
    "Here is one example question: \"What does mediastinal lipomatosis indicate when seen in an image?\"\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "3c90b58d-b7f1-4c9f-bddc-8485be93f262",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_msg_knowledge = \"\"\"\n",
    "You should filter out useless and noisy retrieval knowledge, keep the important and useful knowledge about the given report, especially the knowledge about the medical terminologies.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "4f1fcc90-7c19-4919-8e95-f0d137fb5193",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "from langchain.retrievers.multi_query import MultiQueryRetriever\n",
    "from langchain_core.runnables import RunnableLambda, RunnablePassthrough\n",
    "from langchain_community.document_transformers import LongContextReorder\n",
    "\n",
    "\n",
    "import logging\n",
    "\n",
    "logging.basicConfig()\n",
    "logging.getLogger(\"langchain.retrievers.multi_query\").setLevel(logging.INFO)\n",
    "\n",
    "sem_retriever = db.as_retriever(search_kwargs=dict(k=10))\n",
    "\n",
    "from langchain.output_parsers import PydanticOutputParser\n",
    "from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
    "from pydantic import BaseModel, Field\n",
    "import json\n",
    "from langchain_core.messages import HumanMessage, SystemMessage\n",
    "\n",
    "class _BaseModel(BaseModel):\n",
    "\n",
    "    def __str__(self):\n",
    "        return json.dumps(self.__dict__, indent=4)\n",
    "        \n",
    "def format_qa(qa_pairs):\n",
    "    # print(qa_pairs)\n",
    "    return qa_pairs['report']\n",
    "\n",
    "def reorder_docs(docs):\n",
    "    for i, doc in enumerate(docs):\n",
    "        doc.metadata[\"rank\"] = i+1\n",
    "    reordering = LongContextReorder()\n",
    "    return reordering.transform_documents(docs)\n",
    "\n",
    "def format_docs(docs):\n",
    "    # print(f\"Found docs: {docs}\")\n",
    "    return \"\\n\".join(\"{}. ({}) {}\".format(i+1, doc.metadata[\"source\"], doc.page_content) for i, doc in enumerate(docs))\n",
    "\n",
    "class QA(BaseModel):\n",
    "    question: str = Field(description=\"The question\")\n",
    "    answer: str = Field(description=\"The answer\")\n",
    "\n",
    "class QueryAnswer(BaseModel):\n",
    "    qa_pairs: List[QA] = Field(description=\"The QA pairs list generated.\")\n",
    "\n",
    "class KnowledgeFilter(_BaseModel):\n",
    "    knowledge: str = Field(description=\"The filtered knowledge.\")\n",
    "\n",
    "#Here we use the classic parser in LangChain: Pydantic, to ensure a strict parsing process\n",
    "rag_parser = PydanticOutputParser(pydantic_object=QueryAnswer)\n",
    "knowledge_filter_parser = PydanticOutputParser(pydantic_object=KnowledgeFilter)\n",
    "    \n",
    "\n",
    "filter_prompt = ChatPromptTemplate.from_messages([\n",
    "    # \"Try to explain some concepts mentioned in the QA pairs with the provided knowledge\"\n",
    "    SystemMessage(content=system_msg_knowledge),\n",
    "    HumanMessagePromptTemplate(\n",
    "        prompt=PromptTemplate(\n",
    "            template=\"The medical report: {report}, retrieved knowledge: {knowledge_ori}, return the filtered knowledge in a json format.\",\n",
    "            input_variables=[\"knowledge_ori\", \"report\"],\n",
    "            partial_variables={\"format_instructions\": knowledge_filter_parser.get_format_instructions()},\n",
    "        )\n",
    "    ),\n",
    "])\n",
    "\n",
    "\n",
    "rag_prompt = ChatPromptTemplate.from_messages([\n",
    "    SystemMessage(content=system_msg),\n",
    "    HumanMessagePromptTemplate(\n",
    "        prompt=PromptTemplate(\n",
    "            template=\"The medical report: {report}, retrieved knowledge: {knowledge}. Generate high-quality open-ended question-answer pairs.\",\n",
    "            input_variables=[\"knowledge\", \"report\"],\n",
    "            partial_variables={\"format_instructions\": rag_parser.get_format_instructions()},\n",
    "        )\n",
    "    ),\n",
    "])\n",
    "\n",
    "\n",
    "from operator import itemgetter, attrgetter\n",
    "\n",
    "rag_chain = (\n",
    "    {\"knowledge_ori\": format_qa | sem_retriever | reorder_docs | format_docs, \"report\": RunnablePassthrough()} \n",
    "    | filter_prompt \n",
    "    | gpt_model \n",
    "    | knowledge_filter_parser\n",
    "    | {\"knowledge\": attrgetter(\"knowledge\"), \"report\": RunnablePassthrough()}\n",
    "    | rag_prompt \n",
    "    | gpt_model \n",
    "    | rag_parser\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "748e7b5e-eb00-475f-80ec-77ea5b1c3b17",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "open_qas = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "13e5b55d-0b5d-4a27-8810-780daf30fc20",
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2024-08-19T19:45:14.897805100Z",
     "start_time": "2024-08-19T19:45:14.879594700Z"
    }
   },
   "outputs": [],
   "source": [
    "for i in tqdm(range(len(useful_data))):\n",
    "    report = useful_data[i]['report']\n",
    "    img_path = useful_data[i]['img_path']\n",
    "    qa = rag_chain.invoke(dict(report=report))\n",
    "    final_new_qa = {\"img_path\":img_path, \"qa_pairs\": qa.dict()}\n",
    "    open_qas.append(final_new_qa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "9740d3db-04ff-45c2-94a6-801114b7a0f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# store\n",
    "def trans_file_open(qas_list):\n",
    "    new_list = []\n",
    "    id_ = 0\n",
    "    for i in qas_list:\n",
    "        img_id = i['img_path']\n",
    "        new_qas = i['qa_pairs']['qa_pairs']\n",
    "        for j in new_qas:\n",
    "            new_dict = dict()\n",
    "            new_dict['img_name'] = img_id\n",
    "            new_dict['question'] = j[\"question\"]\n",
    "            new_dict['answer'] = j[\"answer\"]\n",
    "            new_dict['question_type'] = 'type4_Knowledge'\n",
    "            new_dict['structured_answer'] = None\n",
    "            new_dict['qid'] = id_\n",
    "            new_dict['img_id'] = img_id\n",
    "    \n",
    "            new_list.append(new_dict)\n",
    "            id_ += 1\n",
    "\n",
    "    return new_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ed59cc1d-f509-4dfd-a52e-9945c1490401",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-19T19:47:08.967093800Z",
     "start_time": "2024-08-19T19:47:08.953073300Z"
    }
   },
   "outputs": [],
   "source": [
    "mimic_open_json = trans_file_open(open_qas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "d99c9efd-3d8c-4453-b50e-6534eb4b9703",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data saved!\n"
     ]
    }
   ],
   "source": [
    "with open(\"saved json file\", 'w') as json_file:\n",
    "    json.dump(mimic_open_json, json_file, indent=4)\n",
    "\n",
    "print(f\"Data saved!\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "language": "python",
   "display_name": "Python 3 (ipykernel)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
