"""Helper functions for generating grid of parameters."""

import itertools

import gin


def features_merge(keys, list_of_combinations):
    """Merges list of combinations.

    Example: Let config be {'label_color': [0,1],
                            'label_shape': [1,3]},
    and we take caresian product of label values. Then
    the list_of_combinations is:
        [(0, 1), (0, 3), (1, 1), (1, 3)]
    The function returns:
        [{'label_color': 0, 'label_shape': 1},
         {'label_color': 0, 'label_shape': 3},
         {'label_color': 1, 'label_shape': 1},
         {'label_color': 1, 'label_shape': 3}]

    Args:
        keys (tuple): Tuple of labels.
        list_of_combinations (dict): Dictionary of ranges.

    Return:
        list of dictionaries.
    """
    features_list = []
    for combination in list_of_combinations:
        zip_dict = dict(zip(keys, combination))
        features_list.append(zip_dict)

    return features_list


@gin.configurable
def features_cartesian(config):
    """Returns list of dicts of all combinations.

    Example: config = {'label_color': [0,1], 'label_shape': [1,3]}.
    Then the output is a dict:
        [{'label_color': 0, 'label_shape': 1},
        {'label_color': 1, 'label_shape': 3},
        {'label_color': 0, 'label_shape': 1},
        {'label_color': 1, 'label_shape': 3}]

    Args:
        config (dict): Dictionary of ranges

    Return:
        list of dictionaries.
    """
    assert isinstance(config, dict), 'The config needs to be a dictionary'

    cartesian_product = []
    keys, values = zip(*config.items())
    for prod in itertools.product(*values):
        cartesian_product.append(prod)

    return features_merge(keys, cartesian_product)


@gin.configurable
def features_zip(config):
    """Returns list of dicts of zipped configs.

    Example: config = {'label_color': [0,1], 'label_shape': [1,3]}.
    Then the output is a dict
        [{'label_color': 0, 'label_shape': 1},
         {'label_color': 1, 'label_shape': 3}]

    Args:
        config (dict): Dictionary of ranges

    Return:
        list of dictionaries.
    """
    assert isinstance(config, dict), 'The config needs to be a dictionary'

    keys = tuple(sorted(config.keys()))
    zipped_combinations = list(zip(*config.values()))

    return features_merge(keys, zipped_combinations)


@gin.configurable
def features_manual(config):
    """Returns config

    Example:
        config = [{'label_color': 0, 'label_shape': 1},
                  {'label_color': 1, 'label_shape': 3}]

    Args:
        config (list): List of dictionaries

    Return:
        list of dictionaries.
    """
    assert isinstance(config, list), 'The config needs to be a list'

    return config


def max_cardinality(features_list):
    """Counts max_i |F_i|.

    Example: For,
        features_list = [{'label_color': 0, 'label_shape': 1},
                        {'label_color': 1, 'label_shape': 1}],
        the function returns 2, since max(|{0, 1}|, |{1, 1}|) = 2.

    Args:
        features_list (list): List of dictionaries

    Return:
        max_i |F_i|
    """
    feature_values = [d.values() for d in features_list]
    feature_values = zip(*feature_values)
    feature_values = [len(set(fv)) for fv in feature_values]
    return max(feature_values)


def max_value(features_list):
    """Counts max_i max(F_i).

    Example: For,
        features_list = [{'label_color': 0, 'label_shape': 1},
                        {'label_color': 4, 'label_shape': 2}],
        the function returns 2, since max(max{0,4}, max{1,2}) = 4.

    Args:
        features_list (list): List of dictionaries

    Return:
        max_i max(F_i)
    """
    feature_values = [d.values() for d in features_list]
    feature_values = zip(*feature_values)
    feature_values = [max(fv) for fv in feature_values]
    return 1 + max(feature_values)  # since 0-based
