import os
import sys
import time
import openai
from tqdm import tqdm
from pathlib import Path
from openai import APIError
from openai.error import APIConnectionError, Timeout, RateLimitError, ServiceUnavailableError

from mimic_cxr_utils import *
from chexpert_labeler import ChexpertLabeler

assert len(sys.argv) == 2
task = sys.argv[1]
assert task in ["correction", "history", "template"]


openai.api_type = "azure"
openai.api_base = "https://gcrgpt4aoai9c.openai.azure.com/"
openai.api_version = "2023-03-15-preview"
openai.api_key = os.getenv("OPENAI_API_KEY")

data_dir = "/data/datasets/MIMIC-CXR/"
out_dir = f"/data/datasets/MIMIC-CXR/files/reports_{task}/"

data_dir = Path(data_dir)
if not os.path.isdir(out_dir):
    os.mkdir(out_dir)

dicomid2label = create_id2label_dict(data_dir/"mimic-cxr-2.0.0-metadata.csv")
studyid2split = create_id2split_dict(data_dir/"mimic-cxr-2.0.0-split.csv")
studyid2path = create_id2path_dict(data_dir/"mimic-cxr-2.0.0-metadata.csv")
metadata = pd.read_csv(data_dir/"mimic-cxr-2.0.0-metadata.csv")

if task == "correction":
    messages = [
        {
            "role":"system",
            "content":"You are a careless radiologist and will make some minor diagnosis error when reading a chest x-ray radiograph. You will be given a correct chest x-ray report with a findings section, and you will be asked to rewrite the report with a few diagnostic errors. Then provide the instructions of how to correct your wrong diagnosis."
        },{
            "role":"user",
            "content":"Findings: Normal lung parenchyma without consolidation or mass lesions. Centrilobular emphysema in upper lobes with small nodular opacities (4-6 mm). No lymphadenopathy or pleural effusion. Unremarkable cardiac and upper abdominal structures. Impression: Normal lung parenchyma. Centrilobular emphysema in upper lobes with small nodular opacities. No lymphadenopathy or pleural effusion. Unremarkable cardiac and upper abdominal structures."
        },{
            "role":"assistant",
            "content":"Report#: Findings: Normal lung parenchyma without consolidation or mass lesions. No emphysema. Pleural effusion in lower left lobe. Unremarkable cardiac and upper abdominal structures. Impression: Normal lung parenchyma. Lower left pleural effusion. Unremarkable cardiac and upper abdominal structures.\nInstructions#: Include centrilobular emphysema in upper lobes with small nodular opacities (4-6 mm). The patient has no pleural effusion."
        },{
            "role":"user",
            "content":""
        }
    ]
elif task == "history":
    messages = [
        {
            "role": "system",
            "content": "You are a clinical assistant that helps radiologists retrieve information of patient's other medical records and examination results. You will be given a chest x-ray diagnosis report, and you need to infer the patient's possible medical conditions or history and other medical examination the patient should have done. You need to first give several examples of their possible medical conditions based on the report diagnosis, then list a few the medical examinations that the patient should have done, and finally give a few examples of the patient's medical examination results that may lead to the diagnosis in the given report. List none if the report indicates no acute cardiopulmonary disease. Please make sure the inferred possible medical condition does not include any information described in the report findings or impression. Also make sure the examples of exam results are consistent with the findings described in the given report. This is very important to my career."
        }, {
            "role": "user",
            "content": ""
        }, {
            "role": "assistant",
            "content": "Possible Medical Conditions:\nPossible Medical Examinations:\nExamples of Examination Results:"
        }, {
            "role": "user",
            "content": ""
        }
    ]
elif task == "template":
    messages = [
        {
            "role": "system",
            "content": "You are a clinical assistant helping radiologists write detailed and well-formatted chest x-ray reports. Without referencing the given report, you should first write a detailed chest x-ray report template with the findings section have different sections considering different pathological observations. Then, you should fill the template you have based on the given chest x-ray report. Your response should include both the blank template and the filled template following this format:\n\nTemplate:\n<blank template you designed>\n\nFilled Template:\n<template filled based on the given report>"
        }, {
            "role": "user",
            "content": ""
        }, {
            "role": "assistant",
            "content": "Template:\n<blank template you designed>\n\nFilled Template:\n<template filled based on the given report>"
        }, {
            "role": "user",
            "content": ""
        }
    ]
elif task == "comparison":
    messages = [
        {
            "role": "system",
            "content": "You are a clinical assistant that helps manage the chest x-ray diagnosis reports of a patient. You will be given two chest x-ray reports in chronological order, and you are asked to rewrite the findings and impression sections of the second report but with more focus on comparison with the previous report. Your response should follow the format:\n\nRewritten Report:"
        }, {
            "role": "user",
            "content": ""
        }, {
            "role": "assistant",
            "content": "Possible Medical Conditions:\nMedical Examination:\nExamination Result:"
        }, {
            "role": "user",
            "content": ""
        }
    ]
else:
    raise NotImplementedError

labeler = ChexpertLabeler()

