"""Measure how random is the communation protocol."""

import collections

import gin
import numpy as np
import pandas as pd


from ncc.compositionality_metrics import base


def entropy_and_argmax(messages):
    """Calulates entropy and argmax of messages.

    """

    messages_list = list(messages)

    if len(messages_list) == 0:
        return 0.0

    counter = collections.Counter(messages_list)
    argmax_message = counter.most_common(1)[0][0]
    frequencies = np.asarray(list(counter.values()), dtype=float)
    probs = frequencies / sum(frequencies)

    entropy = np.round(- np.sum(probs * np.log(probs)) + 0.0001, 3)

    return {'entropy': entropy, 'argmax': argmax_message}


@gin.configurable
class RandomnessMeasure(base.Metric):
    """Measure average sender randomness of messages for given labels."""

    name = 'randomness measure'

    def measure(self, protocol: base.Protocol) -> float:

        # Work only for two features
        protocol_flat = [(feat_1, feat_2, message)
                        for (feat_1, feat_2), message in protocol]

        protocol_df = pd.DataFrame(protocol_flat,
                                   columns=['feat_1', 'feat_2', 'message'])

        table_entropy_and_argmax = pd.pivot_table(protocol_df, values='message',
                                                  index=['feat_1'],
                                                  columns=['feat_2'],
                                                  aggfunc=entropy_and_argmax)

        entropy_table = table_entropy_and_argmax.applymap(
            lambda x: x['entropy'] if isinstance(x, dict) else -0.0)

        return float(np.mean(np.mean(entropy_table)))


@gin.configurable
class SymbolCountMetric(base.Metric):
    """Measures symbols used in communication."""

    def __init__(self, aggregation_method='union'):
        self._aggregation_method = aggregation_method
        self.name = 'symbol count metric ' + str(aggregation_method)

    def measure(self, protocol: base.Protocol) -> float:

        all_len = len(protocol[0][0])
        symbols_on_position = [set()] * all_len

        for _, message in protocol:
            for i, symbol in enumerate(message):
                symbols_on_position[i].add(symbol)

        if self._aggregation_method == 'union':
            return len(set().union(*symbols_on_position))

        if self._aggregation_method == 'max':
            return max([len(symbols) for symbols in symbols_on_position])

        if isinstance(self._aggregation_method, int):
            return len(symbols_on_position[self._aggregation_method])

        assert False, 'Unknown aggregation method.'
