import argparse
import json
import os

from tqdm import tqdm

from dm.normalise import normalise_val

parser = argparse.ArgumentParser()

parser.add_argument("--dataset", required=True)
parser.add_argument("--split", required=True)


def transform_data(scene):
    for obj in scene["objects"]:
        for key, val in obj["attributes"].items():
            norm_val = normalise_val(key, val)
            obj["attributes"][key] = round(norm_val, 5)
    return scene


def main(dataset: str, split: str):
    # input
    with open(f"data/{dataset}/scenes/CLEVR_{split}_extracted.json") as f:
        data = json.load(f)

    # transform
    for scene in tqdm(data["scenes"]):
        transform_data(scene)

    # out
    os.makedirs(f"data/{dataset}/scenes/", exist_ok=True)
    with open(f"data/{dataset}/scenes/CLEVR_{split}_normalised.json", "w") as f:
        json.dump(data, f)


if __name__ == "__main__":
    args = parser.parse_args()
    main(args.dataset, args.split)
