"""Code for getting BERT activations from calls to models."""
import collections
import dataclasses
import itertools
from typing import List, Sequence, Union

import tensorflow as tf
from transformers.models.bert import modeling_tf_bert as hf

from mm.utils.monkey_patching import MonkeyPatcherContext


POSITIONS = frozenset(
    {
        "ATTENTION_KEY",
        "ATTENTION_QUERY",
        "ATTENTION_VALUE",
        "ATTENTION_PRE_OUTPUT",
        "ATTENTION_OUTPUT",
        #
        "FFW_INTERMEDIATE",
        "FFW_OUTPUT",
        #
        # The LAYER_OUTPUT is the output of a transformer block after
        # the add and LayerNorm.
        "LAYER_OUTPUT",
    }
)


@dataclasses.dataclass
class BertActivationsInfo:
    activations: tf.Tensor
    layer_index: int
    position: str


@dataclasses.dataclass
class BertAttentionWeightsInfo:
    weights: tf.Tensor
    layer_index: int


@dataclasses.dataclass
class BertActivationsParams:

    # Either None for all layers, a list of layer indices, or
    # the string "LAST"
    layers: Union[None, Sequence[int], str]

    positions: Sequence[str]

    activation_grouping: str = 'NONE'
    embeddings_position: str = 'POST_LAYER_NORM'

    include_attention_weights: bool = False

    def __post_init__(self):
        if isinstance(self.layers, str):
            assert self.layers == 'LAST'
        assert self.activation_grouping in ["NONE", "BLOCKWISE", "BLOCKWISE_UNIFORM"]
        assert self.embeddings_position in ["PRE_LAYER_NORM", "POST_LAYER_NORM"]
        assert len(set(self.positions) - set(POSITIONS)) == 0


