import os
import base64
import json
from tqdm import tqdm
from pathlib import Path
from datetime import datetime

from mimic_cxr_utils import *


split = "test"


part = ["pos","neg"]
instruction = "Act as a radiologist and write a diagnostic chest CT report for the patient based on their chest CT scans:"
data_dir = "/data/datasets/BIMCV-COVID19-cIter_1_2-Negative_CT_only/"
out_instructions = f"/data/datasets/BIMCV-COVID19-cIter_1_2-Negative_CT_only/instructions_{split}.json"
out_images = f"/data/datasets/BIMCV-COVID19-cIter_1_2-Negative_CT_only/images_{split}.json"
out_context = f"/data/datasets/BIMCV-COVID19-cIter_1_2-Negative_CT_only/context_{split}.json"

data_dir = Path(data_dir)

out_instructions = open(out_instructions, 'w')
out_images = open(out_images, 'w')
out_context = open(out_context, 'w')
instructions_json, images_json, context_json = {}, {}, {}
instructions_json["meta"] = {"version": "0.0.1", "time":datetime.today().strftime('%Y-%m-%d'), "author": "annonymous"}
instructions_json["data"] = {}

unique_id = 0
for p in part:
    data_df = pd.read_csv(data_dir/f"master_max_CT_EN_split_{p}.csv")
    for id, row in tqdm(data_df.iterrows()):
        if split != "all" and row["Split"] != split:
            continue
        report_id = row["ReportID"]
        report = list(filter(lambda x: len(x.strip()) > 1, row["Report"][1:-1].split("'")))[-1].strip("\n\t,'[] ")
        if report == "":
            continue
        image_id = row["Filename"].replace(".nii", "").replace(".gz", "")
        image_path = data_dir / f"covid19_{p}" / row["Filepath"] / row["Filename"]
        images_json[image_id] = str(image_path)
        assert os.path.exists(image_path)
        image_ids_list = [image_id]
        context_json[report_id] = []
        instructions_json["data"][report_id] = {
            "instruction": instruction,
            "answer": report,
            "image_ids": image_ids_list,
            "rel_ins_ids": []
        }
        unique_id += 1

json.dump(images_json, out_images)
json.dump(context_json, out_context)
json.dump(instructions_json, out_instructions)

out_context.close()
out_images.close()
out_instructions.close()

print(f"{unique_id} data processed")