{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "#set the openai env variable here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "api_version = \"2024-05-01-preview\"\n",
    "from langchain_openai import AzureChatOpenAI\n",
    "model_kwargs = dict(\n",
    "    model=\"gpt-4-128k\",\n",
    "    azure_endpoint=\"endpoint here\",\n",
    "    api_key=\"key here\",\n",
    "    api_version=api_version,\n",
    "    temperature=0.0,\n",
    ")\n",
    "\n",
    "gpt_model = AzureChatOpenAI(**model_kwargs, cache=False)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 241,
   "metadata": {},
   "outputs": [],
   "source": [
    "slake_system_msg = \"\"\" You are provided with the metadata, object bounding boxes and a set of existing QA pairs about one medical image. You task is to synthesize a set of new close-ended QA pairs according to the requirements.\n",
    "Unfortunately, you don’t have access to the actual image. Below are requirements for generating the questions and answers in the conversation.\n",
    "- Do not use phrases like \"mentioned\", \"bounding boxes\", \"context\" in the conversation. Instead, refer to the information as being \"in the image.\"\n",
    "- Answer responsibly within the given information and knowledge context, avoiding information not included in the given context.\n",
    "- Make sure your generation contains a key \"type\" for each QA pair, indicating its question type (as detailed in the following part).\n",
    "- Make the questions more diverse.\n",
    "- The ground truth type can be \"yes or no\" or one choice from multi-choice (you need to synthesize several choices in this case), condition on the question and available materials.\n",
    "\n",
    "Here are a set of question types for your generation, you must assign each new QA pair a question type.\n",
    "type_1: Anatomical Hallucination. Example questions: \"Which part of the body does this image belong to?\" \"Does the picture contain liver?\"\n",
    "type_2: Measurement Hallucination like location, size. Example questions: \"Where is the liver?\"\n",
    "type_3: Symptom-Based Hallucination. Example questions: \"Is the lung healthy?\" \"Is there evidence of a pneumothorax?\" \"Is there a fracture?\"\n",
    "type_4: Technique Hallucination. Example questions:  \"What modality is used to take this image? \"\n",
    "\n",
    "Instructions:\n",
    "- Add one key to each QA pair, key= \"ground_truth_type\", value= \"binary\" if the type is \"yes or no\" else \"multi-choice\"\n",
    "- When you see diagnosis information of a disease (e.g. lung cancer) in QA pairs, you should generate new QA pair by asking the symptoms of the disease \n",
    "- For \"multi-choice\" type QA, you must include one key \"choices\" of string type.\n",
    "- Avoid the question that you can not generate a ground truth, for example avoid the answer \"The image does not provide information\". \n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 256,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.output_parsers import PydanticOutputParser\n",
    "from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
    "from pydantic import BaseModel, Field\n",
    "from typing import Any, Dict, Iterator, List, Optional, Sequence, Union\n",
    "\n",
    "def format_qa(qa_pairs):\n",
    "    # print(qa_pairs)\n",
    "    return qa_pairs['qa_pairs']\n",
    "\n",
    "class Slake_closed_QA(BaseModel):\n",
    "    Q: str = Field(description=\"The question (Q)\")\n",
    "    A: str = Field(description=\"The answer (A)\")\n",
    "    type: str = Field(description=\"The QA type\")\n",
    "    choices: str = Field(default=\"\", description=\"The QA choices for multi-choice question type\")\n",
    "    ground_truth_type: str = Field(description=\"The ground_truth type\")\n",
    "    \n",
    "\n",
    "class SlakeQueryAnswer(BaseModel):\n",
    "    qa_pairs: List[Slake_closed_QA] = Field(description=\"The QA pairs list generated.\")\n",
    "\n",
    "#Here we use the classic parser in LangChain: Pydantic, to ensure a strict parsing process\n",
    "slake_closed_parser = PydanticOutputParser(pydantic_object=SlakeQueryAnswer)\n",
    "\n",
    "slake_colsed_prompt = ChatPromptTemplate.from_messages([\n",
    "    SystemMessagePromptTemplate.from_template(slake_system_msg),\n",
    "    HumanMessagePromptTemplate(\n",
    "        prompt=PromptTemplate(\n",
    "            template=\"The metadata: {metadata} , object bounding boxes:{bounding_boxes}, existing QA pairs: {qa_pairs}. Generate high-quality questions for which the correct answers can be inferred solely from the provided information. Ensure the questions align with the specified question types.\",\n",
    "            input_variables=[\"metadata\", \"bounding_boxes\", \"qa_pairs\"],\n",
    "            partial_variables={\"format_instructions\": slake_closed_parser.get_format_instructions()},\n",
    "        )\n",
    "    ),\n",
    "])\n",
    "\n",
    "\n",
    "\n",
    "slake_closed_chain = (\n",
    "    slake_colsed_prompt \n",
    "    | gpt_model \n",
    "    | slake_closed_parser\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batch Processing and Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 252,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Slake\n",
    "import os\n",
    "import json\n",
    "from collections import defaultdict\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 196,
   "metadata": {},
   "outputs": [],
   "source": [
    "slake_data = r\"SLAKE FOLDER\"\n",
    "test_file_path = os.path.join(slake_data, \"the original SLAKE QAs json file\")\n",
    "\n",
    "with open(test_file_path, 'r', encoding='utf8') as file:\n",
    "    slake_test = json.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 201,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data_id_map = dict()\n",
    "img_data_ids = defaultdict(list)\n",
    "for i in slake_test:\n",
    "    test_data_id_map[i['qid']] = i\n",
    "    img_data_ids[i['img_name']].append(i['qid'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 207,
   "metadata": {},
   "outputs": [],
   "source": [
    "# img_data_ids.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 205,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[11934, 11935, 11936, 11937, 11938, 11939, 11940, 11941, 11942, 11943, 11944]"
      ]
     },
     "execution_count": 205,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img_data_ids['xmlab102/source.jpg']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 220,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_path = 'data/slake/imgs'\n",
    "def get_slake_qa(img_id):\n",
    "    final_qas = ''\n",
    "    qa_ids = img_data_ids[img_id]\n",
    "    for _id in qa_ids:\n",
    "        qa_pair = test_data_id_map[_id]\n",
    "        final_qas += f\"Q: {qa_pair['question']} \\nA: {qa_pair['answer']} \\n\"\n",
    "    metadata = \"image_organ:\" + qa_pair['location'] + \"image modality:\"+ qa_pair['metadata'] # take the last qa pair and us the metadata of this image\n",
    "    bounding_path = os.path.join(img_path, f'{img_id.split(\"/\")[0]}/detection.json')\n",
    "    with open(bounding_path, 'r', encoding='utf8') as file:\n",
    "        bounding_dectection = json.load(file)\n",
    "    return metadata, final_qas, bounding_dectection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 221,
   "metadata": {},
   "outputs": [],
   "source": [
    "metadata, qa, bounding = get_slake_qa('xmlab102/source.jpg')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### filter rules:\n",
    "- binary type -- if no \"yes\" or \"no\", filter out\n",
    "- multi-choice type -- if no \"choice\", filter out\n",
    "- "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 253,
   "metadata": {},
   "outputs": [],
   "source": [
    "slake_new_qas = [] # {\"img_path\":XXX, \"new_qa_pairs\":[xxx]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 250,
   "metadata": {},
   "outputs": [],
   "source": [
    "slake_img_ids = list(img_data_ids.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2024-08-19T18:14:03.743982900Z",
     "start_time": "2024-08-19T18:14:03.741470500Z"
    }
   },
   "outputs": [],
   "source": [
    "for _id in tqdm(slake_img_ids):\n",
    "    metadata, qa, bounding = get_slake_qa(_id)\n",
    "    ans = slake_closed_chain.invoke(dict(qa_pairs=qa, bounding_boxes=bounding, metadata=metadata))\n",
    "    final_new_qa = {\"img_path\":_id, \"new_qa_pairs\": ans}\n",
    "    slake_new_qas.append(final_new_qa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 269,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'img_path': 'xmlab80/source.jpg',\n",
       " 'new_qa_pairs': SlakeQueryAnswer(qa_pairs=[Slake_closed_QA(Q='Is the imaging modality used for this image an MRI?', A='No', type='type_4', choices='', omission_type=0, ground_truth_type='binary'), Slake_closed_QA(Q='Is the primary organ shown in the image the heart?', A='No', type='type_1', choices='', omission_type=0, ground_truth_type='binary'), Slake_closed_QA(Q='Is the abnormality found in the left lung?', A='Yes', type='type_2', choices='', omission_type=1, ground_truth_type='binary'), Slake_closed_QA(Q='Does the image show any signs of pneumonia?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Slake_closed_QA(Q='What disease is present in the image?', A='A. Lung Cancer', type='type_3', choices='A. Lung Cancer, B. Liver Cirrhosis, C. Cardiomegaly', omission_type=-1, ground_truth_type='multi-choice'), Slake_closed_QA(Q='Is there a presence of a tumor in the image?', A='Yes', type='type_3', choices='', omission_type=1, ground_truth_type='binary'), Slake_closed_QA(Q='Which side of the lung is the cancer located?', A='B. Right', type='type_2', choices='A. Left, B. Right, C. Both', omission_type=-1, ground_truth_type='multi-choice'), Slake_closed_QA(Q='Is the heart larger than the liver in the image?', A='Yes', type='type_1', choices='', omission_type=-1, ground_truth_type='binary'), Slake_closed_QA(Q='Is the image taken from the abdominal region?', A='No', type='type_1', choices='', omission_type=0, ground_truth_type='binary'), Slake_closed_QA(Q='Is the liver visible in the image?', A='No', type='type_1', choices='', omission_type=0, ground_truth_type='binary')])}"
      ]
     },
     "execution_count": 269,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "slake_new_qas[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 293,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trans_file(qas_list):\n",
    "    id_ = 0\n",
    "    new_list = []\n",
    "    for i in qas_list:\n",
    "        img_id = i['img_path']\n",
    "        new_qas = i['new_qa_pairs'].dict()['qa_pairs']\n",
    "        old_qa_ids = img_data_ids[img_id]\n",
    "        old_qa_pair = test_data_id_map[old_qa_ids[0]]\n",
    "        for j in new_qas:\n",
    "            new_dict = dict()\n",
    "            new_dict['img_name'] = img_id\n",
    "            new_dict['question'] = j['Q']\n",
    "            new_dict['answer'] = j['A']\n",
    "            if \"image\" in j['A'].lower() and \"not\" in j['A'] and \"provide\" in j['A']:\n",
    "                continue\n",
    "            if j['ground_truth_type'] == \"binary\":\n",
    "                if \"yes\" not in j['A'].lower() and \"no\" not in j['A'].lower():\n",
    "                    print(j['A'].lower())\n",
    "                    continue\n",
    "            new_dict['hallucination_type'] = j['type']\n",
    "            new_dict['question_type'] = j['ground_truth_type']\n",
    "            new_dict['choices'] = j['choices']\n",
    "            new_dict['img_id'] = old_qa_pair['img_id']\n",
    "            new_dict['qid'] = id_\n",
    "            new_dict['location'] = old_qa_pair['location']\n",
    "            new_dict['modality'] = old_qa_pair['modality']\n",
    "            new_list.append(new_dict)\n",
    "            id_ += 1\n",
    "    return new_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2024-08-19T18:15:36.742820900Z",
     "start_time": "2024-08-19T18:15:36.728725500Z"
    }
   },
   "outputs": [],
   "source": [
    "slake_new_qas_json = trans_file(slake_new_qas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 300,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'img_name': 'xmlab103/source.jpg',\n",
       " 'question': 'What is the condition of the right lung?',\n",
       " 'answer': 'Abnormal',\n",
       " 'hallucination_type': 'type_3',\n",
       " 'omission_type': -1,\n",
       " 'question_type': 'multi-choice',\n",
       " 'choices': 'A: Healthy, B: Abnormal, C: Not visible',\n",
       " 'img_id': 103,\n",
       " 'qid': 9,\n",
       " 'location': 'Abdomen',\n",
       " 'modality': 'CT'}"
      ]
     },
     "execution_count": 300,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "slake_new_qas_json[9] #data example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 298,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data saved!\n"
     ]
    }
   ],
   "source": [
    "with open(\"saved json file\", 'w') as json_file:\n",
    "    json.dump(slake_new_qas_json, json_file, indent=4)\n",
    "\n",
    "print(f\"Data saved!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "language": "python",
   "display_name": "Python 3 (ipykernel)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
