import csv
import sys
import os
import math
from pathlib import Path
import itertools
from more_itertools import batched
from multiprocessing import Pool, Lock
from tqdm import tqdm

sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from chexpert_labeler import ChexpertLabeler
from mimic_cxr_utils import CATEGORIES, parse_report, parse_report_raw



data_dir = Path("/scratch/datasets/MIMIC-CXR/files/reports_all")
out_file = "chexpert_extracted.csv"
num_processes = os.cpu_count()-1


file = open(out_file, 'w')
csv_writer = csv.DictWriter(file, fieldnames=["study_id"]+CATEGORIES, delimiter=',')
csv_writer.writeheader()
file.close()
lock = Lock()

def get_label(report_list):
    labeler = ChexpertLabeler()
    label_list = []
    for study_id, report in tqdm(report_list):
        label = labeler.get_label(report)
        label = [1 if label[k] == 1 else 0 for k in CATEGORIES]
        label_list.append((study_id, label))
    del labeler
    lock.acquire()
    file = open(out_file, 'a')
    for study_id, label in label_list:
        row = {"study_id": study_id}
        for i in range(len(CATEGORIES)):
            row[CATEGORIES[i]] = label[i]
        csv_writer = csv.DictWriter(file, fieldnames=["study_id"] + CATEGORIES, delimiter=',')
        csv_writer.writerow(row)
    file.close()
    print(f"Wrote {len(label_list)} lines")
    lock.release()
    return label_list

all_reports = []
for i, report_path in tqdm(enumerate(data_dir.iterdir()), total=len(os.listdir(data_dir))):
    study_id = report_path.stem.replace('s','')
    report, findings, _ = parse_report(report_path)
    report, findings = report.strip(), findings.strip()
    if findings != "":
        all_reports.append((study_id, findings))
    elif report != "":
        all_reports.append((study_id, report))
    else:
        report = parse_report_raw(report_path).strip()
        all_reports.append((study_id, report))


report_chunks = list(batched(all_reports, math.ceil(len(all_reports) / num_processes)))
if num_processes != len(report_chunks):
    num_processes = len(report_chunks)
pool = Pool(num_processes)
report_labels = pool.imap(get_label, report_chunks)
pool.close()
pool.join()







