# Copyright (c) Facebook, Inc. and its affiliates.
import json
import os
import sys

from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.datasets import load_sem_seg
from detectron2.utils.file_io import PathManager


def _get_imagenet_files(image_dir):
    files = []
    # scan through the directory
    categories = PathManager.ls(image_dir)
    for cat in categories:
        cat_img_dir = os.path.join(image_dir, cat)
        
        for basename in PathManager.ls(cat_img_dir):
            image_file = os.path.join(cat_img_dir, basename)

            # files.append((image_file, instance_file, label_file, json_file))
            files.append((image_file))
    assert len(files), "No images found in {}".format(image_dir)
    assert PathManager.isfile(files[0]), files[0]

    return files

def load_imagenet(image_dir):
    """
    Args:
        image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
        gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train".
    Returns:
        list[dict]: a list of dict, each has "file_name" and
            "sem_seg_file_name".
    """
    ret = []
    # gt_dir is small and contain many small files. make sense to fetch to local first
    # gt_dir = PathManager.get_local_path(gt_dir)
    # import ipdb; ipdb.set_trace()

    for image_file in _get_imagenet_files(image_dir):
    # for image_file, json_file in _get_imagenet_files(image_dir):
        # with PathManager.open(json_file, "r") as f:
        #     jsonobj = json.load(f)
        ret.append(
            {
                "file_name": image_file,
                # "height": jsonobj["imgHeight"],
                # "width": jsonobj["imgWidth"],
            }
        )
    assert len(ret), f"No images found in {image_dir}!"
    # assert PathManager.isfile(
    #     ret[0]["sem_seg_file_name"]
    # ), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py"  # noqa

    return ret 


def register_imagenet(root):
    root = os.path.join(root, "ILSVRC2012")
    # meta = _get_kitti_step_meta()
    for name in ["train", "val"]:
        image_dir = os.path.join(root, name)
        name = f"imagenet_colorization_{name}"
        DatasetCatalog.register(
            name, lambda x=image_dir: load_imagenet(x)
        )
        MetadataCatalog.get(name).set(
            image_root=image_dir,
            evaluator_type="imagenet_colorization",
        )


# import ipdb; ipdb.set_trace()

_root = os.getenv("DETECTRON2_DATASETS")
try:
    # import ipdb; ipdb.set_trace()
    register_imagenet(_root)
except:
    # import ipdb; ipdb.set_trace()
    print('ImageNet not registered! Check $DETECTRON2_DATASETS/ILSVRC2012')
    sys.exit(1)