class BertActivationsContext:
    def __init__(
        self, params, n_layers: int, layer_stride: int = 1
    ):
        self.params = params
        self.n_layers = n_layers

        if layer_stride != 1:
            assert (
                self.params.layers is None or self.params.layers == "LAST"
            ), "Must have params.layers as None or LAST if using non-unit layer stride."
        self.layer_stride = layer_stride

        self.monkey_patcher = MonkeyPatcherContext()

        self.monkey_patcher.patch_method(
            hf.TFBertEmbeddings, "call", self._hf_bert_embeddings_call
        )

        self.monkey_patcher.patch_method(
            hf.TFBertLayer, "call", self._hf_bert_layer_call
        )
        self.monkey_patcher.patch_method(
            hf.TFBertOutput, "call", self._hf_bert_output_call
        )
        self.monkey_patcher.patch_method(
            hf.TFBertIntermediate, "call", self._hf_bert_intermediate_call
        )
        self.monkey_patcher.patch_method(
            hf.TFBertSelfAttention,
            "transpose_for_scores",
            self._hf_bert_self_attention_transpose_for_scores,
        )
        self.monkey_patcher.patch_method(
            hf.TFBertSelfAttention, "call", self._hf_bert_self_attention_call
        )
        self.monkey_patcher.patch_method(
            hf.TFBertSelfOutput, "call", self._hf_bert_self_output_call
        )

    def __enter__(self):
        self.monkey_patcher.__enter__()

    def __exit__(self, *args):
        self.monkey_patcher.__exit__(*args)

    def _should_be_tracking_at_current_layer(self) -> bool:
        if self.params.layers == "LAST":
            return self.layer_index == self.n_layers - 1

        elif self.params.layers is None or self.layer_index in self.params.layers:
            if (
                self.params.activation_grouping != "BLOCKWISE"
                and (self.layer_index + 1) % self.layer_stride
            ):
                return False
            return True
        return False

    def _add_to_activations_buffer(self, tensor, position):
        if self._should_be_tracking_at_current_layer():
            self.activations_buffer.append(
                BertActivationsInfo(
                    activations=tensor,
                    layer_index=self.layer_index,
                    position=position,
                )
            )

    def _add_to_attention_weights_buffer(self, tensor):
        if self._should_be_tracking_at_current_layer():
            self.attention_weights_buffer.append(
                BertAttentionWeightsInfo(
                    weights=tensor,
                    layer_index=self.layer_index,
                )
            )

    def _hf_bert_embeddings_call(self, og_call, layer, *args, **kwargs):
        """The returned value is the input to the main body."""
        if self.params.embeddings_position == "PRE_LAYER_NORM":
            return self._hf_bert_embeddings_call_pre(og_call, layer, *args, **kwargs)
        elif self.params.embeddings_position == "POST_LAYER_NORM":
            return self._hf_bert_embeddings_call_post(og_call, layer, *args, **kwargs)
        else:
            raise ValueError(self.params.embeddings_position)

    def _tf_add_call_pre(self, og_call, layer, *args, **kwargs):
        ret = og_call(layer, *args, **kwargs)
        self.embeddings_buffer = ret
        # self.tape.watch(ret)
        return ret

    def _hf_bert_embeddings_call_pre(self, og_call, layer, *args, **kwargs):
        """The returned value is the input to the main body."""
        mp = MonkeyPatcherContext()
        mp.patch_method(tf.keras.layers.Add, "call", self._tf_add_call_pre)
        with mp:
            return og_call(layer, *args, **kwargs)

    def _hf_bert_embeddings_call_post(self, og_call, layer, *args, **kwargs):
        """The returned value is the input to the main body."""
        ret = og_call(layer, *args, **kwargs)
        self.embeddings_buffer = ret
        # self.tape.watch(ret)
        return ret

    def _hf_bert_layer_call(self, og_call, layer, *args, **kwargs):
        ret = og_call(layer, *args, **kwargs)
        if "LAYER_OUTPUT" in self.params.positions:
            hidden_states = ret[0]
            self._add_to_activations_buffer(hidden_states, "LAYER_OUTPUT")
        self.layer_index += 1
        return ret

    def _hf_bert_output_call(
        self, og_call, layer, hidden_states, input_tensor, training=False
    ):
        hidden_states = layer.dense(inputs=hidden_states)
        hidden_states = layer.dropout(inputs=hidden_states, training=training)

        if "FFW_OUTPUT" in self.params.positions:
            self._add_to_activations_buffer(hidden_states, "FFW_OUTPUT")

        hidden_states = layer.LayerNorm(inputs=hidden_states + input_tensor)

        return hidden_states

    def _hf_bert_intermediate_call(self, og_call, layer, *args, **kwargs):
        ret = og_call(layer, *args, **kwargs)
        if "FFW_INTERMEDIATE" in self.params.positions:
            self._add_to_activations_buffer(ret, "FFW_INTERMEDIATE")
        return ret

    def _hf_bert_self_attention_transpose_for_scores(
        self, og_fn, layer, tensor, batch_size
    ):
        name = tensor.name
        pos = self.params.positions

        # TODO: This is somewhat hacky but written for brevity. Probably make this
        # more robust later.
        if "ATTENTION_KEY" in pos and "/key/" in name:
            self._add_to_activations_buffer(tensor, "ATTENTION_KEY")
        if "ATTENTION_QUERY" in pos and "/query/" in name:
            self._add_to_activations_buffer(tensor, "ATTENTION_QUERY")
        if "ATTENTION_VALUE" in pos and "/value/" in name:
            self._add_to_activations_buffer(tensor, "ATTENTION_VALUE")

        return og_fn(layer, tensor, batch_size)

    def _hf_bert_self_attention_call(self, og_call, layer, *args, **kwargs):
        if self.params.include_attention_weights:
            mp = MonkeyPatcherContext()
            mp.patch_method(
                tf.nn, "softmax", self._self_attn_tf_softmax_for_attn_weights
            )
            with mp:
                ret = og_call(layer, *args, **kwargs)

        else:
            ret = og_call(layer, *args, **kwargs)

        if "ATTENTION_PRE_OUTPUT" in self.params.positions:
            attention_output = ret[0]
            self._add_to_activations_buffer(attention_output, "ATTENTION_PRE_OUTPUT")

        return ret

    def _self_attn_tf_softmax_for_attn_weights(self, og_call, *args, **kwargs):
        logits = kwargs["logits"]
        self._add_to_attention_weights_buffer(logits)
        return og_call(*args, **kwargs)

    def _hf_bert_self_output_call(
        self, og_call, layer, hidden_states, input_tensor, training=False
    ):
        hidden_states = layer.dense(inputs=hidden_states)
        hidden_states = layer.dropout(inputs=hidden_states, training=training)

        if "ATTENTION_OUTPUT" in self.params.positions:
            self._add_to_activations_buffer(hidden_states, "ATTENTION_OUTPUT")

        hidden_states = layer.LayerNorm(inputs=hidden_states + input_tensor)
        return hidden_states

    ###########################################################################

    # def set_tape(self, tape):
    #     self.tape = tape

    def reset_buffers(self, layer_stride=None):
        self.layer_index = 0
        self.embeddings_buffer = None
        self.activations_buffer = []
        self.attention_weights_buffer = []
        if layer_stride is not None:
            self.layer_stride = layer_stride

    def get_embeddings(self):
        return self.embeddings_buffer

    def get_activations(self) -> List[List[BertActivationsInfo]]:
        grouping = self.params.activation_grouping

        if grouping == "NONE":
            return [[a] for a in self.activations_buffer]

        elif grouping in {"BLOCKWISE", "BLOCKWISE_UNIFORM"}:
            groups = collections.defaultdict(list)
            # self.layer_stride
            for a in self.activations_buffer:
                groups[a.layer_index].append(a)
            entries = sorted(groups.items(), key=lambda x: x[0])
            acts = [v for _, v in entries]

            if self.params.layers == "LAST" or grouping == "BLOCKWISE_UNIFORM":
                return acts

            if grouping == "BLOCKWISE":
                # The following chunk of code lets us compare 1 student block
                # against multple teacher blocks if the teacher is deeper than
                # the student.
                assert not (len(acts) % self.layer_stride)
                acts = [
                    list(itertools.chain.from_iterable(acts[i : i + self.layer_stride]))
                    for i in range(0, len(acts), self.layer_stride)
                ]
                return acts

            else:
                raise ValueError(grouping)

        else:
            raise ValueError(grouping)

    def get_attention_weights(self) -> List[BertAttentionWeightsInfo]:
        return self.attention_weights_buffer
