import os

from torchvision.datasets import CIFAR100
from vissl.data.dataset_catalog import VisslDatasetCatalog
from vissl.utils.io import save_file, load_file
from json import JSONDecodeError


ROOT = "./ats/data/datasets/"  # Where to store data.
N_CLASSES = 100
LABEL2IDX = dict(
    zip(['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle',
         'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle',
         'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup',
         'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house',
         'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man',
         'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid',
         'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
         'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew',
         'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower',
         'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
         'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'], range(N_CLASSES)))
IDX2LABEL = {v: k for k, v in LABEL2IDX.items()}


if __name__ == "__main__":
    _ = CIFAR100(root=ROOT, train=True, download=True)
    _ = CIFAR100(root=ROOT, train=False, download=True)

    json_data = {
        "CIFAR100": {
          "train": [
            os.path.abspath(ROOT), "<unused>"
          ],
          "val": [
            os.path.abspath(ROOT), "<unused>"
          ]
        }
    }

    try:
        metadata = load_file("dataset_catalog.json")
        metadata.update(json_data)
    except JSONDecodeError:
        metadata = json_data

    save_file(metadata, "dataset_catalog.json", append_to_json=False)
    VisslDatasetCatalog.register_json("dataset_catalog.json")
    print(f"Regstered datasets: {VisslDatasetCatalog.list()}")
