from .IIANet import IIANet
from .crossnet import crossnet
from .tfgridnet_v2 import tfgridnet_v2
from .tfgridnet_v2_step2 import tfgridnet_v2_step2
from .tfgridnet_v2_a_o import tfgridnet_v2_a_o
from .tfgridnet_v2_wer import tfgridnet_v2_wer
from .tfgridnet_v2_Lip_o import tfgridnet_v2_Lip_o
from .tfgridnet_v2_face_o import tfgridnet_v2_face_o

__all__ = [
    "IIANet",
    "crossnet"
    "tfgridnet_v2"
    "tfgridnet_v2_step2"
    "tfgridnet_v2_a_o"
    "tfgridnet_v2_wer"
    "tfgridnet_v2_Lip_o"
    "tfgridnet_v2_face_o"
]


def register_model(custom_model):
    """Register a custom model, gettable with `models.get`.

    Args:
        custom_model: Custom model to register.

    """
    if (
        custom_model.__name__ in globals().keys()
        or custom_model.__name__.lower() in globals().keys()
    ):
        raise ValueError(
            f"Model {custom_model.__name__} already exists. Choose another name."
        )
    globals().update({custom_model.__name__: custom_model})


def get(identifier):
    """Returns an model class from a string (case-insensitive).

    Args:
        identifier (str): the model name.

    Returns:
        :class:`torch.nn.Module`
    """
    if isinstance(identifier, str):
        to_get = {k.lower(): v for k, v in globals().items()}
        cls = to_get.get(identifier.lower())
        if cls is None:
            raise ValueError(f"Could not interpret model name : {str(identifier)}")
        return cls
    raise ValueError(f"Could not interpret model name : {str(identifier)}")
