"""Utilities related to sign patterns in the perturbation set-up."""
from typing import Sequence

import numpy as np
import tensorflow as tf

from em.fishers import diagonal
from em.models import em_models
from em.perturbations import perturbation_exp_util as pe_util
from em.perturbations import examples_context

ExamplesContext = examples_context.ExamplesContext
SoftmaxKlDivergenceLoss = pe_util.SoftmaxKlDivergenceLoss


@tf.function
def _compute_loss_gradient_for_batch(model, variables, x, y):
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(variables)
        logits = em_models.compute_logits(model, x, training=False)
        loss = model.compiled_loss(y, logits)
    return tape.gradient(loss, variables)


def compute_loss_gradient(
    model: tf.keras.Model,
    variables: Sequence[tf.Variable],
    ds: tf.data.Dataset,
):
    # The ds should be batched and finite.
    acc_grads = [tf.Variable(tf.zeros_like(v), trainable=False) for v in variables]
    total_ex = 0
    for x, y in ds:
        batch_size = diagonal.batch_size_from_batch(x)
        batch_grads = _compute_loss_gradient_for_batch(model, variables, x, y)
        for a, g in zip(acc_grads, batch_grads):
            a.assign_add(float(batch_size) * g)
        total_ex += batch_size
    for g in acc_grads:
        g.assign(g / float(total_ex))
    return acc_grads


def compute_kl_gradient(
    model: tf.keras.Model,
    variables: Sequence[tf.Variable],
    examples_context: ExamplesContext,
    example_indices: Sequence[int],
    batch_size: int = 32,
    *,
    allow_recompile: bool = False
):
    # Returns the gradient of the KL-divergence with the original predictions.
    if not isinstance(model.loss, SoftmaxKlDivergenceLoss):
        if not allow_recompile:
            raise ValueError
        model.compile(
            loss=SoftmaxKlDivergenceLoss(),
            metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
        )

    example_indices = np.array(example_indices, dtype=np.int32)

    ds = examples_context.get_og_logits_labeled_ds(example_indices)
    ds = ds.batch(batch_size)

    return compute_loss_gradient(model, variables, ds)
