import json
import pandas as pd
from tqdm import tqdm
import os

PROMPT = {
'DEFAULT_DDX':("What are the possible differential diagnoses for this patient?"),
'DEFAULT_GEN':("Write a findings report on the given chest x-ray, including information about any abnormalities that you see.\nFindings: "),
'DEFAULT_SUM':("Write an impression summarization of the given chest x-ray and findings report, "
               "including information about any abnormalities that you see.\nFindings: "),
'HEATMAP_DESC':("The chest x-ray image below is overlaid with red heat map dots, representing key areas of the radiologist's eye gaze. "
                "The intensity of the red color indicates the duration of fixation, meaning the darker the dot, the longer the gaze in that region. "
                "Use this eye gaze pattern to answer the question that follows.\n"),
'FIX_DESC':("The chest x-ray image is paired with key eye fixation data, "
            "detailing where and for how long the radiologist's attention was focused. "
            "The coordinates (X, Y) represent the relative position on the image, "
            "and the fixation time indicates the duration of focus at each point. "
            "This information follows the sequence of the radiologist's eye movements. "
            "Use the provided eye fixation data to answer the question that follows.\n"
            "Eye Fixation Data:\n"
           ),
'VIDEO_DESC':("The video below shows a dynamic representation of the radiologist's eye fixation as they review a chest x-ray. "
              "Red dots indicate fixation points, and the duration of each fixation is represented by the length of time each dot remains visible on the screen. "
              "The video also shows the tracking sequence of the eye gaze. "
              "Use both the fixation points and the gaze tracking in the video to answer the question that follows.\n"),
'INCONTEXT_LEARNING_DESC':("The following are three exemplar chest x-ray reports, each providing a detailed analysis of the findings, impressions, and any identified abnormalities. "
                           "Review these reports carefully to understand the structure, terminology, and format used in professional radiology reporting. "
                           "Use them as a reference to answer the question based on the chest x-ray provided.\n\n"
                           "Exemplar Report 1:\n"
                           "Findings: Thin left-sided dual-chamber pacemaker device is again noted with leads terminating in the right atrium and ventricle. "
                           "The aorta is mildly tortuous, otherwise the hilar and mediastinal contours are unremarkable. Note is made of mild pulmonary vascular congestion. "
                           "No focal consolidations concerning for pneumonia are identified. There is no pleural effusion or pneumothorax. "
                           "Mild to moderate degenerative changes are seen involving both acromioclavicular joints.\n"
                           "Impression: No acute intrathoracic abnormalities identified.\n\n"
                           "Exemplar Report 2:\n"
                           "Findings: PA and lateral views of the chest provided. There is no focal consolidation, effusion, or pneumothorax. "
                           "The cardiomediastinal silhouette is normal. Imaged osseous structures are intact. No free air below the right hemidiaphragm is seen.\n"
                           "Impression: No acute intrathoracic process.\n\n"
                           "Exemplar Report 3:\n"
                           "Findings: ET tube tip is 5 cm above the carinal. Swan-Ganz catheter tip is in the right main pulmonary artery. "
                           "Heart size and mediastinum are unremarkable. No pneumothorax is seen. "
                           "There has been interval increase in day hazy alveolar infiltrate in the right lung and patchy alveolar infiltrate in the left lung.\n"
                           "Impression: Worsened appearance of the bilateral alveolar infiltrate.\n\n"
                          ),
}

def process_dict_icl_ddx(di, idx):
    return {'image': di['image_id'].strip(), 'sys': "",
            'question_id': idx, 'conversations': [
                {'from': 'human', 'value': '<image>\n'+PROMPT['INCONTEXT_LEARNING_DESC']+PROMPT['DEFAULT_DDX']},
                {'from': 'gpt', 'value': di['differential_diagnosis'].strip()}
            ]}

def process_dict_icl_ddx_fix(di, idx):
    return {'image': di['image_id'].strip(), 'sys': "",
            'question_id': idx, 'conversations': [
                {'from': 'human', 'value': '<image>\n'+PROMPT['FIX_DESC']+f"{di['fixation_text']}\n"+PROMPT['INCONTEXT_LEARNING_DESC']+PROMPT['DEFAULT_DDX']},
                {'from': 'gpt', 'value': di['differential_diagnosis'].strip()}
            ]}

