import json
import os.path

from tqdm import tqdm
from pathlib import Path
from datetime import datetime
import argparse


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--split", required=True, default="train", choices=["train"])
    parser.add_argument("--data_dir", type=str, default="/scratch/datasets/MIMIC-CXR/")
    parser.add_argument("--output_path", type=str, default=None)
    args = parser.parse_args()
    return args


args = parse_args()
split = args.split
data_dir = Path(args.data_dir)
out_path = data_dir / f"merged_instructions_{split}.json" if args.out_path is None else Path(args.out_path)
context_path = out_path.parent / "context_all.json"
image_path = out_path.parent / "images_all.json"
assert os.path.exists(context_path)


json_path_list = [
    data_dir / f"instructions_{split}.json",
    data_dir / f"correction_instructions_{split}.json",
    data_dir / f"history_instructions_{split}.json",
    data_dir / f"template_instructions_{split}.json",
    data_dir / f"comparison_instructions_{split}.json",
]

print(f"merging {[path.stem for path in json_path_list]}")

out_file = open(out_path, 'w')
merged_json = {}
merged_json["meta"] = {"version": "0.0.1", "time":datetime.today().strftime('%Y-%m-%d'), "author": "annonymous"}
merged_json["data"] = {}
for json_path in tqdm(json_path_list):
    assert json_path.is_file(), f"{json_path} does not exist"
    if "history" in str(json_path):
        task = "history"
    elif "template" in str(json_path):
        task = "template"
    elif "correction" in str(json_path):
        task = "correction"
    elif "comparison" in str(json_path):
        task = "comparison"
    else:
        task = ""
    file = open(json_path)
    data = json.load(file)["data"]
    file.close()
    if task != "":
        data = {f"{k}_{task}": v for k,v in data.items()}
    merged_json["data"].update(data)

json.dump(merged_json, out_file, indent=2)
out_file.close()

context_file = open(context_path, 'r')
context = json.load(context_file)
context_file.close()
for task in tqdm(["correction", "history", "comparison", "template"]):
    context.update({f"{k}_{task}": v for k, v in context.items() if "_" not in k})
context_file = open(context_path, 'w')
json.dump(context, context_file)

print(f"Saved:\n{out_path}\n{context_path}")