import os
import sys
import base64
import json
import time
import random
import openai
import pandas as pd
import argparse
from tqdm import tqdm
from pathlib import Path
from datetime import datetime

from mimic_cxr_utils import *


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", required=True, choices=["correction", "comparison", "template", "history"])
    parser.add_argument("--split", required=True, choices=["train", "test"])
    parser.add_argument("--data_dir", type=str, default=None)
    parser.add_argument("--output_path", type=str, default=None)
    parser.add_argument("--get_reason", action="store_true")
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    data_dir = args.data_dir
    split = args.split
    task = args.task
    get_reason = args.get_reason
    out_instructions = f"{task}_instructions_{split}.json" if args.output_path is None else args.output_path
    reports_path = Path(data_dir) / "files" / f"reports_{task}" if args.data_dir is None else args.data_dir

    openai.api_type = "azure"
    openai.api_base = os.getenv("OPENAI_API_BASE")
    openai.api_version = "2023-07-01-preview"
    openai.api_key = os.getenv("OPENAI_API_KEY")
    engine = "gpt-4-32k"

    data_dir = Path(data_dir)
    dicomid2label = create_id2label_dict(data_dir / "mimic-cxr-2.0.0-metadata.csv")
    studyid2split = create_id2split_dict(data_dir / "mimic-cxr-2.0.0-split.csv")
    studyid2images = create_id2images_dict(data_dir / "mimic-cxr-2.0.0-metadata.csv")
    studyid2path = create_id2path_dict(data_dir / "mimic-cxr-2.0.0-metadata.csv")
    metadata = pd.read_csv(data_dir / "mimic-cxr-2.0.0-metadata.csv")
    instructions_json, images_json, context_json = {}, {}, {}
    instructions_json["meta"] = {"version": "0.0.1", "time": datetime.today().strftime('%Y-%m-%d'), "author": "annonymous"}
    instructions_json["data"] = {}

    context = []
    reason_list = []


def process_annotation(task, annotation):
    valid_ids = []
    df = pd.read_csv(annotation)
    mask = df['record_path'].apply(lambda x: task in x)
    df = df[mask]
    df = df[df["A1"] == "yes"]
    for item in df["record_path"]:
        study_id = Path(item).stem.split('\\')[-1]
        assert study_id[0] == 's' and len(study_id) == 9
        valid_ids.append(study_id)
    return valid_ids


def query_gpt(messages):
    while True:
        try:
            response = openai.ChatCompletion.create(
                messages=messages,
                engine=engine,
                temperature=0.7,
                max_tokens=800,
                top_p=0.95,
                frequency_penalty=0,
                presence_penalty=0,
                stop=None,
                request_timeout=20,
            )
        except Exception as e:
            print(e)
            time.sleep(5)
            continue
        break
    return response


def summarize(reason):
    messages = [{
        "role": "user",
        "content": f"Reason: {reason}. Please summarize this this reason of error in a short phrase. Only output the summarized phrase."
    }]
    reason = query_gpt(messages)["choices"][0]["message"]["content"]
    return reason


def summarize_all(reason_list):
    messages = [{
        "role": "user",
        "content": f"List of reasons: \"{','.join(reason_list)}\" Given this list of phrases of reasons, first categorize them into a few different categories, then give the count of examples for each category."
    }]
    result = query_gpt(messages)["choices"][0]["message"]["content"]
    return result


def parse_generated(path, task):
    report_string = ""
    f = open(path, 'r')
    for line in f.readlines():
        line = line.strip('\t ')
        report_string += line
    f.close()
    if task == "template":
        template_idx = report_string.find("TEMPLATE:")
        report_idx = report_string.find("TEMPLATED REPORT:")
        gt_idx = report_string.find("GT:")
        reason_idx = report_string.find("Reason:")
        if template_idx < 0 or report_idx < 0:
            print("Parse template fail, no 'TEMPLATE:' or 'TEMPLATED REPORT:'", path)
            return None
        template_str = report_string[template_idx:report_idx].replace("TEMPLATE:", "").strip("\n\t ")
        report_str = report_string[report_idx:gt_idx].replace("TEMPLATED REPORT:", "").strip("\n\t ")
        reason_str = report_string[reason_idx:].replace("Reason:", "").strip("\n\t ")
        if template_str == report_str or template_str == "" or report_str == "":
            print("Parse template fail")
            print(template_str, report_str)
            if get_reason and reason_str != "":
                print(reason_str)
                reason_list.append(summarize(reason_str))
            return None
        return {"template": template_str, "report": report_str}
    elif task == "correction":
        incorrect_idx = report_string.find("INCORRECT REPORT:")
        instruction_idx = report_string.find("INSTRUCTIONS:")
        correct_idx = report_string.find("GT:", instruction_idx, -1)
        if incorrect_idx < 0 or correct_idx < 0 or instruction_idx < 0:
            print("No incorrect/correct/instruction key", study_id)
            exit()
            return None
        incorrect_str = report_string[incorrect_idx:instruction_idx].replace("INCORRECT REPORT:", "").strip('\n\t ')
        correct_str = report_string[correct_idx:].replace("GT:", "").strip('\n\t ')
        fix_str = report_string[instruction_idx:correct_idx].replace("INSTRUCTIONS:", "").strip('\n\t ')
        if len(incorrect_str) == 0 or len(fix_str) == 0 or len(correct_str) == 0:
            print(f"incorrect_str/fix_str/correct_str len=0", study_id)
            exit()
            return None
        return {"incorrect_report": incorrect_str, "instruction": fix_str, "correct_report": correct_str}
    elif task == "history":
        history_idx = report_string.find("Medical History:")
        test_idx = report_string.find("Medical Tests:")
        gt_idx = report_string.find("GT:")
        reason_idx = report_string.find("Reason:")
        if history_idx < 0 or test_idx < 0: return None
        history_str = report_string[history_idx:test_idx].replace("Medical History:", "").strip("\n\t ")
        test_str = report_string[test_idx:gt_idx].replace("Medical Tests:", "").strip("\n\t ")
        reason_str = report_string[reason_idx:].replace("Reason:", "").strip("\n\t ")
        if get_reason and reason_str != "":
            print(reason_str)
            reason_list.append(summarize(reason_str))
            return None
        # test_str = clean_and_sample_history(test_str)
        # history_str = clean_and_sample_history(history_str)
        history_str = (history_str + '\n' + test_str).strip("\n\t ")
        return {"history": history_str}
    else:
        raise NotImplementedError


