"""Computing the salience of input features to components via factorization of cross terms."""
import dataclasses
from typing import Sequence

import numpy as np
import tensorflow as tf
from transformers import TFBertForSequenceClassification
from transformers.models.bert import modeling_tf_bert

from em.fishers import lrm_pefs
from em.models import em_models
from em.tools.nmf import lrm_npeff
from em.util import monkey_patching

###############################################################################

expand_batch_dims = lrm_pefs.expand_batch_dims
flatten_batch_mpefs = lrm_pefs.flatten_batch_mpefs

LrmNpeffDecomposition = lrm_npeff.LrmNpeffDecomposition

###############################################################################


def add_dummy_batch_dim(batch):
    if isinstance(batch, tf.Tensor):
        return tf.expand_dims(batch, axis=0)
    else:
        return {k: tf.expand_dims(v, axis=0) for k, v in batch.items()}


def compute_mpef_frobenius_norms(flat_pefs: tf.Tensor) -> tf.Tensor:
    # flat_pefs.shape = [n_classes, n_params]
    AtA = tf.einsum('cj,kj->ck', flat_pefs, flat_pefs)
    sq_norms = tf.reduce_sum(tf.square(AtA), axis=[-2, -1])
    return tf.sqrt(sq_norms)


###############################################################################

@dataclasses.dataclass
class TransformerJointLrmPefComputer:
    """

    NOTE: Right now only works with BERT.

    We treat the outputs of the embeddings layer as the input features.

    Computes the PEF for only a single example at a time. This might be
    improved in the future.
    """
    model: TFBertForSequenceClassification
    variables: Sequence[tf.Variable]

    def __post_init__(self):
        self.n_labels = self.model.num_labels

        self._tape = None
        self._input_features = None

        self.mp_ctx = monkey_patching.MonkeyPatcherContext()
        self._set_up_monkey_patching()

    ############################################################
    # I think needed to be hashable to work with tf.function
    
    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return self is other

    ############################################################

    def _set_up_monkey_patching(self):
        self.mp_ctx.patch_method(modeling_tf_bert.TFBertEmbeddings, "call", self._embeddings_call_override_fn)

    def _embeddings_call_override_fn(self, original_fn, *args, **kwargs):
        output = original_fn(*args, **kwargs)
        self._input_features = output
        self._tape.watch([output])
        return output

    ############################################################

    @tf.function
    def _fisher_single_example(self, single_example):

        single_example_batch = add_dummy_batch_dim(single_example)

        with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape:
            tape.watch(self.variables)

            self._tape = tape
            with self.mp_ctx:
                logits = em_models.compute_logits(self.model, single_example_batch)

            # The batch dimension must be 1 to call the model, so we remove it
            # here.
            logits = tf.squeeze(logits, axis=0)

            log_probs = tf.nn.log_softmax(logits, axis=-1)
            probs = tf.nn.softmax(logits, axis=-1)

            weighted_params_grads = []
            weighted_inputs_grads = []
            log_probs = [log_probs[i] for i in range(self.n_labels)]
            with tape.stop_recording():
                for i in range(self.n_labels):
                    log_prob = log_probs[i]

                    params_grad = tape.gradient(log_prob, self.variables)
                    weighted_params_grad = [tf.sqrt(probs[i]) * g for g in params_grad]
                    weighted_params_grads.append(weighted_params_grad)

                    inputs_grad = tape.gradient(log_prob, self._input_features)
                    # The tf.squeeze removes the dummy batch dimension.
                    weighted_inputs_grad = tf.sqrt(probs[i]) * tf.squeeze(inputs_grad, axis=0)
                    weighted_inputs_grads.append(weighted_inputs_grad)

        params_fisher = [tf.stack(g, axis=0) for g in zip(*weighted_params_grads)]
        inputs_fisher = tf.stack(weighted_inputs_grads, axis=0)

        return params_fisher, inputs_fisher, logits

    @tf.function
    def process_example(self, example):
        # The example should NOT be batched.
        params_fisher, inputs_fisher, _ = self._fisher_single_example(example)
        return params_fisher, inputs_fisher


###############################################################################
###############################################################################


