{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a01493d0-2c47-44dd-9541-88f675bc6515",
   "metadata": {},
   "source": [
    "## Close-ended Data generation for Visual Factual Hallucination using IU-Xray"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "50f3e10a-207b-413a-b0fc-b9d9efe16db2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "#set the openai env variable here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ad131f03-0e5e-4618-91a7-a3370a5b7708",
   "metadata": {},
   "outputs": [],
   "source": [
    "api_version = \"2024-05-01-preview\"\n",
    "from langchain_openai import AzureChatOpenAI\n",
    "\n",
    "model_kwargs = dict(\n",
    "    model=\"gpt-4-128k\", \n",
    "    azure_endpoint=\"endpoint here\",\n",
    "    api_key=\"key here\",\n",
    "    api_version=api_version,\n",
    "    temperature=0.0,\n",
    ")\n",
    "\n",
    "gpt_model = AzureChatOpenAI(**model_kwargs, cache=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "99faaac1-edae-4abc-80f2-585845ccc7ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "logging.basicConfig()\n",
    "\n",
    "from langchain.chains import LLMChain\n",
    "from langchain.output_parsers import PydanticOutputParser\n",
    "from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
    "from pydantic import BaseModel, Field\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4102b637-787d-4756-bbd9-c3da28354339",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_msg_close = \"\"\" You are provided with the clinical report about a medical image. You task is to synthesize a set of new close-ended QA pairs according to the requirements. Unfortunately, you don’t have access to the actual image.\n",
    "Below are requirements for generating the questions and answers in the conversation.\n",
    "- Do not use phrases like \"mentioned\", \"report\", \"context\" in the conversation. Instead, refer to the information as being \"in the image.\"\n",
    "- Answer responsibly within the given information and knowledge context, avoiding information not included in the given context.\n",
    "- Make sure your generation contains a key \"type\" for each QA pair, indicating its question type (as detailed in the following part).\n",
    "- Make the questions more diverse.\n",
    "- The ground truth type can be \"yes or no\" or one choice from multi-choice (you need to synthesize several choices in this case), condition on the question and available materials.\n",
    "\n",
    "Here are a set of question types for your generation, you must assign each new QA pair a question type.\n",
    "type_1: Anatomical Hallucination. Example questions: \"Which part of the body does this image belong to?\" \"Does the picture contain liver?\"\n",
    "type_2: Measurement Hallucination like location, size. Example questions: \"Where is the liver?\"\n",
    "type_3: Symptom-Based Hallucination. Example questions: \"Is the lung healthy?\" \"Is there evidence of a pneumothorax?\" \"Is there a fracture?\"\n",
    "type_4: Technique Hallucination. Example questions:  \"What modality is used to take this image? \"\n",
    "\n",
    "Instructions:\n",
    "- Add one key to each QA pair, key= \"ground_truth_type\", value= \"binary\" if the type is \"yes or no\" else \"multi-choice\"\n",
    "- When you see diagnosis information of a disease (e.g. lung cancer) in QA pairs, you should generate new QA pair by asking the symptoms of the disease \n",
    "- For \"multi-choice\" type QA, you must include one key \"choices\" of string type.\n",
    "- Avoid the question that you can not generate a ground truth, for example avoid the answer \"The image does not provide information\". \n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "3d415b23-49a4-4892-b1a1-f2d73c038c73",
   "metadata": {},
   "outputs": [],
   "source": [
    "#for filtering low-quality generated questions\n",
    "sys_msg_filter = \"\"\"\n",
    "You are provided with a set of QA pairs in a json format. Your task is to improve or filter out low quality QA pairs.\n",
    "Here are some standards for the measurement of low-quality:\n",
    "- The answer is vague when the question type is multi-choice, try to make the answer clear, which choice is correct.\n",
    "- The answer is like \"xxx is not specified in the image.\", which means that this qa pair should be filtered out becuase there is no ground truth.\n",
    "You can also judge from other common standards.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "a6bb57f8-8176-4c23-abdb-b1da5d02dc9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any, Dict, Iterator, List, Optional, Sequence, Union\n",
    "\n",
    "class Closed_QA(BaseModel):\n",
    "    Q: str = Field(description=\"The question (Q)\")\n",
    "    A: str = Field(description=\"The answer (A)\")\n",
    "    type: str = Field(description=\"The QA type\")\n",
    "    choices: str = Field(default=\"\", description=\"The QA choices for multi-choice question type\")\n",
    "    ground_truth_type: str = Field(description=\"The ground_truth type\")\n",
    "    \n",
    "\n",
    "class Closed_QueryAnswer(BaseModel):\n",
    "    qa_pairs: List[Closed_QA] = Field(description=\"The QA pairs list generated.\")\n",
    "    \n",
    "#Here we use the classic parser in LangChain: Pydantic, to ensure a strict parsing process\n",
    "closed_parser = PydanticOutputParser(pydantic_object=Closed_QueryAnswer)\n",
    "\n",
    "colsed_prompt = ChatPromptTemplate.from_messages([\n",
    "    SystemMessagePromptTemplate.from_template(system_msg_close),\n",
    "    HumanMessagePromptTemplate(\n",
    "        prompt=PromptTemplate(\n",
    "            template=\"The medical report: {report}. generate high-quality questions for which the correct answers can be inferred solely from the provided report. \",\n",
    "            input_variables=[\"report\"],\n",
    "            partial_variables={\"format_instructions\": closed_parser.get_format_instructions()},\n",
    "        )\n",
    "    ),\n",
    "])\n",
    "\n",
    "#for filtering out low-quality generated questions\n",
    "filter_prompt = ChatPromptTemplate.from_messages([\n",
    "    SystemMessagePromptTemplate.from_template(sys_msg_filter),\n",
    "    HumanMessagePromptTemplate(\n",
    "        prompt=PromptTemplate(\n",
    "            template=\"The generated QA pairs in a json format: {qa_pairs}. Try to improve or filter out low-quality QA pairs, return in the original json format, but output the QA pair list as the value of additional key 'qa_pairs'. No other items, ensure the complete json format for parser!\",\n",
    "            input_variables=[\"qa_pairs\"],\n",
    "            partial_variables={\"format_instructions\": closed_parser.get_format_instructions()},\n",
    "        )\n",
    "    ),\n",
    "])\n",
    "\n",
    "\n",
    "from operator import itemgetter, attrgetter\n",
    "\n",
    "closed_chain = (\n",
    "    colsed_prompt \n",
    "    | gpt_model \n",
    "    | closed_parser\n",
    "    | {\"qa_pairs\": attrgetter(\"qa_pairs\")}\n",
    "    | filter_prompt\n",
    "    | gpt_model\n",
    "    | closed_parser\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a11228b2-d4b3-4207-b96c-4fdc33f0c8de",
   "metadata": {},
   "outputs": [],
   "source": [
    "#batch processing\n",
    "files_path = \"the path of IU-Xray data, the original QAs json file\"\n",
    "\n",
    "with open(files_path, 'r', encoding='utf8') as file:\n",
    "    xray_files = json.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "e30abff7-1889-4ad0-9343-7afec15a5bb2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(590,\n",
       " {'id': 'CXR2915_IM-1317',\n",
       "  'report': 'The heart size is normal. The mediastinal contour is within normal limits. The lungs are free of any focal infiltrates. There are no nodules or masses. No visible pneumothorax. No visible pleural fluid. The XXXX are grossly normal. There is no visible free intraperitoneal air under the diaphragm.',\n",
       "  'image_path': ['CXR2915_IM-1317/0.png', 'CXR2915_IM-1317/1.png'],\n",
       "  'split': 'test'})"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(xray_files['test']), xray_files['test'][5] #use the test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "cf581c6d-248a-453c-8a42-db37e73e6c2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#use 290 for close-ended generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "651e70a3-943a-41e1-a4f9-abe89b2b2e44",
   "metadata": {},
   "outputs": [],
   "source": [
    "xray_close_qas = []\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a1f62bbb-229f-4670-b61d-e8e9341c00bc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-19T17:33:06.719529100Z",
     "start_time": "2024-08-19T17:33:06.713514600Z"
    }
   },
   "outputs": [],
   "source": [
    "for i in tqdm(xray_files['test']):\n",
    "    img_path, report = i['image_path'][0] if i['image_path'][0][-5]==\"0\" else i['image_path'][1], i['report']\n",
    "    ans = closed_chain.invoke(dict(report=report))\n",
    "    final_new_qa = {\"img_path\":img_path, \"new_qa_pairs\": ans}\n",
    "    xray_close_qas.append(final_new_qa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "037cf2aa-e039-4bc1-8b3f-7589105670e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "290"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(xray_close_qas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "fa02cdab-3685-498b-a898-de28869110c9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'img_path': 'CXR3030_IM-1405/0.png',\n",
       "  'new_qa_pairs': Closed_QueryAnswer(qa_pairs=[Closed_QA(Q='Does the image show any signs of a large pleural effusion?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Closed_QA(Q='Is there any evidence of pneumothorax in the image?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Closed_QA(Q='Can we see any acute bony abnormalities in the image?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Closed_QA(Q='What part of the body is depicted in the image?', A='Chest', type='type_1', choices='A. Chest, B. Abdomen, C. Skull, D. Pelvis', omission_type=-1, ground_truth_type='multi-choice'), Closed_QA(Q='Is the cardiomediastinal silhouette normal in the image?', A='Yes', type='type_3', choices='', omission_type=1, ground_truth_type='binary'), Closed_QA(Q='Are there any signs of focal consolidation in the image?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Closed_QA(Q='What is the condition of the lungs in the image?', A='Healthy', type='type_3', choices='A. Healthy, B. Consolidated, C. Effusion present, D. Pneumothorax', omission_type=-1, ground_truth_type='multi-choice')])},\n",
       " {'img_path': 'CXR38_IM-1911/0.png',\n",
       "  'new_qa_pairs': Closed_QueryAnswer(qa_pairs=[Closed_QA(Q='Which part of the body does this image show?', A='The image shows the chest area.', type='type_1', choices='A. Chest area, B. Abdominal area, C. Pelvic area', omission_type=-1, ground_truth_type='multi-choice'), Closed_QA(Q='Is there a pneumothorax present in the image?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Closed_QA(Q='Are the bony structures in the image intact?', A='Yes', type='type_3', choices='', omission_type=1, ground_truth_type='binary'), Closed_QA(Q='What modality was used to take this image?', A='The modality is not specified in the image.', type='type_4', choices='A. MRI, B. CT, C. X-ray, D. Ultrasound', omission_type=-1, ground_truth_type='multi-choice'), Closed_QA(Q='Is there any evidence of pleural effusion in the image?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Closed_QA(Q='Are the lungs in the image clear?', A='Yes', type='type_3', choices='', omission_type=1, ground_truth_type='binary'), Closed_QA(Q='Is the heart size within normal limits according to the image?', A='Yes', type='type_3', choices='', omission_type=1, ground_truth_type='binary'), Closed_QA(Q='Where is the mediastinum located in the image?', A='The location of the mediastinum is not specified in the image.', type='type_2', choices='A. Upper left, B. Upper right, C. Center, D. Not specified', omission_type=-1, ground_truth_type='multi-choice')])}]"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#first version, with low-quality pairs\n",
    "xray_close_qas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "52dac661-2531-45e4-92e6-40f91a2e3514",
   "metadata": {},
   "outputs": [],
   "source": [
    "# store\n",
    "def trans_file(qas_list):\n",
    "    new_list = []\n",
    "    id_ = 0\n",
    "    for i in qas_list:\n",
    "        img_id = i['img_path']\n",
    "        new_qas = i['new_qa_pairs'].dict()['qa_pairs']\n",
    "        # old_qa_ids = img_data_ids[img_id]\n",
    "        # old_qa_pair = test_data_id_map[old_qa_ids[0]]\n",
    "        for j in new_qas:\n",
    "            new_dict = dict()\n",
    "            new_dict['img_name'] = img_id\n",
    "            new_dict['question'] = j['Q']\n",
    "            new_dict['answer'] = j['A']\n",
    "            if \"image\" in j['A'].lower() and \"not\" in j['A'] and \"provide\" in j['A']:\n",
    "                continue\n",
    "            if j['ground_truth_type'] == \"binary\":\n",
    "                if \"yes\" not in j['A'].lower() and \"no\" not in j['A'].lower():\n",
    "                    print(j['A'].lower())\n",
    "                    continue\n",
    "            new_dict['hallucination_type'] = j['type']\n",
    "            new_dict['question_type'] = j['ground_truth_type']\n",
    "            new_dict['choices'] = j['choices']\n",
    "            new_dict['qid'] = id_\n",
    "            new_dict['img_id'] = img_id\n",
    "            new_dict['location'] = None\n",
    "            new_dict['modality'] = None\n",
    "            new_list.append(new_dict)\n",
    "            id_ += 1\n",
    "    return new_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "45a120a6-f0d6-451d-b1e9-119413399bff",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-19T17:33:54.926433600Z",
     "start_time": "2024-08-19T17:33:54.897322900Z"
    }
   },
   "outputs": [],
   "source": [
    "xray_new_qas_json = trans_file(xray_close_qas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "f3c68897-39ef-4de3-b6b9-3fc4ea3461b4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2017"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(xray_new_qas_json)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "446c939a-c369-4beb-907c-6bca88b2fd35",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'img_name': 'CXR141_IM-0260/0.png',\n",
       " 'question': 'Does the image show any abnormalities in the cardiac size?',\n",
       " 'answer': 'No',\n",
       " 'hallucination_type': 'type_3',\n",
       " 'omission_type': 0,\n",
       " 'question_type': 'binary',\n",
       " 'choices': '',\n",
       " 'qid': 890,\n",
       " 'img_id': 'CXR141_IM-0260/0.png',\n",
       " 'location': None,\n",
       " 'modality': None}"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xray_new_qas_json[890]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "44508bc9-9eb0-4ff6-978c-c754a38095b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data saved!\n"
     ]
    }
   ],
   "source": [
    "with open(\"out path for iu_xray data\", 'w') as json_file:\n",
    "    json.dump(xray_new_qas_json, json_file, indent=4)\n",
    "\n",
    "print(f\"Data saved!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e87c9e9d-f8df-4600-a60f-6c64353a9b55",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "language": "python",
   "display_name": "Python 3 (ipykernel)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
