import csv
import math
import re
import random
import os.path
import sys
from pathlib import Path
import pandas as pd
from enum import Enum, auto

# sys.path.insert(0, str(Path(__file__).parent.parent.parent))
# from chexpert_labeler import ChexpertLabeler

CATEGORIES = ["No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly",
              "Lung Lesion", "Lung Opacity", "Edema", "Consolidation",
              "Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion",
              "Pleural Other", "Fracture", "Support Devices"]

correction_instructions = [
    "Revise the medical report based on the chest x-ray radiographs and these instructions: {instructions}",
    "Fix this incorrect medical report of these chest x-ray images using these guidelines: {instructions}",
    "Update the medical report of the given chest x-ray images with these changes: {instructions}",
    "Based on the given chest x-ray images, edit this medical report following these suggestions: {instructions}",
    "Apply these revisions to the given medical report of the chest x-ray radiographs: {instructions}",
    "Refine the given medical report of the chest x-ray images with these improvements: {instructions}",
    "Enhance the medical report by incorporating these notes: {instructions}",
    "Revise the medical report based on the chest x-ray radiographs considering these recommendations: {instructions}",
    "Fix the given incorrect medical report based on the chest x-ray images and these instructions: {instructions}",
]

history_instructions = [
    "The patient has the following medical conditions and exam result: \n{history}\n Examine the given chest x-ray images and patient's medical conditions, and write a medical report detailing the findings:",
    "The patient has following information: \n{history} Review the attached chest x-ray images and relevant patient information to write a detailed medical report:",
    "Medical conditions of the patient: \n{history}\n Based on the chest x-ray images and patient's medical details, draft a detailed diagnostic medical report:",
    "Given that the patient has the following medical history: \n{history}\n, write a detailed medical report for the patient based on the given medical history and chest x-ray radiographs:",
    "The patient has the following medical record: \n{history}\n, combine with the chest x-ray images, write a detailed diagnostic medical report for the patient:",
]

template_instructions = [
    "Please act as a radiologist and write a diagnostic radiology report for the patient based on their chest radiographs, the format should follow the template:\n{template}",
    "Write a diagnostic radiology report for the patient based on their chest radiographs following the given template:\n{template}",
    "Please fill the following chest x-ray radiology template based on the given chest x-ray images:\n{template}",
    "Template:\n{template}\nPlease fill this chest x-ray diagnostic report template based on the give chest x-ray radiographs.",
    "Template:\n{template}\nGiven this template, please fill it after investigating the given chest x-ray radiology report.",
    "Referencing the given chest x-ray images, please fill the following chest x-ray report template:\n{template}",
]

comparison_instructions = [
    "Previous medical report:\n{previous_report}\nAct as a radiologist and write a diagnostic radiology report for the patient based on their chest radiographs and previous medical report:",
    "Medical report from the last visit:\n{previous_report}\nPlease write a diagnostic radiology report for the patient based on their chest radiographs considering the report from last visit:",
    "The patient has a previous visit with the report:\n{previous_report}\nConsidering the patient's previous report, please write a chest x-ray report for the patient based on the chest x-ray images:",
    "Act as a radiologist and write a diagnosis chest x-ray report by inspecting patient's chest x-ray images and previous report. The patient's previous report:\n{previous_report}",
    "Please write a diagnosis chest x-ray report by investigating the given chest x-ray images, referencing the patient's previous report:\n{previous_report}",
]



ignore_keys = {
    "EXAM",
    "CLINICAL",
    "COMPARISON",
    "Comparison:",
    "REASON",
    "HISTORY",
    "TECHNIQUE",
    "STUDY",
    "COMMENT",
    "NOTIFICATION",
    "PATIENT",
    "DATE",
    "TYPE",
    "NOTE",
    "REFERENCE",
    "ADDENDUM",
    "INDICATION",
    "ACCESSION NUMBER",
    "CC",
    "OMPARISON",
    "RECOMMENDATION",
    "FDG PET-CT",
    "PROCEDURE",
    "PRELIMINARY",
    "TIME",
    "OPINION",
    "NDICATION",
    "CLINCAL",
    "COMPARE",
    "COMPARRISON",
    "3HISTORY",
    "]HISTORY",
    "INDCATION",
    "IDICATION",
    "COMPARISION",
    "COMAPRISON",
    "RECOMMMENDATIONS",
    "RECOMMEDATIONS",
    "NOTFICATIONS",
    "COMPARSION",
    "NDICATION"
}