def process_dict_icl_ddx_heat(di, idx, _id='heatmap_image_id'):
    return {'image': di[_id].strip(), 'sys': "",
            'question_id': idx, 'conversations': [
                {'from': 'human', 'value': '<image>\n'+PROMPT['HEATMAP_DESC']+PROMPT['INCONTEXT_LEARNING_DESC']+PROMPT['DEFAULT_DDX']},
                {'from': 'gpt', 'value': di['differential_diagnosis'].strip()}
            ]}

def process_dict_icl_ddx_video(di, idx, _id='video_id'):
    return {'video': di[_id].strip(), 'sys': "",
            'question_id': idx, 'conversations': [
                {'from': 'human', 'value': '<image>\n'+PROMPT['VIDEO_DESC']+PROMPT['INCONTEXT_LEARNING_DESC']+PROMPT['DEFAULT_DDX']},
                {'from': 'gpt', 'value': di['differential_diagnosis'].strip()}
            ]}


def process_dict_icl_gen(di, idx):
    return {'image': di['image_id'].strip(), 'sys': "",
            'question_id': idx, 'conversations': [
                {'from': 'human', 'value': '<image>\n'+PROMPT['INCONTEXT_LEARNING_DESC']+PROMPT['DEFAULT_GEN']},
                {'from': 'gpt', 'value': di['findings'].strip()}
            ]}

def process_dict_icl_gen_fix(di, idx, _id='heatmap_image_id'):
    return {'image': di['image_id'].strip(), 'sys': "",
            'question_id': idx, 'conversations': [
                {'from': 'human', 'value': '<image>\n'+PROMPT['FIX_DESC']+f"{di['fixation_text']}\n"+PROMPT['INCONTEXT_LEARNING_DESC']+PROMPT['DEFAULT_GEN']},
                {'from': 'gpt', 'value': di['findings'].strip()}
            ]}

def process_dict_icl_gen_heat(di, idx, _id='heatmap_image_id'):
    return {'image': di[_id].strip(), 'sys': "",
            'question_id': idx, 'conversations': [
                {'from': 'human', 'value': '<image>\n'+PROMPT['HEATMAP_DESC']+PROMPT['INCONTEXT_LEARNING_DESC']+PROMPT['DEFAULT_GEN']},
                {'from': 'gpt', 'value': di['findings'].strip()}
            ]}

def process_dict_icl_gen_video(di, idx, _id='video_id'):
    return {'video': di[_id].strip(), 'sys': "",
            'question_id': idx, 'conversations': [
                {'from': 'human', 'value': '<image>\n'+PROMPT['VIDEO_DESC']+PROMPT['INCONTEXT_LEARNING_DESC']+PROMPT['DEFAULT_GEN']},
                {'from': 'gpt', 'value': di['findings'].strip()}
            ]}


def process_dict_icl_sum(di, idx):
    return {'image': di['image_id'].strip(), 'sys': "", 
            'question_id': idx, 'conversations': [
                {'from': 'human', 'value': '<image>\n'+PROMPT['DEFAULT_SUM']+PROMPT['INCONTEXT_LEARNING_DESC']+f"{di['findings']}\nImpression: "},
                {'from': 'gpt', 'value': di['impression'].strip()}
            ]}

def process_dict_icl_sum_fix(di, idx):
    return {'image': di['image_id'].strip(), 'sys': "",
            'question_id': idx, 'conversations': [
                {'from': 'human', 'value': '<image>\n'+PROMPT['FIX_DESC']+f"{di['fixation_text']}\n"+PROMPT['INCONTEXT_LEARNING_DESC']+PROMPT['DEFAULT_SUM']+f"{di['findings']}\nImpression: "},
                {'from': 'gpt', 'value': di['impression'].strip()}
            ]}

def process_dict_icl_sum_heat(di, idx, _id='heatmap_image_id'):
    return {'image': di[_id].strip(), 'sys': "",
            'question_id': idx, 'conversations': [
                {'from': 'human', 'value': '<image>\n'+PROMPT['HEATMAP_DESC']+PROMPT['INCONTEXT_LEARNING_DESC']+PROMPT['DEFAULT_SUM']+f"{di['findings']}\nImpression: "},
                {'from': 'gpt', 'value': di['impression'].strip()}
            ]}

