"""Common code for merging / ablation of HANS lone nmfs."""
import dataclasses
from typing import Optional, Sequence, Union

import numpy as np
import tensorflow as tf
from transformers import PreTrainedTokenizer
from tqdm import tqdm

from em import datasets as em_datasets
from em.merging import merging
from em.util import hf_util

from em.projects.ll import hans_components_context as HCC
from em.projects.ll import hans_util


class HansLoneMerger:

    def __init__(
        self,
        *merge_args,
        evaluation_context,
        output_model,
        example_subset_indices: Sequence[Sequence[int]],
        **merge_kwargs,
    ):
        self._merge_args = merge_args
        self._merge_kwargs = merge_kwargs

        self.evaluation_context = evaluation_context
        self.output_model = output_model
        self.example_subset_indices = example_subset_indices

    def perform_merges(self, coeffs_set: Union[int, Sequence[Sequence[float]]]):
        if isinstance(coeffs_set, int):
            coeffs_set = merging.create_pairwise_grid_coeffs(coeffs_set)

        for coeffs in coeffs_set:
            merging._merge_with_coeffs(
                *self._merge_args, **self._merge_kwargs
            )
            eval_results = self.evaluation_context.evaluate(self.output_model)

            # TODO
            raise NotImplementedError('TODO')


###############################################################################
###############################################################################


@dataclasses.dataclass
class MnliEvaluationResults(HCC.EvaluationResultsMixIn):
    labels: np.ndarray
    logits: np.ndarray
    og_logits: Optional[np.ndarray] = None

    def __post_init__(self):
        self.set_up_derived_attributes()


@dataclasses.dataclass
class MnliEvaluationContext:
    n_examples: int
    tokenizer: PreTrainedTokenizer

    sequence_length: int = 64
    batch_size: int = 128

    og_logits: Optional[np.ndarray] = None
    labels: Optional[np.ndarray] = None

    def __post_init__(self):
        self.ds = self._load_ds()

    def _load_ds(self):
        ds = em_datasets.load('glue/mnli', split='validation',
                              sequence_length=self.sequence_length, tokenizer=self.tokenizer)
        ds = em_datasets.glue.fix_text_attack_mnli_labeling(ds)
        return ds.take(self.n_examples).cache()

    ############################################################

    def _compute_labels_and_logits(self, model):
        labels, logits = [], []
        for x, y in self.ds.batch(self.batch_size):
            labels.append(y.numpy())
            batch_logits = model(x, training=False).logits.numpy()
            logits.append(batch_logits)
        labels = np.concatenate(labels, axis=0)
        logits = np.concatenate(logits, axis=0)
        return labels, logits

    def set_up_og_data(self, model):
        self.labels, self.og_logits = self._compute_labels_and_logits(model)

    def evaluate(self, model):
        labels, logits = self._compute_labels_and_logits(model)
        return MnliEvaluationResults(
            labels=labels,
            logits=logits,
            og_logits=self.og_logits,
        )


###############################################################################
###############################################################################


def get_ds_by_example_indices(full_ds, example_indices: Sequence[int]):
    # TODO: This is pretty slow, I think. This was pretty hacky, so there
    # are likely a bunch of ways to speed this up.
    example_indices = set(example_indices)
    examples = [
        ex for i, ex in enumerate(full_ds)
        if i in example_indices
    ]

    def gen():
        yield from examples

    return tf.data.Dataset.from_generator(
        gen, output_signature=full_ds.element_spec
    ).cache()


###############################################################################

def get_loss_gradient(
    model,
    variables,
    ds,
    *,
    normalize_gradients_by_example: bool = False,
    use_tqdm: bool = True,
):
    fishers = [tf.Variable(tf.zeros_like(v), trainable=False) for v in variables]

    @tf.function
    def compute_gradient_for_batch(x, y):
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(variables)
            logits = model(x, training=False).logits
            logits = hans_util.fix_up_hans_logits_tf(logits)
            loss = model.compiled_loss(y, logits)

        grads = tape.gradient(loss, variables)
        if normalize_gradients_by_example:
            grads_sq_mag = tf.reduce_sum([tf.reduce_sum(g**2) for g in grads])

        for f, g in zip(fishers, grads):
            if normalize_gradients_by_example:
                f.assign_add(g * tf.math.rsqrt(grads_sq_mag))
            else:
                f.assign_add(g)

    if normalize_gradients_by_example:
        # TODO: Support some level of batching here.
        ds = ds.unbatch().batch(1)

    for x, y in tqdm(ds) if use_tqdm else ds:
        compute_gradient_for_batch(x, y)

    return fishers


def apply_sign_guide(variables, sign_guide, delta: float):
    assert len(variables) == len(sign_guide)
    return [
        v + delta * tf.sign(sg)
        for v, sg in zip(variables, sign_guide)
    ]
