"""Stuff related to ICA decompositions of BERT activations."""
import dataclasses
import re
from typing import Sequence

import tensorflow as tf
from transformers.models.bert import modeling_tf_bert as hf

from em.util.monkey_patching import MonkeyPatcherContext


@dataclasses.dataclass
class ClsTokenPerturbationContext:

    d_model: int
    layer_indices: Sequence[int]

    def __post_init__(self):
        # Set these before entering the context.
        self.component_variables = [tf.Variable(tf.zeros([self.d_model])) for _ in self.layer_indices]
        self.magnitude_variable = tf.Variable(0.0, dtype=tf.float32)

        self.monkey_patcher = MonkeyPatcherContext()
        self.monkey_patcher.patch_method(
            hf.TFBertLayer, "call", self._hf_bert_layer_call
        )

    def _layer_to_index(self, layer):
        m = re.search(r'layer_\._(\d+)', layer.name)
        return int(m.group(1))

    def _get_component_variable_for_layer(self, layer):
        layer_index = self._layer_to_index(layer)
        for i, v in zip(self.layer_indices, self.component_variables):
            if i == layer_index:
                return v
        return None

    def _hf_bert_layer_call(self, og_call, layer, *args, **kwargs):
        # ret.shape = [d_batch, d_sequence, d_model]
        ret = og_call(layer, *args, **kwargs)

        perturbation = self._get_component_variable_for_layer(layer)

        # Exit early if we are not perturbing this output.
        if perturbation is None:
            return ret

        acts = ret[0]
        perturbation = self.magnitude_variable * perturbation

        # We assume that CLS token is the first token.
        p2 = tf.zeros(tf.shape(acts)[1:], dtype=perturbation.dtype)
        p2 = tf.tensor_scatter_nd_update(
            p2,
            [[0]],
            [perturbation]
        )
        p2 = tf.expand_dims(p2, axis=0)

        acts = acts + p2

        return (acts,) + ret[1:]
