"""Context independence compositionality metric"""

import collections
from typing import Dict  # pylint: disable=import-only-modules

import gin
import numpy as np

from ncc.compositionality_metrics import base


@gin.configurable
class ContextIndependence(base.Metric):
    """Context independence compositionality metric"""

    name = 'context independence'

    def measure(self, protocol: base.Protocol) -> float:
        character_set = set(c for derivation, message in protocol
                            for c in message)
        vocab = {char: idx for idx, char in enumerate(character_set)}
        concept_set = set(concept for derivation, message in protocol
                          for concept in derivation)
        concepts = {concept: idx for idx, concept in enumerate(concept_set)}

        concept_symbol_matrix = self._compute_concept_symbol_matrix(
            protocol, vocab, concepts)
        v_cs = concept_symbol_matrix.argmax(axis=1)
        context_independence_scores = np.zeros(len(concept_set))
        for concept in range(len(concept_set)):
            v_c = v_cs[concept]
            p_vc_c = concept_symbol_matrix[concept, v_c] / concept_symbol_matrix[concept, :].sum(axis=0)  # pylint: disable=line-too-long
            p_c_vc = concept_symbol_matrix[concept, v_c] / concept_symbol_matrix[:, v_c].sum(axis=0)  # pylint: disable=line-too-long
            context_independence_scores[concept] = p_vc_c * p_c_vc
        return context_independence_scores.mean(axis=0)

    def _compute_concept_symbol_matrix(
            self,
            protocol: base.Protocol,
            vocab: Dict[str, int],
            concepts: Dict[str, int],
            epsilon: float = 10e-8
    ) -> np.ndarray:
        concept_to_message = collections.defaultdict(list)
        for derivation, message in protocol:
            for concept in derivation:
                concept_to_message[concept] += list(message)
        concept_symbol_matrix = np.ndarray(
            (len(concept_to_message), len(vocab))
        )
        concept_symbol_matrix.fill(epsilon)
        for concept, symbols in concept_to_message.items():
            for symbol in symbols:
                concept_symbol_matrix[concepts[concept], vocab[symbol]] += 1
        return concept_symbol_matrix
