import os
import re
import time
import openai
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from openai import ChatCompletion, APIError
from openai.error import APIConnectionError, Timeout, RateLimitError, ServiceUnavailableError

from bimcv_utils import *

task = "comparison"

openai.api_type = "azure"
openai.api_base = "https://gcrgpt4aoai9c.openai.azure.com/"
openai.api_version = "2023-03-15-preview"
openai.api_key = os.getenv("OPENAI_API_KEY")

data_dir = Path("/scratch/datasets/BIMCV-COVID19-cIter_1_2-Negative_CT_only/")
data_path_pos = data_dir / "master_max_CT_EN_split_pos.csv"
data_path_neg = data_dir / "master_max_CT_EN_split_neg.csv"
data_df = pd.concat([pd.read_csv(data_path_pos), pd.read_csv(data_path_neg)])
out_dir = data_dir / f"reports_{task}"
out_dir = Path("tmp/")

if not os.path.isdir(out_dir):
    os.mkdir(out_dir)

filter_words = ["cardiomegaly", "edema", "consolidation", "pneumonia", "atelectasis", "pneumothorax", "effusion",]

if task == "correction":
    messages = [
        {
            "role":"system",
            "content":"You are a careless radiologist and will make some minor diagnosis error when reading a chest CT scan. You will be given a correct chest CT report with a findings section, and you will be asked to rewrite the report with a few diagnostic errors. Then provide the instructions of how to correct your wrong diagnosis."
        },{
            "role":"user",
            "content":"Normal lung parenchyma without consolidation or mass lesions. Centrilobular emphysema in upper lobes with small nodular opacities (4-6 mm). No lymphadenopathy or pleural effusion. Unremarkable cardiac and upper abdominal structures. Normal lung parenchyma. Centrilobular emphysema in upper lobes with small nodular opacities. No lymphadenopathy or pleural effusion. Unremarkable cardiac and upper abdominal structures."
        },{
            "role":"assistant",
            "content":"Report#: Normal lung parenchyma without consolidation or mass lesions. No emphysema. Pleural effusion in lower left lobe. Unremarkable cardiac and upper abdominal structures. Normal lung parenchyma. Lower left pleural effusion. Unremarkable cardiac and upper abdominal structures.\nInstructions#: Include centrilobular emphysema in upper lobes with small nodular opacities (4-6 mm). The patient has no pleural effusion."
        },{
            "role":"user",
            "content":""
        }
    ]
elif task == "history":
    messages = [
        {
            "role": "system",
            "content": "You are a clinical assistant that helps radiologists retrieve information of patient's other medical records and examination results. You will be given a chest CT diagnosis report, and you need to infer the patient's possible medical conditions or history and other medical examination the patient should have done. You need to first give several examples of their possible medical conditions based on the report diagnosis, then list a few the medical examinations that the patient should have done, and finally give a few examples of the patient's medical examination results that may lead to the diagnosis in the given report. List none if the report indicates no acute cardiopulmonary disease. Please make sure the inferred possible medical condition does not include any information described in the report findings or impression. Also make sure the examples of exam results are consistent with the findings described in the given report. This is very important to my career."
        }, {
            "role": "user",
            "content": ""
        }, {
            "role": "assistant",
            "content": "Possible Medical Conditions:\nPossible Medical Examinations:\nExamples of Examination Results:"
        }, {
            "role": "user",
            "content": ""
        }
    ]
elif task == "template":
    messages = [
        {
            "role": "system",
            "content": "You are a clinical assistant helping radiologists write detailed and well-formatted chest CT reports. Without referencing the given report, you should first write a detailed chest CT report template with the findings section have different sections considering different pathological observations. Then, you should fill the template you have based on the given chest CT report. Your response should include both the blank template and the filled template following this format:\n\nTemplate:\n<blank template you designed>\n\nFilled Template:\n<template filled based on the given report>"
        }, {
            "role": "user",
            "content": ""
        }, {
            "role": "assistant",
            "content": "Template:\n<blank template you designed>\n\nFilled Template:\n<template filled based on the given report>"
        }, {
            "role": "user",
            "content": ""
        }
    ]
elif task == "comparison":
    messages = [
        {
            "role": "system",
            "content": "You are a clinical assistant that helps manage the chest x-ray diagnosis reports of a patient. You will be given two chest x-ray reports in chronological order, and you are asked to rewrite the findings and impression sections of the second report but with more focus on comparison with the previous report. Your response should follow the format:\n\nRewritten Report:"
        }, {
            "role": "user",
            "content": ""
        }, {
            "role": "assistant",
            "content": "Possible Medical Conditions:\nMedical Examination:\nExamination Result:"
        }, {
            "role": "user",
            "content": ""
        }
    ]
else:
    raise NotImplementedError