findings_keys = {
	"finding:",
    "findings:",
    "report:",
	"finidngs",
	"findigns",
	"finsings",
	"findnings",
	"findngs",
	"findins",
	"fingdings",
	"findindgs",
	"findoings",
	"fidings",
	"findgings",
    # "two views of the chest:"
}

impression_keys = {
    "impession",
	"imprssion",
	"impression",
	"imperssion",
	"conlcusion",
	"coclusion",
	"mpression",
	"impresion",
	"imoression",
	"impressoin",
	"impresson",
	"imprression",
	"iimpression",
	"conclusion",
	"fimpression",
	"impressio___",
}

filter_words = ["cardiomegaly", "edema", "consolidation", "pneumonia", "atelectasis", "pneumothorax", "effusion", "report", "report.", "ct"]

golden_keys = findings_keys.union(impression_keys)
findings_keys = tuple(findings_keys)
impression_keys = tuple(impression_keys)
ignore_keys = tuple(ignore_keys)

def create_id2label_dict(path):
    """
    Create a dict of image dicom_id to label (PA, AP, LAT, etc.)
    """
    id2label = {}
    f = open(path, 'r', newline='')
    csv_reader = csv.reader(f, delimiter=',', quotechar='"')
    for line in csv_reader:
        id2label[line[0]] = line[4] if len(line[4]) > 0 else "NA"
    id2label.pop("dicom_id")
    f.close()
    return id2label

def create_id2split_dict(path):
    """
    Create a dict of study_id to split (train/validate/test)
    """
    id2split = {}
    f = open(path, 'r', newline='')
    csv_reader = csv.reader(f, delimiter=',', quotechar='"')
    for line in csv_reader:
        if line[1] in id2split.keys():
            assert id2split[line[1]] == line[3]
        id2split[line[1]] = line[3]
    id2split.pop("study_id")
    f.close()
    return id2split

def create_id2images_dict(path):
    """
    Create a dict of image study_id to list of images
    """
    id2images = {}
    f = open(path, 'r', newline='')
    csv_reader = csv.reader(f, delimiter=',', quotechar='"')
    for line in csv_reader:
        image_id = line[0]
        study_id = line[2]
        if id2images.get(study_id) is not None:
            id2images[study_id].append(image_id)
        else:
            id2images[study_id] = [image_id]
    id2images.pop("study_id")
    f.close()
    return id2images

def create_id2path_dict(path):
    """
    Create a dict of study_id to report path
    """
    metadata = pd.read_csv(path)
    id2path = {}
    for _, row in metadata.iterrows():
        study_id = f"s{row['study_id']}"
        if study_id in id2path.keys():
            continue
        patient_id = row["subject_id"]
        id2path[study_id] = os.path.join("files", "reports", f"p{str(patient_id)[:2]}", f"p{str(patient_id)}", f"{study_id}.txt")
    return id2path

def create_id2imagepath_dict(path):
    """
    Create a dict of study_id to image path
    """
    metadata = pd.read_csv(path)
    id2path = {}
    for _, row in metadata.iterrows():
        study_id = f"s{row['study_id']}"
        if study_id in id2path.keys():
            continue
        patient_id = row["subject_id"]
        id2path[study_id] = os.path.join("files", f"p{str(patient_id)[:2]}", f"p{str(patient_id)}", study_id)
    return id2path

keys = {}
key_count = {}
def key_stats(path, line):
    if ':' in line:
        line_l = line.split(':')
        if line_l[0].isupper():
            # if line.startswith(("WET","AP","PA","PORTABLE","CHEST","FRONTAL")): continue
            if line.startswith(tuple(ignore_keys)): return
            # if line.startswith(golden_keys): continue
            keys[line_l[0]] = (line_l[1], path)
            try:
                key_count[line_l[0]] += 1
            except:
                key_count[line_l[0]] = 1

