{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "#set the openai env variable here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "api_version = \"2024-05-01-preview\"\n",
    "from langchain_openai import AzureChatOpenAI\n",
    "model_kwargs = dict(\n",
    "    model=\"gpt-4-128k\",\n",
    "    azure_endpoint=\"endpoint here\",\n",
    "    api_key=\"key here\",\n",
    "    api_version=api_version,\n",
    "    temperature=0.0,\n",
    ")\n",
    "\n",
    "gpt_model = AzureChatOpenAI(**model_kwargs, cache=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "system_msg = \"\"\" You are provided with the metadata and a set of existing QA pairs about one medical image. You task is to synthesize a set of new close-ended QA pairs according to the requirements. Unfortunately, you don’t have access to the actual image. Below are requirements for generating the questions and answers in the conversation.\n",
    "\n",
    "- Do not use phrases like \"mentioned\", \"QA\", \"context\" in the conversation. Instead, refer to the information as being \"in the image.\"\n",
    "- Answer responsibly within the given information and knowledge context, avoiding information not included in the given context.\n",
    "- Make sure your generation contains a key \"type\" for each QA pair, indicating its question type (as detailed in the following part).\n",
    "- Make the questions diverse.\n",
    "- The ground truth type can be \"yes or no\" or one choice from multi-choice (you need to synthesize several choices in this case), condition on the question and available materials.\n",
    "\n",
    "Here are a set of question types for your generation, you must assign each new QA pair a question type.\n",
    "type_1: Anatomical Hallucination. Example questions: \"Which part of the body does this image belong to?\" \"Does the picture contain liver?\"\n",
    "type_2: Measurement Hallucination like location, size. Example questions: \"Where is the liver?\"\n",
    "type_3: Symptom-Based Hallucination. Example questions: \"Is the lung healthy?\" \"Is there evidence of a pneumothorax?\" \"Is there a fracture?\"\n",
    "type_4: Technique Hallucination. Example questions:  \"What modality is used to take this image? \"\n",
    "\n",
    "Instructions:\n",
    "- Add one key to each QA pair, key= \"ground_truth_type\", value= \"binary\" if the type is \"yes or no\" else \"multi-choice\"\n",
    "- When you see diagnosis information of a disease (e.g. lung cancer) in QA pairs, you should generate new QA pair by asking the symptoms of the disease \n",
    "- For \"multi-choice\" type QA, you must include one key \"choices\" of string type.\n",
    "- Avoid the question that you can not generate a ground truth, for example avoid the answer \"The image does not provide information\". \n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any, Dict, Iterator, List, Optional, Sequence, Union\n",
    "\n",
    "def inspect(state):\n",
    "    \"\"\"Print the state passed between Runnables in a langchain and pass it on\"\"\"\n",
    "    print(state, '\\n')\n",
    "    return state\n",
    "\n",
    "from langchain.output_parsers import PydanticOutputParser\n",
    "from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
    "from pydantic import BaseModel, Field\n",
    "import json\n",
    "\n",
    "\n",
    "def format_qa(qa_pairs):\n",
    "    # print(qa_pairs)\n",
    "    return qa_pairs['qa_pairs']\n",
    "\n",
    "class Closed_QA(BaseModel):\n",
    "    Q: str = Field(description=\"The question (Q)\")\n",
    "    A: str = Field(description=\"The answer (A)\")\n",
    "    type: str = Field(description=\"The QA type\")\n",
    "    choices: str = Field(default=\"\", description=\"The QA choices for multi-choice question type\")\n",
    "    ground_truth_type: str = Field(description=\"The ground_truth type\")\n",
    "    \n",
    "\n",
    "class QueryAnswer(BaseModel):\n",
    "    qa_pairs: List[Closed_QA] = Field(description=\"The QA pairs list generated.\")\n",
    "\n",
    "#Here we use the classic parser in LangChain: Pydantic, to ensure a strict parsing process\n",
    "closed_parser = PydanticOutputParser(pydantic_object=QueryAnswer)\n",
    "\n",
    "colsed_prompt = ChatPromptTemplate.from_messages([\n",
    "    SystemMessagePromptTemplate.from_template(system_msg),\n",
    "    HumanMessagePromptTemplate(\n",
    "        prompt=PromptTemplate(\n",
    "            template=\"Given the metadata: {metadata} and existing QA pairs: {qa_pairs}, generate high-quality questions for which the correct answers can be inferred solely from the provided information. Ensure the questions align with the specified question types.\",\n",
    "            input_variables=[\"metadata\", \"qa_pairs\"],\n",
    "            partial_variables={\"format_instructions\": closed_parser.get_format_instructions()},\n",
    "        )\n",
    "    ),\n",
    "])\n",
    "\n",
    "closed_chain = (\n",
    "    colsed_prompt \n",
    "    | gpt_model \n",
    "    | closed_parser\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batch Processing and Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Slake\n",
    "import os\n",
    "import json\n",
    "from collections import defaultdict\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "rad_data = r\"RAD-VQA DATA PATH\"\n",
    "test_file_path = os.path.join(rad_data, \"JSON file of the original RQA-VQA\")\n",
    "\n",
    "with open(test_file_path, 'r', encoding='utf8') as file:\n",
    "    rad_test = json.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data_id_map = dict()\n",
    "img_data_ids = defaultdict(list)\n",
    "for i in rad_test:\n",
    "    test_data_id_map[i['qid']] = i\n",
    "    img_data_ids[i['image_name']].append(i['qid'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "314"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(img_data_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['0', 13, 14, 16, 17, 21]"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img_data_ids['synpic54610.jpg']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_path = 'data/rad_vqa/images'\n",
    "def get_slake_qa(img_id):\n",
    "    final_qas = ''\n",
    "    qa_ids = img_data_ids[img_id]\n",
    "    for _id in qa_ids:\n",
    "        qa_pair = test_data_id_map[_id]\n",
    "        final_qas += f\"Q: {qa_pair['question']} \\nA: {qa_pair['answer']} \\n\"\n",
    "    metadata = \"image_organ:\" + qa_pair['image_organ'] # take the last qa pair and us the metadata of this image\n",
    "    return final_qas, metadata"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### filter rules:\n",
    "- binary type -- if no \"yes\" or \"no\", filter out\n",
    "- multi-choice type -- if no \"choice\", filter out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "slake_new_qas = [] # {\"img_path\":XXX, \"new_qa_pairs\":[xxx]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "slake_img_ids = list(img_data_ids.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2024-08-19T18:14:59.408101200Z",
     "start_time": "2024-08-19T18:14:59.394081200Z"
    }
   },
   "outputs": [],
   "source": [
    "for _id in tqdm(slake_img_ids):\n",
    "    qa, metadata = get_slake_qa(_id)\n",
    "    ans = closed_chain.invoke(dict(qa_pairs=qa, metadata=metadata))\n",
    "    final_new_qa = {\"img_path\":_id, \"new_qa_pairs\": ans}\n",
    "    slake_new_qas.append(final_new_qa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'img_path': 'synpic19114.jpg',\n",
       " 'new_qa_pairs': QueryAnswer(qa_pairs=[Closed_QA(Q='Is there evidence of pleural plaques in the image?', A='Yes', type='type_3', choices='', omission_type=1, ground_truth_type='binary'), Closed_QA(Q='What is the abnormality detected in the lungs?', A='Pleural plaques', type='type_3', choices='A. Pleural effusion, B. Pleural plaques, C. Pneumothorax, D. Lung nodules', omission_type=-1, ground_truth_type='multi-choice'), Closed_QA(Q='Is the aortic arch enlarged?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Closed_QA(Q='Does the image show a normal aortic arch contour?', A='Yes', type='type_3', choices='', omission_type=1, ground_truth_type='binary'), Closed_QA(Q='Are the pleural plaques localized to the hemithoraces?', A='Yes', type='type_2', choices='', omission_type=1, ground_truth_type='binary'), Closed_QA(Q='What part of the body is shown in the image?', A='Chest', type='type_1', choices='A. Chest, B. Abdomen, C. Pelvis, D. Skull', omission_type=-1, ground_truth_type='multi-choice'), Closed_QA(Q='Is the modality used for this image an X-ray?', A='Yes', type='type_4', choices='', omission_type=1, ground_truth_type='binary')])}"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "slake_new_qas[313]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trans_file(qas_list):\n",
    "    new_list = []\n",
    "    id_ = 0\n",
    "    for i in qas_list:\n",
    "        img_id = i['img_path']\n",
    "        new_qas = i['new_qa_pairs'].dict()['qa_pairs']\n",
    "        old_qa_ids = img_data_ids[img_id]\n",
    "        old_qa_pair = test_data_id_map[old_qa_ids[0]]\n",
    "        for j in new_qas:\n",
    "            new_dict = dict()\n",
    "            new_dict['img_name'] = img_id\n",
    "            new_dict['question'] = j['Q']\n",
    "            new_dict['answer'] = j['A']\n",
    "            if \"image\" in j['A'].lower() and \"not\" in j['A'] and \"provide\" in j['A']:\n",
    "                continue\n",
    "            if j['ground_truth_type'] == \"binary\":\n",
    "                if \"yes\" not in j['A'].lower() and \"no\" not in j['A'].lower():\n",
    "                    print(j['A'].lower())\n",
    "                    continue\n",
    "            new_dict['hallucination_type'] = j['type']\n",
    "            new_dict['question_type'] = j['ground_truth_type']\n",
    "            new_dict['choices'] = j['choices']\n",
    "            new_dict['qid'] = id_\n",
    "            new_dict['img_id'] = old_qa_pair['image_name']\n",
    "            new_dict['location'] = old_qa_pair['image_organ']\n",
    "            new_dict['modality'] = None\n",
    "            new_list.append(new_dict)\n",
    "            id_ += 1\n",
    "    return new_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2024-08-19T18:01:35.086199200Z",
     "start_time": "2024-08-19T18:01:35.073170800Z"
    }
   },
   "outputs": [],
   "source": [
    "rad_new_qas_json = trans_file(slake_new_qas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'img_name': 'synpic51926.jpg',\n",
       " 'question': 'Does the image show any abnormalities in the digestive system?',\n",
       " 'answer': 'Yes',\n",
       " 'hallucination_type': 'type_3',\n",
       " 'omission_type': 1,\n",
       " 'question_type': 'binary',\n",
       " 'choices': '',\n",
       " 'qid': 789,\n",
       " 'img_id': 'synpic51926.jpg',\n",
       " 'location': 'ABD',\n",
       " 'modality': None}"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rad_new_qas_json[789] #data example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data saved!\n"
     ]
    }
   ],
   "source": [
    "with open(\"saved json file\", 'w') as json_file:\n",
    "    json.dump(rad_new_qas_json, json_file, indent=4)\n",
    "\n",
    "print(f\"Data saved!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "language": "python",
   "display_name": "Python 3 (ipykernel)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
