import os

from torchvision.datasets import MNIST
from vissl.data.dataset_catalog import VisslDatasetCatalog
from vissl.utils.io import save_file, load_file
from json import JSONDecodeError


ROOT = "./ats/data/datasets/"  # Where to store data.
LABEL2IDX = {str(i): i for i in range(10)}
N_CLASSES = len(LABEL2IDX)
IDX2LABEL = {v: k for k, v in LABEL2IDX.items()}


if __name__ == "__main__":
    _ = MNIST(root=ROOT, train=True, download=True)
    _ = MNIST(root=ROOT, train=False, download=True)

    json_data = {
        "MNIST": {
          "train": [
            os.path.abspath(ROOT), "<unused>"
          ],
          "val": [
            os.path.abspath(ROOT), "<unused>"
          ]
        }
    }

    try:
        metadata = load_file("dataset_catalog.json")
        metadata.update(json_data)
    except JSONDecodeError:
        metadata = json_data

    save_file(metadata, "dataset_catalog.json", append_to_json=True)
    VisslDatasetCatalog.register_json("dataset_catalog.json")
    print(f"Regstered datasets: {VisslDatasetCatalog.list()}")