def get_instrucion_and_answer(path, task):
    parsed_data = parse_generated(path, task)
    if parsed_data is None:
        print("parsed None", path)
        return None, None
    if task == "correction":
        incorrect_str = parsed_data["incorrect_report"]
        fix_str = parsed_data["instruction"]
        correct_str = parsed_data["correct_report"]
        context.append(incorrect_str)
        if random.random() > 0.5:
            instruction = random.choice(correction_instructions) + '\n' + "Report: {incorrect_report}"
        else:
            instruction = "Report: {incorrect_report}" + '\n' + random.choice(correction_instructions)
        instruction = instruction.format(incorrect_report=incorrect_str, instructions=fix_str)
        return instruction, correct_str
    elif task == "history":
        history_str = parsed_data["history"]
        context.append(history_str)
        correct_report_path = data_dir / studyid2path[path.name[:-4]]
        correct_report, correct_findings, _ = parse_report(correct_report_path)
        if correct_findings.strip() == "":
            if correct_report.strip() == "":
                print("Parse failed", correct_report_path)
                return None, None
            else:
                correct_findings = correct_report
        instruction = random.choice(history_instructions).format(history=history_str)
        return instruction, correct_findings
    elif task == "template":
        report_str = parsed_data["report"]
        template_str = parsed_data["template"]
        context.append(template_str)
        instruction = random.choice(template_instructions).format(template=template_str)
        return instruction, report_str
    elif task == "comparison":
        current_report_path = studyid2path[path.name[:-4]]
        previous_report_path = get_previous_report_path(Path(current_report_path), metadata)
        if previous_report_path is None: return None, None
        _, current_findings, _ = parse_report(data_dir / current_report_path)
        _, previous_findings, _ = parse_report(data_dir / previous_report_path)
        if len(previous_findings) == 0 or len(current_findings) == 0: return None, None
        instruction = random.choice(comparison_instructions).format(previous_report=previous_findings)
        return instruction, current_findings
    else:
        raise NotImplementedError


if __name__ == "__main__":
    count = 0
    if task == "comparison":
        for patient_path in tqdm((data_dir / "files").glob("p*/p*")):
            patient_id = patient_path.name
            for study_path in patient_path.glob("s*"):
                study_id = study_path.name
                if studyid2split[study_id[1:]] != split:
                    continue
                image_ids_list = studyid2images[study_id[1:]]
                image_label_list = [dicomid2label[image_id] for image_id in image_ids_list]
                image_ids_list = [image_ids_list[i] for i in range(len(image_ids_list)) if
                                  image_label_list[i] in ["PA", "AP"]]
                if len(image_ids_list) == 0:
                    continue
                current_report_path = studyid2path[study_id]
                previous_report_path = get_previous_report_path(data_dir / current_report_path, metadata)
                if previous_report_path is None:
                    continue
                _, current_findings, _ = parse_report(data_dir / current_report_path)
                _, previous_findings, _ = parse_report(data_dir / previous_report_path)
                if len(previous_findings) == 0 or len(current_findings) == 0:
                    continue
                context.append(previous_findings)
                instruction = random.choice(comparison_instructions).format(previous_report=previous_findings)
                instructions_json["data"][study_id] = {
                    "instruction": instruction.lower(),
                    "answer": current_findings.lower(),
                    "image_ids": image_ids_list,
                }
                count += 1
            #     if count >= 20:
            #         break
            # if count >= 20:
            #     print(f"{count} data processed")
            #     break

    else:
        print(f"Total {len(os.listdir(reports_path))} files for {task} task")
        for report_path in tqdm(reports_path.iterdir(), total=len(os.listdir(reports_path))):
            study_id = report_path.name[:-4]
            if studyid2split[study_id[1:]] != split: continue
            image_ids_list = studyid2images[study_id[1:]]
            image_label_list = [dicomid2label[image_id] for image_id in image_ids_list]
            image_ids_list = [image_ids_list[i] for i in range(len(image_ids_list)) if
                              image_label_list[i] in ["PA", "AP", "NA"]]
            if len(image_ids_list) == 0:
                print("No AP/PA image", study_id)
                continue
            instruction, answer = get_instrucion_and_answer(report_path, task)
            if instruction is None or answer is None:
                # print("Fail to parse report", report_path)
                continue
            instructions_json["data"][study_id] = {
                "instruction": instruction.lower(),
                "answer": answer.lower(),
                "image_ids": image_ids_list,
                "rel_ins_ids": []
            }
            # count += 1
            # if get_reason and count > 200:
            #     break

    if not get_reason:
        out_instructions = open(out_instructions, 'w')
        json.dump(instructions_json, out_instructions, indent=2)
        out_instructions.close()
        print(f"{count} data processed for {task} task")
    else:
        result = summarize_all(reason_list)
        print(result)
        print(len(reason_list))

    out_context = open(out_context, 'w')
    json.dump(context, out_context)
    out_context.close()