class State(Enum):
    INIT = auto()        # Before "FINAL REPORT", ignore everything
    REPORT = auto()      # "FINAL REPORT" line reached, subsequent lines will write to report_str only, but will be discarded if ignore_keys encountered
    IGNORE = auto()      # Subsequent lines will be ignored
    FINDINGS = auto()    # Subsequent lines will append to findings_str and report_str
    IMPRESSION = auto()  # Subsequent lines will append to impression_str and report_str
    END = auto()         # Stop parsing subsequent lines


def parse_report(path):
    report_str, findings_str, impression_str = "", "", ""
    state = State.INIT
    f = open(path, 'r')
    for line in f.readlines():
        line = line.strip('\n\t ')
        if state == State.INIT:
            if line == "FINAL REPORT":
                state = State.REPORT
        elif state == State.REPORT:
            if line.startswith(ignore_keys):
                report_str = ""
                state = State.IGNORE
            elif line.lower().startswith(findings_keys):
                report_str = line
                findings_str = line
                state = State.FINDINGS
            elif line.lower().startswith(impression_keys):
                report_str = line
                impression_str = line
                state = State.IMPRESSION
            else:
                report_str += ' ' + line
        elif state == State.FINDINGS:
            if line.startswith(ignore_keys):
                state = State.END
            elif line.lower().startswith(impression_keys):
                report_str += ' ' + line
                impression_str = line
                state = State.IMPRESSION
            else:
                report_str += ' ' + line
                findings_str += ' ' + line
        elif state == State.IMPRESSION:
            if line.startswith(ignore_keys):
                state = State.END
            else:
                report_str += ' ' + line
                impression_str += ' ' + line
        elif state == State.IGNORE:
            if line.lower().startswith(findings_keys):
                report_str = line
                findings_str = line
                state = State.FINDINGS
            elif line.lower().startswith(impression_keys):
                report_str = line
                impression_str = line
                state = State.IMPRESSION
            elif line == "":
                state = State.REPORT
        elif state == State.END:
            break
    findings_str = findings_str.strip("\n\t ")
    impression_str = impression_str.strip("\n\t ")
    report_str = report_str.strip("\n\t ")
    f.close()
    return report_str, findings_str, impression_str

def parse_report_raw(path):
    report_str = ""
    f = open(path, 'r')
    for line in f.readlines():
        line = line.strip('\n\t ')
        report_str += line + ' '
    f.close()
    return report_str


def get_report_path(study_id, metadata):
    study_id = study_id.replace("s", "")
    for id, row in metadata.iterrows():
        if str(row["study_id"]) == study_id:
            patient_id = row["subject_id"]
            return Path("files") / "reports" / f"p{str(patient_id)[:2]}" / f"p{str(patient_id)}" / f"s{study_id}.txt"
    return None

def get_previous_report_path(report_path, metadata):
    prev_study_row = None
    curr_date = metadata[metadata["study_id"] == int(report_path.stem[1:])]["StudyDate"].iloc[0]
    curr_time = metadata[metadata["study_id"] == int(report_path.stem[1:])]["StudyTime"].iloc[0]
    prev_date, prev_time = 0, 0
    all_studies = metadata[metadata["subject_id"] == int(report_path.parent.stem[1:])]
    for id, row in all_studies.iterrows():
        if prev_date < row["StudyDate"] < curr_date:
            prev_study_row = row
            prev_date = row["StudyDate"]
        elif row["StudyDate"] == curr_date:
            if prev_time < row["StudyDate"] < curr_time:
                prev_study_row = row
                prev_time = row["StudyTime"]
                prev_date = row["StudyDare"]
    if prev_study_row is None:
        return None
    prev_report_path = report_path.parent / f"s{prev_study_row['study_id']}.txt"
    return prev_report_path

def remove_leakage(pred, words):
    pred_list = re.split("\n,", pred)
    valid_list = []
    for sentence in pred_list:
        if len([w for w in sentence.lower().split() if w in words]) == 0:
            valid_list.append(sentence)
    valid = '\n'.join(valid_list)
    if valid.strip("\n ") in ["", "None", "none"]:
        valid = "None"
    return valid