def process_dict_icl_sum_video(di, idx, _id='video_id'):
    return {'video': di[_id].strip(), 'sys': "",
            'question_id': idx, 'conversations': [
                {'from': 'human', 'value': '<image>\n'+PROMPT['VIDEO_DESC']+PROMPT['INCONTEXT_LEARNING_DESC']+PROMPT['DEFAULT_SUM']+f"{di['findings']}\nImpression: "},
                {'from': 'gpt', 'value': di['impression'].strip()}
            ]}


def process(d, mode='alpha'):
    with open(f'DDX/ICLR_ICLDEFAULT_{mode}.jsonl', 'w') as fi:
        data_dict=[process_dict_icl_ddx(di, idx) for idx, di in enumerate(d) if len(di['findings'])>0]
        fi.write(json.dumps(data_dict))
    with open(f'GEN/ICLR_ICLDEFAULT_{mode}.jsonl', 'w') as fi:
        data_dict=[process_dict_icl_gen(di, idx) for idx, di in enumerate(d) if len(di['findings'])>0]
        fi.write(json.dumps(data_dict))
    with open(f'SUM/ICLR_ICLDEFAULT_{mode}.jsonl', 'w') as fi:
        data_dict=[process_dict_icl_sum(di, idx) for idx, di in enumerate(d) if len(di['findings'])>0 and len(di['impression'])>0]
        fi.write(json.dumps(data_dict))
        
    with open(f'DDX/ICLR_ICLFIX_{mode}.jsonl', 'w') as fi:
        data_dict=[process_dict_icl_ddx_fix(di, idx) for idx, di in enumerate(d) if len(di['findings'])>0]
        fi.write(json.dumps(data_dict))
    with open(f'GEN/ICLR_ICLFIX_{mode}.jsonl', 'w') as fi:
        data_dict=[process_dict_icl_gen_fix(di, idx) for idx, di in enumerate(d) if len(di['findings'])>0]
        fi.write(json.dumps(data_dict))
    with open(f'SUM/ICLR_ICLFIX_{mode}.jsonl', 'w') as fi:
        data_dict=[process_dict_icl_sum_fix(di, idx) for idx, di in enumerate(d) if len(di['findings'])>0 and len(di['impression'])>0]
        fi.write(json.dumps(data_dict))
        
    with open(f'DDX/ICLR_ICLHEAT_{mode}.jsonl', 'w') as fi:
        data_dict=[process_dict_icl_ddx_heat(di, idx) for idx, di in enumerate(d) if len(di['findings'])>0]
        fi.write(json.dumps(data_dict))
    with open(f'GEN/ICLR_ICLHEAT_{mode}.jsonl', 'w') as fi:
        data_dict=[process_dict_icl_gen_heat(di, idx) for idx, di in enumerate(d) if len(di['findings'])>0]
        fi.write(json.dumps(data_dict))
    with open(f'SUM/ICLR_ICLHEAT_{mode}.jsonl', 'w') as fi:
        data_dict=[process_dict_icl_sum_heat(di, idx) for idx, di in enumerate(d) if len(di['findings'])>0 and len(di['impression'])>0]
        fi.write(json.dumps(data_dict))
        
    with open(f'DDX/ICLR_ICLVIDEO_{mode}.jsonl', 'w') as fi:
        data_dict=[process_dict_icl_ddx_video(di, idx) for idx, di in enumerate(d) if len(di['findings'])>0]
        fi.write(json.dumps(data_dict))
    with open(f'GEN/ICLR_ICLVIDEO_{mode}.jsonl', 'w') as fi:
        data_dict=[process_dict_icl_gen_video(di, idx) for idx, di in enumerate(d) if len(di['findings'])>0]
        fi.write(json.dumps(data_dict))
    with open(f'SUM/ICLR_ICLVIDEO_{mode}.jsonl', 'w') as fi:
        data_dict=[process_dict_icl_sum_video(di, idx) for idx, di in enumerate(d) if len(di['findings'])>0 and len(di['impression'])>0]
        fi.write(json.dumps(data_dict))

if __name__ == '__main__':
    if not os.path.exists('DDX'):
        os.makedirs('DDX')
    if not os.path.exists('GEN'):
        os.makedirs('GEN')
    if not os.path.exists('SUM'):
        os.makedirs('SUM')
        
    with open('mimic-eye-video-alpha.json', 'w') as fi:
        d=json.loads(f.read())
    process(d, mode='ALPHA')
    
    with open('mimic-eye-video-beta.json', 'w') as fi:
        d=json.loads(f.read())
    process(d, mode='BETA')