"""Helper functions for using compositionality metrics"""

from typing import Dict, List, Tuple  # pylint: disable=import-only-modules
import string
import collections

import gin
import numpy as np

from ncc.compositionality_metrics import base


def get_protocol(
        messages: np.ndarray,
        labels: List[Tuple[int, int]],
        features_names: List[str]
) -> base.Protocol:
    """Convert a record of sender-receiver interaction into a
    base.Protocol instance
    """
    protocol = []
    for message, label in zip(messages, labels):
        derivation = tuple(f'{feature_name}={value}' for feature_name, value
                           in zip(features_names, label))
        assert all(symbol < 52 for symbol in message)
        message = ''.join(string.ascii_letters[symbol_idx]
                          for symbol_idx in message)
        protocol.append((derivation, message))
    return protocol


def get_vocab_from_protocol(protocol: base.Protocol) -> Dict[str, int]:
    character_set = set(c for derivation, message in protocol for c in message)
    return {char: idx for idx, char in enumerate(character_set)}


class SkipEpochMetricWrapper(base.Metric):
    """
    Compute `metric` every `skip_epochs` epochs, return previous value otherwise
    """
    metric: base.Metric

    def __init__(self, metric: base.Metric, skip_epochs: int = 1):
        self._metric = metric
        self.name = self._metric.name
        self.skip_epochs = skip_epochs
        self.counter = -1
        self.last_value = None

    def measure(self, protocol: base.Protocol) -> float:
        self.counter += 1
        if self.counter % self.skip_epochs == 0:
            self.last_value = self._metric.measure(protocol)
            return self.last_value
        else:
            return self.last_value


@gin.configurable
class OneMessagePerClassWrapper(base.Metric):
    """
    For each class (combination of features) select only most often message.
    """
    def __init__(self, metric: base.Metric):
        self._metric = metric
        self.name = self._metric.name + '_OneMsgPerClass'

    def measure(self, protocol: base.Protocol) -> float:
        freq = collections.defaultdict(collections.Counter)
        for cls, msg in protocol:
            freq[cls].update([msg])
        proto2 = []
        for cls, msg_cnt in freq.items():
            proto2.append((cls, msg_cnt.most_common(1)[0][0]))
        return self._metric.measure(proto2)
