import os
import base64
import json
import argparse
import random

from tqdm import tqdm
from pathlib import Path
from datetime import datetime

from mimic_cxr_utils import *

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--split", required=True, default="train", choices=["train"])
    parser.add_argument("--data_dir", type=str, default="/scratch/datasets/MIMIC-CXR/")
    parser.add_argument("--output_path", type=str, default=None)
    args = parser.parse_args()
    return args

args = parse_args()
split = args.split
findings_only = True

data_dir = args.data_dir
out_instructions = Path(args.data_dir) / f"instructions_{split}.json" if args.output_path is None else args.output_path
out_images = out_instructions.parent / f"images_{split}.json"
out_context = out_instructions.parent / f"context_{split}.json"


print(f"Saving:\n{out_instructions}\n{out_images}\n{out_context}")
out_instructions = open(out_instructions, 'w')
# out_images = open(out_images, 'w')
# out_context = open(out_context, 'w')
data_dir = Path(data_dir)
dicomid2label = create_id2label_dict(data_dir/"mimic-cxr-2.0.0-metadata.csv")
studyid2split = create_id2split_dict(data_dir/"mimic-cxr-2.0.0-split.csv")
instructions_json, images_json, context_json = {}, {}, {}
instructions_json["meta"] = {"version": "0.0.1", "time":datetime.today().strftime('%Y-%m-%d'), "author": "annonymous"}
instructions_json["data"] = {}

unique_id = 0
total_count = 0
for patient_path in tqdm((data_dir/"files").glob("p*/p*")):
    total_count += 1
    patient_id = patient_path.name
    for study_path in patient_path.glob("s*"):
        multiple_views = False
        study_id = study_path.name
        if split != "all" and studyid2split[study_id[1:]] != split:
            continue
        image_path_list = [str(path)[len(str(data_dir))+1:] for path in list(study_path.glob("*.jpg"))]
        image_label_list = [dicomid2label[path.split('/')[-1][:-4]] for path in image_path_list]
        image_ids_list = [os.path.basename(image_path).replace(".jpg", "") for image_path in image_path_list]
        image_paths = ','.join(image_path_list)
        image_labels = ','.join(image_label_list)
        image_ids_list = [image_ids_list[i] for i in range(len(image_ids_list)) if image_label_list[i] in ["PA", "AP"]]
        if len(image_ids_list) == 0:
            continue
        context_json[study_id] = []
        for image_path in image_path_list:
            # f = open(data_dir / image_path, "rb")
            # image = str(base64.b64encode(f.read()))[2:-1]
            # f.close()
            image_id = os.path.basename(image_path).replace(".jpg", "")
            images_json[image_id] = image_path
        report_path = data_dir/"files"/"reports"/patient_id[:3]/patient_id/f"{study_id}.txt"
        report, findings, impression = parse_report(report_path)
        if split != "all" and (findings == "" or impression == ""):
            if findings_only and findings != "":
                pass
            else:
                continue
        instructions_json["data"][study_id] = {
            "instruction": random.choice(no_context_instructions).lower(),
            "answer": findings.lower(),
            "image_ids": image_ids_list,
            "rel_ins_ids": []
        }
        unique_id += 1

json.dump(images_json, out_images)
out_images.close()
json.dump(context_json, out_context)
out_context.close()
json.dump(instructions_json, out_instructions)
out_instructions.close()

print(f"{unique_id}/{total_count} data processed")