unique_id = 0
for idx, row in data_df.iterrows():
    report_id = row["ReportID"]
    out_file_path = out_dir / f"{report_id}.txt"
    if os.path.exists(out_file_path):
        continue
    reports = row["Report"]
    report = list(filter(lambda x: len(x.strip()) > 1, reports[1:-1].split("'")))[-1].strip("\n\t,'[] ")
    if task in ["correction", "template", "history"]:
        messages[-1]["content"] = report
    elif task == "comparison":
        patient_id = row["PatientID"]
        patient_data_pos_path = data_dir/"covid19_pos"/patient_id/f"{patient_id}_sessions.tsv"
        patient_data_neg_path = data_dir/"covid19_neg"/patient_id/f"{patient_id}_sessions.tsv"
        patient_data_pos, patient_data_neg = None, None
        if os.path.exists(patient_data_pos_path):
            patient_data_pos = pd.read_csv(patient_data_pos_path, delimiter='\t')
        if os.path.exists(patient_data_neg_path):
            patient_data_neg = pd.read_csv(patient_data_neg_path, delimiter='\t')
        assert patient_data_pos is not None or patient_data_neg is not None
        patient_data = pd.concat([patient_data_pos, patient_data_neg])
        previous_report_id = get_previous_report_id(report_id, patient_data)
        if previous_report_id is None:
            continue
        previous_row = data_df[data_df["ReportID"] == previous_report_id]
        if len(previous_row) == 0:
            continue
        previous_report = list(filter(lambda x: len(x.strip()) > 1, previous_row.iloc[0]["Report"][1:-1].split("'")))[-1].strip("\n\t,'[] ")
        messages[-1]["content"] = f"First Report:\n{previous_report}\n\nSecond Report:\n{report}"
    while True:
        try:
            response = openai.ChatCompletion.create(
                messages=messages,
                engine="gpt-35-turbo",
                temperature=0.7,
                max_tokens=350,
                top_p=0.95,
                frequency_penalty=0,
                presence_penalty=0,
                stop=None,
                request_timeout=15,
            )
        except (APIConnectionError, Timeout, RateLimitError, APIError, ServiceUnavailableError) as e:
            print(e)
            time.sleep(5)
            continue
        break

    response = response["choices"][0]["message"]["content"]
    if task == "correction":
        report_pos = response.find("Report#:")
        instructions_pos = response.find("Instructions#:")
        if report_pos < 0 or instructions_pos < 0:
            continue
        incorrect_report = response[report_pos:instructions_pos].replace("Report#:","")
        instructions = response[instructions_pos:].replace("Instructions#:","")
        out_file = open(out_file_path, 'w')
        out_file.write("INCORRECT REPORT:\n" + incorrect_report.strip("\n "))
        out_file.write("\n\nINSTRUCTIONS:\n" + instructions.strip("\n "))
        out_file.write("\n\nCORRECT REPORT:\n" + report.strip("\n ") + '\n')
        out_file.close()
    elif task == "history":
        med_cond_pos = response.find("Possible Medical Conditions:")
        med_exam_pos = response.find("Possible Medical Examinations:")
        exam_res_pos = response.find("Examples of Examination Results:")
        if (med_cond_pos < 0 or med_exam_pos < 0) and exam_res_pos < 0:
            continue
        if med_cond_pos < 0 or med_exam_pos < 0:
            med_cond = "None"
        else:
            med_cond = response[med_cond_pos:med_exam_pos].replace("Possible Medical Conditions:", "")
            # med_cond = remove_leakage(med_cond, filter_words)
        if exam_res_pos < 0:
            exam_res = "None"
        else:
            exam_res = response[exam_res_pos:].replace("Examples of Examination Results:", "")
            # exam_res = remove_leakage(exam_res, filter_words)
        if med_cond.strip("\n .") == "None" and exam_res.strip("\n .") == "None":
            continue
        out_file = open(out_file_path, 'w')
        out_file.write("MEDICAL CONDITIONS:\n" + med_cond.strip("\n "))
        out_file.write("\n\nEXAM RESULTS:\n" + exam_res.strip("\n ") + '\n')
        out_file.close()
    elif task == "template":
        template_pos = response.find("Template:")
        report_pos = response.find("Filled Template:")
        if template_pos < 0 or report_pos < 0:
            continue
        template = response[template_pos:report_pos].replace("Template:", "")
        filled_template = response[report_pos:].replace("Filled Template:", "")
        out_file = open(out_file_path, 'w')
        out_file.write("TEMPLATE:\n" + template.strip("\n "))
        out_file.write("\n\nFILLED TEMPLATE:\n" + filled_template.strip("\n ") + '\n')
        out_file.close()
    unique_id += 1
    print(unique_id)
    # if unique_id >= 1000:
    #     break
