"""Code for generating low-rank representations of PEF matrices.

The low rank matrices are stored densely, so this code is mostly
for intial development of the method on small models.
"""
from typing import Sequence

import numpy as np
import tensorflow as tf
from tqdm import tqdm

from em.fishers import diagonal
from em.models import em_models


###############################################################################

def _expand_batch_dims(batch):
    if isinstance(batch, tf.Tensor):
        return tf.expand_dims(batch, axis=1)
    else:
        return {k: tf.expand_dims(v, axis=1) for k, v in batch.items()}


###############################################################################


@tf.function
def compute_m_pefs_for_batch(
    model,
    variables,
    batch,
):

    @tf.function
    def fisher_single_example(single_example_batch):

        with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape:
            tape.watch(variables)

            logits = em_models.compute_logits(model, single_example_batch)

            # The batch dimension must be 1 to call the model, so we remove it
            # here.
            logits = tf.squeeze(logits, axis=0)

            log_probs = tf.nn.log_softmax(logits, axis=-1)
            probs = tf.nn.softmax(logits, axis=-1)

            weighted_grads = []
            log_probs = [log_probs[i] for i in range(num_labels)]
            with tape.stop_recording():
                for i in range(num_labels):
                    log_prob = log_probs[i]
                    grad = tape.gradient(log_prob, variables)
                    weighted_grad = [tf.sqrt(probs[i]) * g for g in grad]
                    weighted_grads.append(weighted_grad)

        return [tf.stack(g, axis=0) for g in zip(*weighted_grads)]

    ########################################

    num_labels = model.num_labels

    batch = _expand_batch_dims(batch)

    # fishers = tf.vectorized_map(fisher_single_example, batch)
    fishers = tf.map_fn(
        fisher_single_example,
        batch,
        fn_output_signature=len(variables) * [tf.float32],
    )

    # Parallel to variables list.
    # Each item's shape is [batch, n_classes, *var_shape]
    # return [tf.stack(f, axis=0) for f in fishers]
    return fishers


def _flatten_batch_pefs(batch_pefs: Sequence[tf.Tensor]) -> tf.Tensor:
    # output.shape = [batch_size, n_classes, n_params]
    return tf.concat([
        tf.reshape(p,
                   tf.concat([tf.shape(p)[:2], [-1]], axis=0))
        for p in batch_pefs
    ], axis=-1)


def _normalize_batch_pefs(flat_batch_pefs: tf.Tensor) -> tf.Tensor:
    # flat_batch_pefs.shape = [batch_size, n_classes, n_params]
    AtA = tf.einsum('bcj,bkj->bck', flat_batch_pefs, flat_batch_pefs)
    sq_norms = tf.reduce_sum(tf.square(AtA), axis=[-2, -1], keepdims=True)
    # The double sqrt is NOT a bug. We need to use the sqrt of the Frobenius
    # norm of A^tA since we are rescaling A, which is like the square root
    # of the actual PEF matrix.
    inv_norm_factor = tf.sqrt(tf.sqrt(sq_norms))
    return flat_batch_pefs / inv_norm_factor


def compute_flat_m_pefs_for_ds(
    model: tf.keras.Model,
    variables: Sequence[tf.Variable],
    unlabeled_ds: tf.data.Dataset,
    n_examples: int,
    batch_size: int,
    *,
    normalize_pefs: bool = False,
):
    # The ds should be unbatched and contain only the inputs, not
    # an (input, label) tuple.
    ds = unlabeled_ds.take(n_examples).batch(batch_size)
    flat_pefs = []
    for batch in tqdm(ds):
        batch_pefs = compute_m_pefs_for_batch(model, variables, batch)
        batch_pefs = _flatten_batch_pefs(batch_pefs)
        if normalize_pefs:
            batch_pefs = _normalize_batch_pefs(batch_pefs)
        flat_pefs.append(batch_pefs)
    return tf.concat(flat_pefs, axis=0)