def compute_inputs_salience(
    nmf: LrmNpeffDecomposition,
    w: np.ndarray,
    flattened_params_lrm_fisher: np.ndarray,
    flattened_inputs_lrm_fisher: np.ndarray,
    n_top_components: int,
    **kwargs,
):
    # w.shape = [n_components]
    # flattened_inputs_lrm_fisher.shape = [n_classes, d_inputs]
    # flattened_params_lrm_fisher.shape = [n_classes, n_params]

    top_comp_inds = np.argsort(-w)[:n_top_components]
    w = w[top_comp_inds]
    G = nmf.G[top_comp_inds]

    params_fisher = flattened_params_lrm_fisher[:, nmf.new_to_old_col_indices]

    computer = _ComponentInputSalienceComputer(
        w=tf.cast(w, tf.float32),
        G=tf.cast(G, tf.float32),
        params_fisher=tf.cast(params_fisher, tf.float32),
        inputs_fisher=tf.cast(flattened_inputs_lrm_fisher, tf.float32),
        **kwargs,
    )

    return computer


@dataclasses.dataclass
class _ComponentInputSalienceComputer:
    # shape = [n_top_components]
    w: tf.Tensor

    # shape = [n_top_components, n_params]
    G: tf.Tensor

    # shape = [n_classes, n_params]
    params_fisher: tf.Tensor

    # shape = [n_classes, d_inputs]
    inputs_fisher: tf.Tensor

    lmbda_diag_block: float = 0.0

    def __post_init__(self):
        # TODO: Better input scaling.
        self.inputs_salience = tf.Variable(
            tf.random.normal([self.G.shape[0], self.inputs_fisher.shape[1]]) / tf.sqrt(
                float(self.G.shape[0] * self.inputs_fisher.shape[1] / 2.0)),
            dtype=tf.float32)

        self._precompute_constants()

    ############################################################
    # I think needed to be hashable to work with tf.function
    
    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return self is other

    ############################################################

    def _precompute_constants(self):
        # shape = [n_top_components, n_top_components]
        self.GG = tf.einsum('ij,kj->ik', self.G, self.G)

        # shape = [n_classes, n_top_components]
        self.pf_G = tf.einsum('ij,kj->ik', self.params_fisher, self.G)

        pfpf = tf.einsum('ij,kj->ik', self.params_fisher, self.params_fisher)
        ifif = tf.einsum('ij,kj->ik', self.inputs_fisher, self.inputs_fisher)

        # scalars
        self.tr_cross_true = tf.einsum('ij,ij->', pfpf, ifif)
        self.tr_diag_true = tf.einsum('ij,ij->', ifif, ifif)

    @tf.function
    def _compute_cross_block_loss(self):
        # Squared Frobenius norm of the difference between the reconstruction of the cross block.

        isis = tf.einsum('ij,kj->ik', self.inputs_salience, self.inputs_salience)
        term1 = tf.einsum('i,j,ij,ij->', self.w, self.w, self.GG, isis)

        ifis = tf.einsum('ij,kj->ik', self.inputs_fisher, self.inputs_salience)
        term2 = tf.einsum('j,ij,ij->', self.w, self.pf_G, ifis)

        return self.tr_cross_true - 2 * term2 + term1

    @tf.function
    def _compute_inputs_diag_block_loss(self):
        isis = tf.einsum('ij,kj->ik', self.inputs_salience, self.inputs_salience)
        term1 = tf.einsum('i,j,ij,ij->', self.w, self.w, isis, isis)

        ifis = tf.einsum('ij,kj->ik', self.inputs_fisher, self.inputs_salience)
        term2 = tf.einsum('j,ij,ij->', self.w, ifis, ifis)

        return self.tr_diag_true - 2 * term2 + term1

    @tf.function
    def _compute_loss(self):
        return self._compute_cross_block_loss()

    @tf.function
    def _gradient_update_step(self, lr):
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(self.inputs_salience)
            cross_loss = 2 * self._compute_cross_block_loss()
            if self.lmbda_diag_block > 0.0:
                diag_loss = self._compute_inputs_diag_block_loss()
            else:
                diag_loss = 0.0

            loss = cross_loss + self.lmbda_diag_block * diag_loss

        grad = tape.gradient(loss, self.inputs_salience)
        self.inputs_salience.assign_sub(lr * grad)
        return cross_loss, diag_loss
