from typing import Dict

from .base import BaseModelWrapper

models_registry: Dict[str, BaseModelWrapper] = {}


def register_model(model_name):
    def register_model_cls(cls):
        if model_name in models_registry:
            raise ValueError(f"Cannot register duplicate model ({model_name})")
        models_registry[model_name] = cls
        return cls

    return register_model_cls


def get_model(model_name: str, **kwargs) -> BaseModelWrapper:
    if model_name not in models_registry:
        raise ValueError(f"Model {model_name} not found")
    return models_registry[model_name](**kwargs)


def get_models_names():
    return list(models_registry.keys())


from .small import *
