{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b27ce4ee-22fa-4988-b300-533a3e7b3b83",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "#set the openai env variable here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "13dde9dd-4623-4fff-bed2-3c75bb2d2125",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "api_version = \"2024-05-01-preview\"\n",
    "from langchain_openai import AzureChatOpenAI\n",
    "model_kwargs = dict(\n",
    "    model=\"gpt-4-128k\",\n",
    "    azure_endpoint=\"endpoint here\",\n",
    "    api_key=\"key here\",\n",
    "    api_version=api_version,\n",
    "    temperature=0.0,\n",
    ")\n",
    "\n",
    "gpt_model = AzureChatOpenAI(**model_kwargs, cache=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ce6e1e0e-70bf-4a10-bc3f-4b1c7d95b56e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any, Dict, Iterator, List, Optional, Sequence, Union\n",
    "from langchain.output_parsers import PydanticOutputParser\n",
    "from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
    "from pydantic import BaseModel, Field\n",
    "import json\n",
    "from langchain_core.messages import HumanMessage, SystemMessage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cf7818c4-e1f2-401e-9a1b-c22657ab77eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "#batch process\n",
    "mimic_report_path = r\"mimic reports folder\"\n",
    "sampled_files_path = r\"mimic_cxr/sampled_paths.json\" # the sampled images from the MIMIC-CXR\n",
    "\n",
    "with open(sampled_files_path, 'r', encoding='utf8') as file:\n",
    "    mimic_files = json.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2594350f-e158-47da-95f6-6dc7009cc5ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "sampled_type4_path = mimic_files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "951783d9-39c0-476d-803f-cf102c1c91be",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-19T18:44:20.398853700Z",
     "start_time": "2024-08-19T18:44:20.395794100Z"
    }
   },
   "outputs": [],
   "source": [
    "notes = pd.read_csv(\"MIMIC-IV, csv file for report\")\n",
    "print(notes.shape)\n",
    "pids = list(set(notes['subject_id'].to_list()))\n",
    "print(len(pids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1257588c-b384-4328-963c-13791fb29beb",
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2024-08-19T18:44:59.512418500Z",
     "start_time": "2024-08-19T18:44:59.480909100Z"
    }
   },
   "outputs": [],
   "source": [
    "pid_repeat_dict = dict()\n",
    "import random\n",
    "def to_report_notes(file_path, notes):\n",
    "    path = \"/\".join(file_path.split(\"/\")[:3]) + \".txt\"\n",
    "    with open(os.path.join(mimic_report_path, path), \"r\") as f:\n",
    "        report = f.read()\n",
    "    pid = int(file_path.split(\"/\")[1][1:])\n",
    "    notes_list_pid = notes[notes['subject_id']==pid]['text'].to_list()\n",
    "    useful_flag = True\n",
    "    if pid_repeat_dict.__contains__(pid):\n",
    "        pid_repeat_dict[pid] += 1\n",
    "    else:\n",
    "        pid_repeat_dict[pid] = 1\n",
    "    useful_flag = True\n",
    "    sampled_notes = notes_list_pid.copy()\n",
    "    if len(notes_list_pid) > 15:\n",
    "        sampled_notes = random.sample(notes_list_pid, 15)\n",
    "    return path, file_path, report, sampled_notes, useful_flag\n",
    "to_report_notes(mimic_files[-1], notes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "129d495e-ce78-4795-8e78-4548c402b9bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "useful_ids = []\n",
    "useful_data = []\n",
    "for _id in sampled_type4_path:\n",
    "    _data = dict()\n",
    "    report_path, img_path, report, notes_pid, useful = to_report_notes(_id, notes)\n",
    "    if useful:\n",
    "        useful_ids.append(_id)\n",
    "        _data['img_path'] = img_path\n",
    "        _data['report'] = report\n",
    "        _data['notes'] = notes_pid\n",
    "        useful_data.append(_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "022a28ea-4963-447d-be4b-39f431fa1255",
   "metadata": {},
   "source": [
    "# Data generation for Contextual Hallucination"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "1a2a46b6-9dff-4ec6-91bf-f532d4b4f294",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_msg = \"\"\"\n",
    "You are provided with the report about a medical image, and the additional clinical notes of this patient. You task is to synthesize a set of close-ended QA pairs according to the requirements. Unfortunately, you don’t have access to the actual image. Below are requirements for generating the questions and answers in the conversation.\n",
    "\n",
    "Answer responsibly within the given information. Given the clinical notes and x-ray reports provided, you should design questions to test for contextual visual hallucination.  The goal is to ensure that the model interprets the x-ray images accurately within the specific context provided by the clinical notes,  without generating clinically inappropriate or inconsistent responses.\n",
    "\n",
    "Example yes-or-no questions:\n",
    "(1) Diagnostic Relevance: \"Given the patient’s history of severe back pain and the x-ray findings, is it likely that the back pain is due to a cardiopulmonary issue?\" (Expected Answer: No)\n",
    "(2) Complication Risk: \"Is it necessary to monitor the patient for potential complications related to the spine, given the findings of thoracic kyphosis and vertebral wedging?\" (Expected Answer: Yes)\n",
    "(3) Family History: \"Given the patient's family history of colorectal cancer, as mentioned in the clinical notes, are there any signs of colorectal abnormalities or precancerous lesions visible in the abdominal CT scan?\" (Expected Answer: Yes)\n",
    "(4) Symptom Analysis: \"Does the absence of focal consolidation in the chest x-ray suggest that the patient’s cough is unrelated to a pulmonary infection?\" (Expected Answer: Yes)\n",
    "\n",
    "Instructions:\n",
    "Generate \"yes-or-no\" questions, ensuring a balanced distribution of labels (yes and no).\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "2bb0a536-e141-47f5-aca2-b9a92723b760",
   "metadata": {},
   "outputs": [],
   "source": [
    "class QA(BaseModel):\n",
    "    question: str = Field(description=\"The question\")\n",
    "    answer: str = Field(description=\"The answer\")\n",
    "    \n",
    "class QueryAnswer(BaseModel):\n",
    "    qa_pairs: List[QA] = Field(description=\"The QA pairs for contextual evaluation.\")\n",
    "\n",
    "#Here we use the classic parser in LangChain: Pydantic, to ensure a strict parsing process\n",
    "qa_parser = PydanticOutputParser(pydantic_object=QueryAnswer)\n",
    "\n",
    "qa_prompt = ChatPromptTemplate.from_messages([\n",
    "    SystemMessage(content=system_msg),\n",
    "    HumanMessagePromptTemplate(\n",
    "        prompt=PromptTemplate(\n",
    "            template=\"The medical report of the image: {report}, clinical notes: {notes}. Generate high-quality close-ended question-answer pairs.\",\n",
    "            input_variables=[\"report\", \"clinical_notes\"],\n",
    "            partial_variables={\"format_instructions\": qa_parser.get_format_instructions()},\n",
    "        )\n",
    "    ),\n",
    "])\n",
    "\n",
    "\n",
    "qa_chain = (\n",
    "    qa_prompt \n",
    "    | gpt_model \n",
    "    | qa_parser\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "9f4aaa66-7e8e-452b-bcca-1564c212476d",
   "metadata": {},
   "outputs": [],
   "source": [
    "qas = []\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "50c9a388-99f2-4ce6-abf5-c47104942d02",
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2024-08-19T19:01:13.690343100Z",
     "start_time": "2024-08-19T19:01:13.662873800Z"
    }
   },
   "outputs": [],
   "source": [
    "for i in tqdm(range(len(useful_data))):\n",
    "    report, notes = useful_data[i]['report'], useful_data[i][\"notes\"]\n",
    "    img_path = useful_data[i]['img_path']\n",
    "    qa = qa_chain.invoke(dict(report=report, clinical_notes=notes))\n",
    "    final_new_qa = {\"img_path\":img_path, \"qa_pairs\": qa.dict()}\n",
    "    qas.append(final_new_qa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "17d2b112-c9d3-4c99-99a1-02335d2a915f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# store\n",
    "def trans_file_close(qas_list):\n",
    "    new_list = []\n",
    "    id_ = 0\n",
    "    for i in qas_list:\n",
    "        img_id = i['img_path']\n",
    "        new_qas = i['qa_pairs']['qa_pairs']\n",
    "        \n",
    "        for j in new_qas:\n",
    "            new_dict = dict()\n",
    "            new_dict['img_name'] = img_id\n",
    "            new_dict['question'] = j[\"question\"]\n",
    "            new_dict['answer'] = j[\"answer\"]\n",
    "            new_dict['question_type'] = 'type3_contextual'\n",
    "            new_dict['ground_truth_type'] = j['ground_truth_type']\n",
    "            new_dict['choices'] = j['choices']\n",
    "            \n",
    "            new_dict['qid'] = id_\n",
    "            new_dict['img_id'] = img_id\n",
    "\n",
    "            new_list.append(new_dict)\n",
    "            id_ += 1\n",
    "\n",
    "    return new_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "f22d926d-3b1b-4373-860d-de39e6c61070",
   "metadata": {},
   "outputs": [],
   "source": [
    "type3_list = trans_file_close(qas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "f9f1a18c-1d1f-4684-b447-15001cb1cdc4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data saved!\n"
     ]
    }
   ],
   "source": [
    "with open(\"saved json file\", 'w') as json_file:\n",
    "    json.dump(type3_list, json_file, indent=4)\n",
    "\n",
    "print(f\"Data saved!\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "language": "python",
   "display_name": "Python 3 (ipykernel)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
