{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b27ce4ee-22fa-4988-b300-533a3e7b3b83",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "#set the openai env variable here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "13dde9dd-4623-4fff-bed2-3c75bb2d2125",
   "metadata": {},
   "outputs": [],
   "source": [
    "api_version = \"2024-05-01-preview\"\n",
    "from langchain_openai import AzureChatOpenAI\n",
    "model_kwargs = dict(\n",
    "    model=\"gpt-4-128k\",\n",
    "    azure_endpoint=\"endpoint here\",\n",
    "    api_key=\"key here\",\n",
    "    api_version=api_version,\n",
    "    temperature=0.0,\n",
    ")\n",
    "\n",
    "gpt_model = AzureChatOpenAI(**model_kwargs, cache=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ce6e1e0e-70bf-4a10-bc3f-4b1c7d95b56e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any, Dict, Iterator, List, Optional, Sequence, Union\n",
    "from langchain.output_parsers import PydanticOutputParser\n",
    "from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
    "from pydantic import BaseModel, Field\n",
    "import json\n",
    "from langchain_core.messages import HumanMessage, SystemMessage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1257588c-b384-4328-963c-13791fb29beb",
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2024-08-19T19:14:19.286495800Z",
     "start_time": "2024-08-19T19:14:19.283427200Z"
    }
   },
   "outputs": [],
   "source": [
    "mimic_report_path = r\"mimic reports folder\"\n",
    "sampled_files_path = r\"mimic_cxr/sampled_paths.json\" # the sampled images from the MIMIC-CXR\n",
    "\n",
    "with open(sampled_files_path, 'r', encoding='utf8') as file:\n",
    "    mimic_files = json.load(file)\n",
    "\n",
    "def to_report(file_path): #get the report of a medical image\n",
    "    path = \"/\".join(file_path.split(\"/\")[:3]) + \".txt\"\n",
    "    with open(os.path.join(mimic_report_path, path), \"r\") as f:\n",
    "        report = f.read()\n",
    "    return path, file_path, report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "sampled_type4_paths = mimic_files[1000:1400] #sample 500 for close-ended"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "290599abd0c3a3e3"
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "129d495e-ce78-4795-8e78-4548c402b9bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "useful_ids = []\n",
    "useful_data = []\n",
    "for _id in sampled_type4_paths:\n",
    "    _data = dict()\n",
    "    img_path, file_path, report = to_report(_id)\n",
    "    useful_ids.append(_id)\n",
    "    _data['img_path'] = img_path\n",
    "    _data['report'] = report\n",
    "    useful_data.append(_data)\n",
    "    # ans = closed_chain.invoke(dict(report=report))\n",
    "    # final_new_qa = {\"img_path\":img_path, \"new_qa_pairs\": ans}\n",
    "    # mimic_close_qas.append(final_new_qa)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "022a28ea-4963-447d-be4b-39f431fa1255",
   "metadata": {},
   "source": [
    "# data generation for close-ended evaluation of Knowledge Hallucination beyond Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "1a2a46b6-9dff-4ec6-91bf-f532d4b4f294",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_msg = \"\"\"\n",
    "You are provided with the clinical report about a medical image. You task is to synthesize a set of close-ended QA pairs (diagnosis) according to the requirements. Unfortunately, you don’t have access to the actual image. Below are requirements for generating the questions and answers in the conversation.\n",
    "\n",
    "Do not use phrases like \"mentioned\", \"report\", \"context\" in the conversation. Instead, refer to the information as being \"in the image\".\n",
    "Answer responsibly within the given report, avoiding information not included in the given context. \n",
    "\n",
    "Instructions:\n",
    "Ensure balanced labels in your generated questions. For example, \"yes-or-no\" questions should have an equal number of \"yes\" and \"no\" answers. To achieve this balance, you may use negative sampling to generate questions with the answer \"no\".\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "2bb0a536-e141-47f5-aca2-b9a92723b760",
   "metadata": {},
   "outputs": [],
   "source": [
    "class QA(BaseModel):\n",
    "    question: str = Field(description=\"The question\")\n",
    "    answer: str = Field(description=\"The answer\")\n",
    "    \n",
    "class QueryAnswer(BaseModel):\n",
    "    qa_pairs: List[QA] = Field(description=\"The QA pairs for contextual evaluation.\")\n",
    "\n",
    "#Here we use the classic parser in LangChain: Pydantic, to ensure a strict parsing process\n",
    "qa_parser = PydanticOutputParser(pydantic_object=QueryAnswer)\n",
    "\n",
    "qa_prompt = ChatPromptTemplate.from_messages([\n",
    "    SystemMessage(content=system_msg),\n",
    "    HumanMessagePromptTemplate(\n",
    "        prompt=PromptTemplate(\n",
    "            template=\"The medical report of the image: {report}. Generate high-quality close-ended question-answer pairs focused on the diagnosis.\",\n",
    "            input_variables=[\"report\"],\n",
    "            partial_variables={\"format_instructions\": qa_parser.get_format_instructions()},\n",
    "        )\n",
    "    ),\n",
    "])\n",
    "\n",
    "qa_chain = (\n",
    "    qa_prompt \n",
    "    | gpt_model \n",
    "    | qa_parser\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "9f4aaa66-7e8e-452b-bcca-1564c212476d",
   "metadata": {},
   "outputs": [],
   "source": [
    "type2_qas = []\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "50c9a388-99f2-4ce6-abf5-c47104942d02",
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2024-08-19T19:29:33.040697700Z",
     "start_time": "2024-08-19T19:29:33.025414600Z"
    }
   },
   "outputs": [],
   "source": [
    "for i in tqdm(range(len(useful_data))):\n",
    "    report = useful_data[i]['report']\n",
    "    img_path = useful_data[i]['img_path']\n",
    "    qa = qa_chain.invoke(dict(report=report))\n",
    "    final_new_qa = {\"img_path\":img_path, \"qa_pairs\": qa.dict()}\n",
    "    type2_qas.append(final_new_qa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "646a4a6b-e2f2-42ae-b7ef-52f9f42e0c93",
   "metadata": {},
   "outputs": [],
   "source": [
    "# store\n",
    "def trans_file_close(qas_list):\n",
    "    new_list = []\n",
    "    id_ = 0\n",
    "    for i in qas_list:\n",
    "        img_id = i['img_path']\n",
    "        new_qas = i['qa_pairs']['qa_pairs']\n",
    "        \n",
    "        for j in new_qas:\n",
    "            if j['question_topic'] == \"diagnosis\":\n",
    "                new_dict = dict()\n",
    "                new_dict['img_name'] = img_id\n",
    "                new_dict['question'] = j[\"question\"]\n",
    "                new_dict['answer'] = j[\"answer\"]\n",
    "                new_dict['question_type'] = 'type4_Diagnosis'\n",
    "                new_dict['ground_truth_type'] = j['ground_truth_type']\n",
    "                new_dict['choices'] = j['choices']\n",
    "                \n",
    "                new_dict['qid'] = id_\n",
    "                new_dict['img_id'] = img_id\n",
    "    \n",
    "                new_list.append(new_dict)\n",
    "                id_ += 1\n",
    "\n",
    "    return new_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "e5d80c41-5982-460a-8362-a3c851c41f4e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1172,\n",
       " {'img_name': 'p19/p19454978/s57331547/7d047120-d24a497e-fc26ea7e-6c3acc0c-ce5bc190.jpg',\n",
       "  'question': 'Is there evidence of atelectasis according to the image?',\n",
       "  'answer': 'Yes',\n",
       "  'question_type': 'type4_Diagnosis',\n",
       "  'ground_truth_type': 'binary',\n",
       "  'choices': '',\n",
       "  'qid': 0,\n",
       "  'img_id': 'p19/p19454978/s57331547/7d047120-d24a497e-fc26ea7e-6c3acc0c-ce5bc190.jpg'})"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type2_close = trans_file_close(type2_qas)\n",
    "len(type2_close), type2_close[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "348f1da9-cc85-4992-94e4-af8f1b59c23f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data saved!\n"
     ]
    }
   ],
   "source": [
    "with open(\"saved json file\", 'w') as json_file:\n",
    "    json.dump(type2_close, json_file, indent=4)\n",
    "\n",
    "print(f\"Data saved!\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "language": "python",
   "display_name": "Python 3 (ipykernel)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
