from typing import Type

from pl_bolts.datamodules.vision_datamodule import VisionDataModule

vision_datamodules_registry = {}


def register_vision_datamodule(name: str):
    def register_vision_datamodule_cls(cls):
        if name in vision_datamodules_registry:
            raise ValueError(f"Cannot register duplicate vision datamodule ({name})")
        vision_datamodules_registry[name] = cls
        return cls

    return register_vision_datamodule_cls


def get_vision_datamodule_cls(name: str) -> Type[VisionDataModule]:
    if name is not None:
        return vision_datamodules_registry[name]
    return vision_datamodules_registry["default"]


def get_vision_datamodules_names():
    return list(vision_datamodules_registry.keys())


from .basic import *
