import os

from torchvision.datasets import CIFAR10
from vissl.data.dataset_catalog import VisslDatasetCatalog
from vissl.utils.io import save_file, load_file
from json import JSONDecodeError


ROOT = "./ats/data/datasets/"  # Where to store data.
LABEL2IDX = {"plane": 0, "car": 1, "bird": 2, "cat": 3, "deer": 4, "dog": 5, "frog": 6, "horse": 7, "ship": 8,
             "truck": 9}
N_CLASSES = len(LABEL2IDX)
IDX2LABEL = {v: k for k, v in LABEL2IDX.items()}


if __name__ == "__main__":
    _ = CIFAR10(root=ROOT, train=True, download=True)
    _ = CIFAR10(root=ROOT, train=False, download=True)

    json_data = {
        "CIFAR10": {
          "train": [
            os.path.abspath(ROOT), "<unused>"
          ],
          "val": [
            os.path.abspath(ROOT), "<unused>"
          ]
        }
    }

    try:
        metadata = load_file("dataset_catalog.json")
        metadata.update(json_data)
    except JSONDecodeError:
        metadata = json_data

    save_file(metadata, "dataset_catalog.json", append_to_json=False)
    VisslDatasetCatalog.register_json("dataset_catalog.json")
    print(f"Regstered datasets: {VisslDatasetCatalog.list()}")
