import csv
from enum import Enum, auto

from tqdm import tqdm
from pathlib import Path


data_dir = "/data/datasets/MIMIC-CXR/"
out_path = "/data/datasets/MIMIC-CXR/processed100.csv"


ignore_keys = {
    "EXAM",
    "CLINICAL",
    "COMPARISON",
    "Comparison:",
    "REASON",
    "HISTORY",
    "TECHNIQUE",
    "STUDY",
    "COMMENT",
    "NOTIFICATION",
    "PATIENT",
    "DATE",
    "TYPE",
    "NOTE",
    "REFERENCE",
    "ADDENDUM",
    "INDICATION",
    "ACCESSION NUMBER",
    "CC",
    "OMPARISON",
    "RECOMMENDATION",
    "FDG PET-CT",
    "PROCEDURE",
    "PRELIMINARY",
    "TIME",
    "OPINION",
    "NDICATION",
    "CLINCAL",
    "COMPARE",
    "COMPARRISON",
    "3HISTORY",
    "]HISTORY",
    "INDCATION",
    "IDICATION",
    "COMPARISION",
    "COMAPRISON",
    "RECOMMMENDATIONS",
    "RECOMMEDATIONS",
    "NOTFICATIONS",
    "COMPARSION",
    "NDICATION"
}

findings_keys = {
	"finding:",
    "findings:",
    "report:",
	"finidngs",
	"findigns",
	"finsings",
	"findnings",
	"findngs",
	"findins",
	"fingdings",
	"findindgs",
	"findoings",
	"fidings",
	"findgings",
    # "two views of the chest:"
}

impression_keys = {
    "impession",
	"imprssion",
	"impression",
	"imperssion",
	"conlcusion",
	"coclusion",
	"mpression",
	"impresion",
	"imoression",
	"impressoin",
	"impresson",
	"imprression",
	"iimpression",
	"conclusion",
	"fimpression",
	"impressio___",
}

golden_keys = findings_keys.union(impression_keys)
findings_keys = tuple(findings_keys)
impression_keys = tuple(impression_keys)
ignore_keys = tuple(ignore_keys)


def create_id2label_dict(path):
    id2label = {}
    f = open(path, 'r', newline='')
    csv_reader = csv.reader(f, delimiter=',', quotechar='"')
    for line in csv_reader:
        id2label[line[0]] = line[4] if len(line[4]) > 0 else "NA"
    id2label.pop("dicom_id")
    f.close()
    return id2label

def create_id2split_dict(path):
    id2split = {}
    f = open(path, 'r', newline='')
    csv_reader = csv.reader(f, delimiter=',', quotechar='"')
    for line in csv_reader:
        if line[1] in id2split.keys():
            assert id2split[line[1]] == line[3]
        id2split[line[1]] = line[3]
    id2split.pop("study_id")
    f.close()
    return id2split



keys = {}
key_count = {}
def key_stats(path, line):
    if ':' in line:
        line_l = line.split(':')
        if line_l[0].isupper():
            # if line.startswith(("WET","AP","PA","PORTABLE","CHEST","FRONTAL")): continue
            if line.startswith(tuple(ignore_keys)): return
            # if line.startswith(golden_keys): continue
            keys[line_l[0]] = (line_l[1], path)
            try:
                key_count[line_l[0]] += 1
            except:
                key_count[line_l[0]] = 1

class State(Enum):
    INIT = auto()        # Before "FINAL REPORT", ignore everything
    REPORT = auto()      # "FINAL REPORT" line reached, subsequent lines will write to report_str only, but will be discarded if ignore_keys encountered
    IGNORE = auto()      # Subsequent lines will be ignored
    FINDINGS = auto()    # Subsequent lines will append to findings_str and report_str
    IMPRESSION = auto()  # Subsequent lines will append to impression_str and report_str
    END = auto()         # Stop parsing subsequent lines


def parse_report(path):
    report_str, findings_str, impression_str = "", "", ""
    state = State.INIT
    f = open(path, 'r')
    for line in f.readlines():
        line = line.strip('\n\t ')
        if state == State.INIT:
            if line == "FINAL REPORT":
                state = State.REPORT
        elif state == State.REPORT:
            if line.startswith(ignore_keys):
                report_str = ""
                state = State.IGNORE
            elif line.lower().startswith(findings_keys):
                report_str = line
                findings_str = line
                state = State.FINDINGS
            elif line.lower().startswith(impression_keys):
                report_str = line
                impression_str = line
                state = State.IMPRESSION
            else:
                report_str += ' ' + line
        elif state == State.FINDINGS:
            if line.startswith(ignore_keys):
                state = State.END
            elif line.lower().startswith(impression_keys):
                report_str += ' ' + line
                impression_str = line
                state = State.IMPRESSION
            else:
                report_str += ' ' + line
                findings_str += ' ' + line
        elif state == State.IMPRESSION:
            if line.startswith(ignore_keys):
                state = State.END
            else:
                report_str += ' ' + line
                impression_str += ' ' + line
        elif state == State.IGNORE:
            if line.lower().startswith(findings_keys):
                report_str = line
                findings_str = line
                state = State.FINDINGS
            elif line.lower().startswith(impression_keys):
                report_str = line
                impression_str = line
                state = State.IMPRESSION
            elif line == "":
                state = State.REPORT
        elif state == State.END:
            break
    findings_str = findings_str.strip(' ')
    impression_str = impression_str.strip(' ')
    report_str = report_str.strip(' ')
    f.close()
    return report_str, findings_str, impression_str


data_dir = Path(data_dir)
try:
    out_file = open(out_path, 'w', newline='')
    print(f"data saved as {out_path}")
except:
    out_file = open("processed.csv", 'w', newline='')
    print("data saved as processed.csv")

csv_writer = csv.writer(out_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
dicomid2label = create_id2label_dict(data_dir/"mimic-cxr-2.0.0-metadata.csv")
studyid2split = create_id2split_dict(data_dir/"mimic-cxr-2.0.0-split.csv")

unique_id = 0

# num_mult_imgs = 0
# num_multiple_views = 0
for patient_path in tqdm((data_dir/"files").glob("p*/p*")):
    patient_id = patient_path.name
    for study_path in patient_path.glob("s*"):
        multiple_views = False
        study_id = study_path.name
        image_path_list = [str(path)[len(str(data_dir))+1:] for path in list(study_path.glob("*.jpg"))]
        image_label_list = [dicomid2label[path.split('/')[-1][:-4]] for path in image_path_list]
        # if len(image_label_list) > 1:
        #     num_mult_imgs += 1
        #     multiple_views = True
        image_paths = ','.join(image_path_list)
        image_labels = ','.join(image_label_list)
        report_path = data_dir/"files"/"reports"/patient_id[:3]/patient_id/f"{study_id}.txt"
        split = studyid2split[study_id[1:]]
        report, findings, impression = parse_report(report_path)
        # if multiple_views and "multiple views" in report.lower():
        #     num_multiple_views += 1
        #     multiple_views = False
        csv_writer.writerow([unique_id, patient_id, study_id, image_paths, image_labels, findings, impression, report, split])
        unique_id += 1
        if unique_id >= 100:
            break
    if unique_id >= 100:
        break

out_file.close()

# print(num_mult_imgs, num_multiple_views)
# for k,v in sorted(key_count.items(), key=lambda item: -item[1]):
#     print(f"{v}, {k}: {keys[k][0]}")
#     print(f"\t{keys[k][1]}")

