import warnings
from typing import Any, Callable, Dict, Union

act_dict: Dict[str, Any] = {}
node_encoder_dict: Dict[str, Any] = {}
edge_encoder_dict: Dict[str, Any] = {}
stage_dict: Dict[str, Any] = {}
head_dict: Dict[str, Any] = {}
layer_dict: Dict[str, Any] = {}
pooling_dict: Dict[str, Any] = {}
network_dict: Dict[str, Any] = {}
config_dict: Dict[str, Any] = {}
dataset_dict: Dict[str, Any] = {}
loader_dict: Dict[str, Any] = {}
sampler_dict: Dict[str, Any] = {}
optimizer_dict: Dict[str, Any] = {}
scheduler_dict: Dict[str, Any] = {}
loss_dict: Dict[str, Any] = {}
train_dict: Dict[str, Any] = {}
feature_augment_dict: Dict[str, Any] = {}
metric_dict: Dict[str, Any] = {}


def register_base(mapping: Dict[str, Any],
                  key: str,
                  module: Any = None) -> Union[None, Callable]:
    r"""Base function for registering a module in GraphGym.

    Args:
        mapping (dict): Python dictionary to register the module.
            hosting all the registered modules
        key (string): The name of the module.
        module (any, optional): The module. If set to :obj:`None`, will return
            a decorator to register a module.
    """
    if module is not None:
        if key in mapping:
            raise KeyError(f"Module with '{key}' already defined")
        mapping[key] = module
        return

    # Other-wise, use it as a decorator:
    def bounded_register(module):
        register_base(mapping, key, module)
        return module

    return bounded_register


def register_act(key: str, module: Any = None):
    r"""Registers an activation function in GraphGym."""
    return register_base(act_dict, key, module)


def register_node_encoder(key: str, module: Any = None):
    r"""Registers a node feature encoder in GraphGym."""
    return register_base(node_encoder_dict, key, module)


def register_edge_encoder(key: str, module: Any = None):
    r"""Registers an edge feature encoder in GraphGym."""
    return register_base(edge_encoder_dict, key, module)


def register_stage(key: str, module: Any = None):
    r"""Registers a customized GNN stage in GraphGym."""
    return register_base(stage_dict, key, module)


def register_head(key: str, module: Any = None):
    r"""Registers a GNN prediction head in GraphGym."""
    return register_base(head_dict, key, module)


def register_layer(key: str, module: Any = None):
    r"""Registers a GNN layer in GraphGym."""
    return register_base(layer_dict, key, module)


def register_pooling(key: str, module: Any = None):
    r"""Registers a GNN global pooling/readout layer in GraphGym."""
    return register_base(pooling_dict, key, module)


def register_network(key: str, module: Any = None):
    r"""Registers a GNN model in GraphGym."""
    return register_base(network_dict, key, module)


def register_config(key: str, module: Any = None):
    r"""Registers a configuration group in GraphGym."""
    return register_base(config_dict, key, module)


def register_dataset(key: str, module: Any = None):
    r"""Registers a dataset in GraphGym."""
    return register_base(dataset_dict, key, module)


def register_loader(key: str, module: Any = None):
    r"""Registers a data loader in GraphGym."""
    return register_base(loader_dict, key, module)


def register_sampler(key: str, module: Any = None):
    r"""Registers a graph sampler in GraphGym."""
    return register_base(sampler_dict, key, module)


def register_optimizer(key: str, module: Any = None):
    r"""Registers an optimizer in GraphGym."""
    return register_base(optimizer_dict, key, module)


def register_scheduler(key: str, module: Any = None):
    r"""Registers a learning rate scheduler in GraphGym."""
    return register_base(scheduler_dict, key, module)


def register_loss(key: str, module: Any = None):
    r"""Registers a loss function in GraphGym."""
    return register_base(loss_dict, key, module)


def register_train(key: str, module: Any = None):
    r"""Registers a training function in GraphGym."""
    return register_base(train_dict, key, module)


def register_metric(key: str, module: Any = None):
    r"""Register a metric function in GraphGym."""
    return register_base(metric_dict, key, module)


def register_feature_augment(key, module):
    return register_base(feature_augment_dict, key, module)


class ModuleStore(dict):
    def __init__(self):
        super().__init__()

    def register(self,
                 module_group: str,
                 key: str,
                 module: Any = None) -> Union[None, Callable]:
        r"""Base function for registering a module in GraphGym.

        Args:
            module_group (str): The name of the module group
            key (string): The name of the module.
            module (any, optional): The module. If set to :obj:`None`, will
                return a decorator to register a module.
        """

        if module_group not in self.keys():
            self[module_group] = {}

        if module is not None:
            if key in self[module_group]:
                warnings.warn(
                    f"Module group {module_group} with '{key}' already "
                    f"defined, registeration failed")
            self[module_group][key] = module
            return

        # Other-wise, use it as a decorator:
        def bounded_register(module):
            self.register(module_group, key, module)
            return module

        return bounded_register


module = ModuleStore()
