"""Positional disetanglement and bag-of-word disentangelment

Adapted from
https://github.com/facebookresearch/EGG/blob/4eca7c0b0908c05d9d402c9c5d20ccf8aaae01b2/egg/zoo/compo_vs_generalization/intervention.py#L45-L92 and  # pylint: disable=line-too-long
https://github.com/facebookresearch/EGG/blob/4eca7c0b0908c05d9d402c9c5d20ccf8aaae01b2/egg/zoo/language_bottleneck/intervention.py#L14-L61  # pylint: disable=line-too-long
"""

import collections
from typing import List  # pylint: disable=import-only-modules

import gin
import numpy as np

from ncc.compositionality_metrics import base
from ncc.compositionality_metrics import utils


def compute_entropy(symbols: List[str]) -> float:
    frequency_table = collections.defaultdict(float)
    for symbol in symbols:
        frequency_table[symbol] += 1.0
    H = 0  # pylint: disable=invalid-name
    for symbol in frequency_table:
        p = frequency_table[symbol]/len(symbols)
        H += -p * np.log2(p)  # pylint: disable=invalid-name
    return H


def compute_mutual_information(
        concepts: List[str],
        symbols: List[str]
) -> float:
    """Helper function for computing mutual information between
    a pair of sequences"""

    concept_entropy = compute_entropy(concepts)  # H[p(concepts)]
    symbol_entropy = compute_entropy(symbols)  # H[p(symbols)]
    symbols_and_concepts = [
        symbol + '_' + concept for symbol, concept in zip(symbols, concepts)
    ]
    # H[p(concepts, symbols)]
    symbol_concept_joint_entropy = compute_entropy(symbols_and_concepts)
    return concept_entropy + symbol_entropy - symbol_concept_joint_entropy  # pylint: disable=line-too-long


def get_permutation_invariant_protocol(
        protocol: base.Protocol
) -> base.Protocol:
    return [(derivation, sorted(message)) for derivation, message in protocol]


@gin.configurable
class PositionalDisentanglement(base.Metric):
    """Base class for disentanglement metrics"""

    name = 'positional disentanglement'

    def __init__(
            self,
            max_message_length: int = 2,
            num_concept_slots: int = 2
    ):
        self.max_message_length = max_message_length
        self.num_concept_slots = num_concept_slots

    def measure(self, protocol: base.Protocol) -> float:
        disentanglement_scores = []
        non_constant_positions = 0

        for j in range(self.max_message_length):
            symbols_j = [message[j] for derivation, message in protocol]
            symbol_mutual_info = []
            symbol_entropy = compute_entropy(symbols_j)
            for i in range(self.num_concept_slots):
                concepts_i = [derivation[i] for derivation, message in protocol]
                mutual_info = compute_mutual_information(concepts_i, symbols_j)
                symbol_mutual_info.append(mutual_info)
            symbol_mutual_info.sort(reverse=True)

            if symbol_entropy > 0:
                disentanglement_score = (symbol_mutual_info[0] - symbol_mutual_info[1]) / symbol_entropy  # pylint: disable=line-too-long
                disentanglement_scores.append(disentanglement_score)
                non_constant_positions += 1
        if non_constant_positions > 0:
            return sum(disentanglement_scores)/non_constant_positions
        else:
            return float('nan')


@gin.configurable
class BagOfWordsDisentanglement(PositionalDisentanglement):
    """Bag-of-words disentanglement"""

    name = 'bag-of-words disentanglement'

    def measure(self, protocol: base.Protocol) -> float:
        vocab = list(utils.get_vocab_from_protocol(protocol))
        num_symbols = len(vocab)
        bow_protocol = []
        for derivation, message in protocol:
            message_bow = [0 for _ in range(num_symbols)]
            for symbol in message:
                message_bow[vocab.index(symbol)] += 1
            message_bow = [str(symbol) for symbol in message_bow]
            bow_protocol.append((derivation, message_bow))
        return super().measure(bow_protocol)
