import logging

from .register import register as register

logger = logging.getLogger(__name__)

try:
    from .splitter import *
except ImportError as error:
    logger.warning(
        f'{error} in `federate.contrib.splitter`, some modules are not '
        f'available.')


def get_splitter(config):
    """
    This function is to build splitter to generate simulated federation \
    datasets from non-FL dataset.

    Args:
        config: configurations for FL, see ``.configs``

    Returns:
        An instance of splitter (see ``core.splitters`` for details)

    Note:
      The key-value pairs of ``cfg.data.splitter`` and domain:
        ===================  ================================================
        Splitter type        Domain
        ===================  ================================================
        lda	                 Generic
        iid                  Generic
        louvain	             Graph (node-level)
        random	             Graph (node-level)
        rel_type	         Graph (link-level)
        scaffold	         Molecular
        scaffold_lda       	 Molecular
        rand_chunk	         Graph (graph-level)
        ===================  ================================================
    """
    client_num = config.federated.num_clients
    try:
        if config.data.splitter_args:
            kwargs = config.data.splitter_args[0]
    except:
        kwargs = {}

    # for func in register.splitter_dict.values():
    #     splitter = func(config.data.splitter, client_num, **kwargs)
    #     if splitter is not None:
    #         return splitter
    # Delay import
    # generic splitter
    if config.data.splitter == 'lda':
        from .splitters.generic import LDASplitter
        splitter = LDASplitter(client_num, **kwargs)
    # graph splitter
    elif config.data.splitter == 'louvain':
        from .splitters.graph import LouvainSplitter
        splitter = LouvainSplitter(client_num, **kwargs)
    elif config.data.splitter == 'random':
        from .splitters.graph import RandomSplitter
        splitter = RandomSplitter(client_num, **kwargs)
    elif config.data.splitter == 'rel_type':
        from .splitters.graph import RelTypeSplitter
        splitter = RelTypeSplitter(client_num, **kwargs)
    elif config.data.splitter == 'scaffold':
        from .splitters.graph import ScaffoldSplitter
        splitter = ScaffoldSplitter(client_num, **kwargs)
    elif config.data.splitter == 'scaffold_lda':
        from .splitters.graph import ScaffoldLdaSplitter
        splitter = ScaffoldLdaSplitter(client_num, **kwargs)
    elif config.data.splitter == 'rand_chunk':
        from .splitters.graph import RandChunkSplitter
        splitter = RandChunkSplitter(client_num, **kwargs)
    elif config.data.splitter == 'iid':
        from .splitters.generic import IIDSplitter
        splitter = IIDSplitter(client_num)
    elif config.data.splitter == 'meta':
        from .splitters.generic import MetaSplitter
        splitter = MetaSplitter(client_num)
    else:
        logger.warning(f'Splitter {config.data.splitter} not found or not '
                       f'used.')
        splitter = None
    return splitter
