"""Stuff for ablation and merging of QQP."""
import dataclasses
from typing import Sequence

import numpy as np
import tensorflow as tf
from tqdm import tqdm

from em.fishers import diagonal

from em.projects.ll import hans_util

from em.projects.pi import qqp_components_context as QAE

###############################################################################


@dataclasses.dataclass
class SignGuider:
    mc: QAE.QqpModelContext

    def __post_init__(self):
        self.special_processing = self.mc.special_processing
        self.all_examples = self.mc.pef.get_full_examples(self.mc.tokenizer)

    ############################################################

    def get_ds_for_examples(self, example_indices: Sequence[int]) -> tf.data.Dataset:
        example_indices = np.array(list(example_indices), dtype=np.int32)
        return tf.data.Dataset.from_tensor_slices(
            QAE.slice_examples_dict(self.all_examples, example_indices))

    ############################################################

    def compute_loss_gradient(
        self,
        ds,
        coeff_ds=None,
        *,
        normalize_by_example: bool = False,
        use_tqdm: bool = True,
    ):
        return self._compute_gradients_helper(
            ds, coeff_ds, tf.identity,
            normalize_by_example=normalize_by_example,
            use_tqdm=use_tqdm,
        )

    def apply_sign_guide(self, variables, sign_guide, delta: float):
        assert len(variables) == len(sign_guide)
        return [
            v + delta * tf.sign(sg)
            for v, sg in zip(variables, sign_guide)
        ]

    def apply_gradient(self, variables, gradient, delta: float):
        assert len(variables) == len(gradient)
        return [
            v + delta * sg
            for v, sg in zip(variables, gradient)
        ]

    ############################################################
    
    def compute_fisher(
        self,
        ds,
        coeff_ds=None,
        *,
        normalize_by_example: bool = False,
        use_tqdm: bool = True,
    ):
        assert self.special_processing != "HANS"

        if normalize_by_example:
            raise NotImplementedError('TODO')

        model = self.mc.model
        variables = self.mc.variables

        fishers = [tf.Variable(tf.zeros_like(v), trainable=False) for v in variables]

        # Don't make this a tf.function since it recompiles each time that way.
        def do_step(x, coeff):
            batch_fishers = diagonal.compute_exact_fisher_for_batch(
                batch=x,
                model=model,
                variables=variables,
                expectation_wrt_logits=True,
            )
            for f, g in zip(fishers, batch_fishers):
                f.assign_add(coeff * g)

        if normalize_by_example or coeff_ds is not None:
            # TODO: Support some level of batching here.
            ds = ds.unbatch().batch(1)
            if coeff_ds is not None:
                coeff_ds = list(coeff_ds.unbatch())
        coeff = tf.constant(1.0, dtype=tf.float32)
        for i, (x, _) in tqdm(enumerate(ds)) if use_tqdm else enumerate(ds):
            if coeff_ds is not None:
                coeff = coeff_ds[i]
            do_step(x, coeff)

        return fishers

    ############################################################
    
    def _compute_gradients_helper(
        self,
        ds,
        coeff_ds,
        grad_postprocess_fn,
        *,
        normalize_by_example: bool = False,
        use_tqdm: bool = True,
    ):
        model = self.mc.model
        variables = self.mc.variables

        acc_grads = [tf.Variable(tf.zeros_like(v), trainable=False) for v in variables]

        @tf.function
        def compute_gradient_for_batch(x, y, coeff):
            with tf.GradientTape(watch_accessed_variables=False) as tape:
                tape.watch(variables)
                logits = model(x, training=False).logits
                if self.special_processing == 'HANS':
                    logits = hans_util.fix_up_hans_logits_tf(logits)[..., :2]
                loss = model.compiled_loss(y, logits)

            grads = tape.gradient(loss, variables)
            grads = [grad_postprocess_fn(g) for g in grads]
            if normalize_by_example:
                grads_sq_mag = tf.reduce_sum([tf.reduce_sum(tf.square(g)) for g in grads])

            for f, g in zip(acc_grads, grads):
                if normalize_by_example:
                    f.assign_add(coeff * g * tf.math.rsqrt(grads_sq_mag))
                else:
                    f.assign_add(coeff * g)

        if normalize_by_example or coeff_ds is not None:
            # TODO: Support some level of batching here.
            ds = ds.unbatch().batch(1)
            if coeff_ds is not None:
                coeff_ds = list(coeff_ds.unbatch())
        coeff = tf.constant(1.0, dtype=tf.float32)
        for i, (x, y) in tqdm(enumerate(ds)) if use_tqdm else enumerate(ds):
            if coeff_ds is not None:
                coeff = coeff_ds[i]
            compute_gradient_for_batch(x, y, coeff)

        return acc_grads

