import os
import shutil

from vissl.data.dataset_catalog import VisslDatasetCatalog
from vissl.utils.io import save_file, load_file
from json import JSONDecodeError
from extra_scripts.create_imagenet_data_files import get_images_labels_info
from collections import namedtuple
import numpy as np
from iopath.common.file_io import g_pathmgr


ROOT = "./ats/data/datasets/"  # Where to store data.
LABEL2IDX = None
N_CLASSES = None
IDX2LABEL = None


if __name__ == "__main__":
    # TODO: Assume the data is downloaded, this script must be called in get_imagenet1k.sh in the end.
    # Consider using this script: https://gist.github.com/bonlime/4e0d236cf98cd5b15d977dfa03a63643

    partitions = ["train", "val"]

    raw_dir = os.path.join(ROOT, "tmp")
    Args = namedtuple('Args', ['data_source_dir', 'generate_json'])
    args = Args(raw_dir, True)
    output_dir = os.path.join(os.path.abspath(ROOT), "imagenet1k")

    for partition in partitions:
        imgs_info, lbls_info, output_dict = get_images_labels_info(partition, args)
        img_info_out_path = f"{output_dir}/{partition}_images.npy"
        label_info_out_path = f"{output_dir}/{partition}_labels.npy"

        np.save(img_info_out_path, np.array(imgs_info))
        np.save(label_info_out_path, np.array(lbls_info))
        if args.generate_json:
            json_out_path = f"{output_dir}/{partition}_targets.json"
            import json

            with g_pathmgr.open(json_out_path, "w") as fp:
                json.dump(output_dict, fp)


    # json_data = {
    #     "oxford_pets_folder": {
    #       "train": [
    #         os.path.join(os.path.abspath(ROOT), "oxford_pets/train"), "<ignored>"
    #       ],
    #       "val": [
    #         os.path.join(os.path.abspath(ROOT), "oxford_pets/test"), "<ignored>"
    #       ]
    #     }
    # }

    # try:
    #     metadata = load_file("dataset_catalog.json")
    #     metadata.update(json_data)
    # except JSONDecodeError:
    #     metadata = json_data

    # save_file(metadata, "dataset_catalog.json", append_to_json=False)
    # VisslDatasetCatalog.register_json("dataset_catalog.json")
    # print(f"Regstered datasets: {VisslDatasetCatalog.list()}")
