"""Stuff for analysis fo BERT/RoBERTa NMF decompositions.

Try to keep this file more general purpose code and less plot-generating
than bert_nmf_analysis.py.
"""

import dataclasses
from typing import Optional, Sequence, Tuple, Union

import numpy as np
import tensorflow as tf

from em.models import transformer_model_vars as tmv
from em.util import flat_pack

###############################################################################


def _layer_indices_from_var_names(names: Sequence[str]) -> Tuple[int, ...]:
    inds = {tmv.extract_layer_index(s) for s in names} - {None}
    return tuple(sorted(inds))


@dataclasses.dataclass
class ComponentLocalizationInfo:
    # NOTE: Copied from bert_nmf_analysis.py. The one in that file is now
    # deprecated.

    # We only really need these for their names and shapes.
    variables: Sequence[tf.Variable]

    def __post_init__(self):
        self._var_names = [v.name for v in self.variables]
        self._var_shapes = [v.shape for v in self.variables]

        self._layer_indices = _layer_indices_from_var_names(self._var_names)

        self._packer = flat_pack.FlatPacker(self._var_shapes)

    def fraction_per_layer(self, component: np.ndarray):
        # NOTE: The returned array might not sum up to 1 since
        # embeddings and pooler are not part of it. The entries in the
        # returned array will in increasing order of the layers present
        # in the variables.
        unpacked_comp = self._packer.decode_tf(component)

        unpacked_comp_sums = [tf.reduce_sum(c) for c in unpacked_comp]
        comp_sum = tf.reduce_sum(component).numpy()

        ret = []
        for ind in self._layer_indices:
            subset_sums = [
                s for s, n in zip(unpacked_comp_sums, self._var_names)
                if tmv.extract_layer_index(n) == ind
            ]
            if len(subset_sums) == 0:
                ret.append(0.0)
            else:
                ret.append(tf.reduce_sum(subset_sums).numpy() / comp_sum)

        return ret

    def fraction_in_pooler(self, component: np.ndarray):
        unpacked_comp = self._packer.decode_tf(component)
        comp_sum = tf.reduce_sum(component).numpy()
        pooler_sum = tf.reduce_sum([
            tf.reduce_sum(s) for s, n in zip(unpacked_comp, self._var_names)
            if tmv.is_pooler_layer(n)
        ]).numpy()
        return pooler_sum / comp_sum

    def fraction_per_variable(self, component: np.ndarray):
        comp_sum = tf.reduce_sum(component).numpy()
        return [
            tf.reduce_sum(v).numpy() / comp_sum
            for v in self._packer.decode_tf(component)
        ]

###############################################################################
