import fnmatch
from types import ModuleType
from typing import Any, Callable, Iterable, List, Optional, Set, TypeVar, Union

from torch import nn

__all__ = ["get_model", "get_model_builder", "list_models"]

M = TypeVar("M", bound=nn.Module)

BUILTIN_MODELS = {}


def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
    def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
        key = name if name is not None else fn.__name__
        if key in BUILTIN_MODELS:
            raise ValueError(f"An entry is already registered under the name '{key}'.")
        BUILTIN_MODELS[key] = fn
        return fn

    return wrapper


def list_models(
    module: Optional[ModuleType] = None,
    include: Union[Iterable[str], str, None] = None,
    exclude: Union[Iterable[str], str, None] = None,
) -> List[str]:
    """
    Returns a list with the names of registered models.

    Args:
        module (ModuleType, optional): The module from which we want to extract the available models.
        include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models.
            Filters are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
            wildcards. In case of many filters, the results is the union of individual filters.
        exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models.
            Filter are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
            wildcards. In case of many filters, the results is removal of all the models that match any individual filter.

    Returns:
        models (list): A list with the names of available models.
    """
    all_models = {
        k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__
    }
    if include:
        models: Set[str] = set()
        if isinstance(include, str):
            include = [include]
        for include_filter in include:
            models = models | set(fnmatch.filter(all_models, include_filter))
    else:
        models = all_models

    if exclude:
        if isinstance(exclude, str):
            exclude = [exclude]
        for exclude_filter in exclude:
            models = models - set(fnmatch.filter(all_models, exclude_filter))
    return sorted(models)


def get_model_builder(name: str) -> Callable[..., nn.Module]:
    """
    Gets the model name and returns the model builder method.

    Args:
        name (str): The name under which the model is registered.

    Returns:
        fn (Callable): The model builder method.
    """
    # name = name.lower()
    try:
        fn = BUILTIN_MODELS[name]
    except KeyError:
        raise ValueError(f"Unknown model {name}")
    return fn


def get_model(model_name: str, **config: Any) -> nn.Module:
    """
    Gets the model name and configuration and returns an instantiated model.

    Args:
        model_name (str): The name under which the model is registered.
        **config (Any): parameters passed to the model builder method.

    Returns:
        model (nn.Module): The initialized model.
    """
    fn = get_model_builder(model_name)
    return fn(**config)
