"""Conflict count compositionality metric"""

import itertools
import collections

import gin

from ncc.compositionality_metrics import base


@gin.configurable
class ConflictCount(base.Metric):
    """Conflict count compositionality metric"""

    name = 'conflict count'

    def measure(self, protocol: base.Protocol) -> float:
        all_len = len(protocol[0][0])
        assert all_len == 2, 'Not tested.'
        all_conflicts = []
        # for all mappings of symbol to features
        for p in itertools.permutations(range(all_len)):
            meanings = [collections.defaultdict(collections.Counter)
                        for i in range(all_len)]
            for features, msg in protocol:
                assert len(msg) == len(features) == all_len
                for i in range(all_len):
                    meanings[i][msg[i]].update([features[p[i]]])

            # count conflicts
            conflicts = 0
            for meaning in meanings:
                for symbol in meaning.values():
                    conflicts += sum(v for c, v in symbol.most_common()[1:])
            all_conflicts += [conflicts]
        return min(all_conflicts)
