# Sample generated data

from pathlib import Path
import os
import shutil
import random
import csv
from tqdm import tqdm
import pandas as pd
# from mimic_cxr_utils import create_id2split_dict, create_id2imagepath_dict


num = "all"  # samples per task or "all"
tasks = ["correction", "template", "history"]
split = "test"  # train, test, None
postprocess = True
copy_img = False

data_dir = Path("/scratch/datasets/MIMIC-CXR/")
dst_dir = Path("./sampled_data")


def create_id2split_dict(path):
    """
    Create a dict of study_id to split (train/validate/test)
    """
    id2split = {}
    f = open(path, 'r', newline='')
    csv_reader = csv.reader(f, delimiter=',', quotechar='"')
    for line in csv_reader:
        if line[1] in id2split.keys():
            assert id2split[line[1]] == line[3]
        id2split[line[1]] = line[3]
    id2split.pop("study_id")
    f.close()
    return id2split

def create_id2imagepath_dict(path):
    """
    Create a dict of study_id to image path
    """
    metadata = pd.read_csv(path)
    id2path = {}
    for _, row in metadata.iterrows():
        study_id = f"s{row['study_id']}"
        if study_id in id2path.keys():
            continue
        patient_id = row["subject_id"]
        id2path[study_id] = os.path.join("files", f"p{str(patient_id)[:2]}", f"p{str(patient_id)}", study_id)
    return id2path


studyid2split = create_id2split_dict(data_dir / "mimic-cxr-2.0.0-split.csv")
studyid2path = create_id2imagepath_dict(data_dir / "mimic-cxr-2.0.0-metadata.csv")
if os.path.exists(dst_dir):
    shutil.rmtree(dst_dir)
os.makedirs(dst_dir)
for task in tasks:
    os.makedirs(dst_dir / f"reports_{task}")
    task_dir = data_dir / "files" / f"reports_{task}"
    txt_list = os.listdir(task_dir)
    if split is not None:
        txt_list = list(filter(lambda x: studyid2split[x[1:9]] == split, txt_list))
    if num != "all":
        txt_list = random.sample(txt_list, num)
    for txt in tqdm(txt_list):
        src_path = task_dir / txt
        dst_path = dst_dir / f"reports_{task}" / txt
        shutil.copy(src_path, dst_path)
        src_img = data_dir / studyid2path[txt.replace(".txt", "")]
        dst_img = dst_dir / f"reports_{task}" / txt.replace(".txt", "")
        if copy_img:
            shutil.copytree(src_img, dst_img)
    if postprocess:
        txt_list = list(dst_path.parent.iterdir())
        for txt in txt_list:
            with open(txt) as f:
                content = f.read()
            reason_idx = content.lower().find("reason:")
            if reason_idx == -1 or content[reason_idx+7:].strip() == "":
                with open(txt, 'w') as f:
                    f.write(content[:reason_idx])
            else:
                os.remove(txt)
        print(f"{len(list(dst_path.parent.iterdir()))} files after postprocessing")