def create_id2chexpert_dict(path):
    # Maybe inaccurate label in mimic-cxr-2.0.0-chexpert.csv
    studyid2chexpert = {}
    df = pd.read_csv(path)
    for idx, row in df.iterrows():
        study_id = f"s{int(row['study_id'])}"
        row = row.to_dict()
        labels = []
        for cat in CATEGORIES:
            if row[cat] == 1:
                labels.append(1)
            else:
                labels.append(0)
        studyid2chexpert[study_id] = labels
    return studyid2chexpert



def check_leakage(task, generated_label=None, gt_label=None):
    if task == "history":
        assert generated_label is not None and gt_label is not None
        for category, label in gt_label.items():
            if not math.isnan(label) and label == generated_label[category]:
                return False
    else:
        raise NotImplementedError
    return True


def clean_and_sample_history(history):
    history_list = history.split('\n')
    new_history_list = []
    for history_str in history_list:
        if "none" in history_str.lower(): continue
        history_str = re.sub("^\d+\.", "", history_str).strip()
        if history_str.startswith('-'): history_str = history_str[1:].strip()
        new_history_list.append(history_str)
    new_history = '\n'.join(random.sample(new_history_list, min(len(new_history_list), 2)))
    return new_history


# Input path of generated txt, output dict
# template: {"template": "", "report":, ""}
# correction: {"incorrect_report": "", "instruction": "", correct_report: ""}
# history: {"history": }
def parse_generated(path, task):
    report_string = ""
    f = open(path, 'r')
    for line in f.readlines():
        line = line.strip('\t ')
        report_string += line
    f.close()
    if task == "template":
        template_idx = report_string.find("TEMPLATE:")
        report_idx = report_string.find("TEMPLATED REPORT:")
        gt_idx = report_string.find("GT:")
        reason_idx = report_string.find("Reason:")
        if template_idx < 0 or report_idx < 0: return None
        template_str = report_string[template_idx:report_idx].replace("TEMPLATE:", "").strip("\n\t ")
        report_str = report_string[report_idx:gt_idx].replace("TEMPLATED REPORT:", "").strip("\n\t ")
        reason_str = report_string[reason_idx:].replace("Reason:", "").strip("\n\t ")
        if template_str == report_str or template_str == "" or report_str == "" or reason_str != "":
            return None
        return {"template": template_str, "report": report_str}
    elif task == "correction":
        incorrect_idx = report_string.find("INCORRECT REPORT:")
        instruction_idx = report_string.find("INSTRUCTIONS:")
        correct_idx = report_string.find("GT:", instruction_idx, -1)
        if incorrect_idx < 0 or correct_idx < 0 or instruction_idx < 0:
            return None
        incorrect_str = report_string[incorrect_idx:instruction_idx].replace("INCORRECT REPORT:", "").strip('\n\t ')
        correct_str = report_string[correct_idx:].replace("GT:", "").strip('\n\t ')
        fix_str = report_string[instruction_idx:correct_idx].replace("INSTRUCTIONS:", "").strip('\n\t ')
        if len(incorrect_str) == 0 or len(fix_str) == 0 or len(correct_str) == 0:
            return None
        return {"incorrect_report": incorrect_str, "instruction": fix_str, "correct_report": correct_str}
    elif task == "history":
        history_idx = report_string.find("Medical History:")
        test_idx = report_string.find("Medical Tests:")
        gt_idx = report_string.find("GT:")
        reason_idx = report_string.find("Reason:")
        if history_idx < 0 or test_idx < 0: return None
        history_str = report_string[history_idx:test_idx].replace("Medical History:", "").strip("\n\t ")
        test_str = report_string[test_idx:gt_idx].replace("Medical Tests:", "").strip("\n\t ")
        reason_str = report_string[reason_idx:].replace("Reason:", "").strip("\n\t ")
        if reason_str != "":
            return None
        # test_str = clean_and_sample_history(test_str)
        # history_str = clean_and_sample_history(history_str)
        history_str = (history_str + '\n' + test_str).strip("\n\t ")
        return {"history": history_str}
    else:
        raise NotImplementedError