unique_id = 0
for patient_path in tqdm((data_dir/"files").glob("p*/p*")):
    patient_id = patient_path.name
    for study_path in patient_path.glob("s*"):
        study_id = study_path.name
        out_file_name = os.path.join(out_dir, f"{study_id}.txt")
        if os.path.exists(out_file_name):
            continue
            # generated_label = labeler.get_label(parse_report_raw(out_file_name))
            # gt_label = studyid2chexpert[study_id]
            # if check_generated(task, generated_label=generated_label, gt_label=gt_label):
            #     continue
        image_path_list = [str(path)[len(str(data_dir))+1:] for path in list(study_path.glob("*.jpg"))]
        image_label_list = [dicomid2label[path.split('/')[-1][:-4]] for path in image_path_list]
        image_paths = ','.join(image_path_list)
        image_labels = ','.join(image_label_list)
        report_path = data_dir/"files"/"reports"/patient_id[:3]/patient_id/f"{study_id}.txt"
        split = studyid2split[study_id[1:]]
        report, findings, impression = parse_report(report_path)
        if findings == "" or impression == "":
            continue  # Skip report without findings or impression section
        if task in ["correction", "template"]:
            messages[-1]["content"] = report
        elif task == "history":
            report_raw = parse_report_raw(report_path)
            messages[-1]["content"] = report_raw
        elif task == "comparison":
            previous_report_path = get_previous_report_path(report_path, metadata)
            if previous_report_path is None: continue
            prev_report, _, _ = parse_report(previous_report_path)
            if prev_report == "": continue
            messages[-1]["content"] = f"First Report:\n{prev_report}\n\nSecond Report:\n{report}"
        while True:
            try:
                response = openai.ChatCompletion.create(
                    messages=messages,
                    engine="gpt-35-turbo",
                    temperature=0.7,
                    max_tokens=350,
                    top_p=0.95,
                    frequency_penalty=0,
                    presence_penalty=0,
                    stop=None,
                    request_timeout=15,
                )
            except (APIConnectionError, Timeout, RateLimitError, APIError, ServiceUnavailableError) as e:
                print(e)
                time.sleep(5)
                continue
            break

        response = response["choices"][0]["message"]["content"]
        if task == "correction":
            report_pos = response.find("Report#:")
            instructions_pos = response.find("Instructions#:")
            if report_pos < 0 or instructions_pos < 0:
                continue
            incorrect_report = response[report_pos:instructions_pos].replace("Report#:","")
            instructions = response[instructions_pos:].replace("Instructions#:","")
            out_file = open(out_file_name, 'w')
            out_file.write("INCORRECT REPORT:\n" + incorrect_report.strip("\n "))
            out_file.write("\n\nINSTRUCTIONS:\n" + instructions.strip("\n "))
            out_file.write("\n\nCORRECT REPORT:\n" + report.strip("\n ") + '\n')
            out_file.close()
            print(incorrect_report, instructions, sep='\n')
        elif task == "history":
            med_cond_pos = response.find("Possible Medical Conditions:")
            med_exam_pos = response.find("Possible Medical Examinations:")
            exam_res_pos = response.find("Examples of Examination Results:")
            if (med_cond_pos < 0 or med_exam_pos < 0) and exam_res_pos < 0:
                continue
            if med_cond_pos < 0 or med_exam_pos < 0:
                med_cond = "None"
            else:
                med_cond = response[med_cond_pos:med_exam_pos].replace("Possible Medical Conditions:", "")
                # med_cond = remove_leakage(med_cond, filter_words)
            if exam_res_pos < 0:
                exam_res = "None"
            else:
                exam_res = response[exam_res_pos:].replace("Examples of Examination Results:", "")
                # exam_res = remove_leakage(exam_res, filter_words)
            med_cond = med_cond.strip("\n ")
            exam_res = exam_res.strip("\n ")
            if med_cond.strip(".") == "None" and exam_res.strip(".") == "None":
                continue
            generated_label = labeler.get_label(med_cond + '\n' + exam_res)
            gt_path = studyid2path[study_id]
            _, gt_findings, _ = parse_report(data_dir / gt_path)
            gt_label = labeler.get_label(gt_findings)
            if not check_leakage(task, generated_label=generated_label, gt_label=gt_label):
                continue
            out_file = open(out_file_name, 'w')
            out_file.write("MEDICAL CONDITIONS:\n" + med_cond.strip("\n "))
            out_file.write("\n\nEXAM RESULTS:\n" + exam_res.strip("\n ") + '\n')
            out_file.close()
            print(med_cond, exam_res, sep='\n')
        elif task == "template":
            template_pos = response.find("Template:")
            report_pos = response.find("Filled Template:")
            if template_pos < 0 or report_pos < 0:
                continue
            template = response[template_pos:report_pos].replace("Template:", "")
            filled_template = response[report_pos:].replace("Filled Template:", "")
            out_file = open(out_file_name, 'w')
            out_file.write("TEMPLATE:\n" + template.strip("\n "))
            out_file.write("\n\nFILLED TEMPLATE:\n" + filled_template.strip("\n ") + '\n')
            out_file.close()
            print(template)
        elif task == "comparison":
            rewritten_report_pos = response.find("Rewritten Report:")
            if rewritten_report_pos < 0:
                continue
            response = response[rewritten_report_pos:].replace("Rewritten Report:", "")
            out_file = open(out_file_name, 'w')
            out_file.write("REPORT:\n" + response.strip("\n "))
            out_file.close()
            print(response)
        unique_id += 1
        print(unique_id)
    #     if unique_id >= 1000:
    #         break
    # if unique_id >= 1000:
    #     break

