"""Code for computing and storing the diagonal Fisher approximation."""
import dataclasses
import os
from typing import List

import tensorflow as tf
import tensorflow_probability as tfp
from tqdm import tqdm

from em.fishers import fisher_abcs

from em.util import hdf5_util
from em.util import hf_util

tfd = tfp.distributions


def batch_size_from_batch(batch):
    if isinstance(batch, tf.Tensor):
        return tf.shape(batch)[0]
    else:
        return tf.shape(batch["input_ids"])[0]


@tf.function
def compute_exact_fisher_for_batch(
    batch,
    model,
    variables,
    expectation_wrt_logits,
    *,
    per_example=False,
    # model_type: str = 'hf'
):

    ########################################

    @tf.function
    def fisher_single_example_expected(single_example_batch):
        """
        NOTE: I wrote this with Hugging Face classifiers in mind. There is
        probably a good way to do the same thing but with more customizability
        to support alternate forms of models.
        """
        with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape:
            tape.watch(variables)

            logits = model(single_example_batch, training=False)
            if not isinstance(logits, tf.Tensor):
                # This should happen for HuggingfaceModels
                logits = logits.logits

            # The batch dimension must be 1 to call the model, so we remove it
            # here.
            logits = tf.squeeze(logits, axis=0)

            log_probs = tf.nn.log_softmax(logits, axis=-1)
            probs = tf.nn.softmax(logits, axis=-1)

            sq_grads = []
            log_probs = [log_probs[i] for i in range(num_labels)]
            with tape.stop_recording():
                for i in range(num_labels):
                    log_prob = log_probs[i]
                    grad = tape.gradient(log_prob, variables)
                    sq_grad = [probs[i] * tf.square(g) for g in grad]
                    sq_grads.append(sq_grad)
        # Take the average across logits. The per-logit weight was added
        # earlier as each per-logit square gradient was weighted by the
        # probability of the class according to the output distribution.
        return [tf.reduce_sum(g, axis=0) for g in zip(*sq_grads)]

    @tf.function
    def fisher_single_example_sampled(single_example_batch):
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(variables)

            logits = model(single_example_batch, training=False).logits
            # The batch dimension must be 1 to call the model, so we remove it here.
            logits = tf.squeeze(logits, axis=0)
            log_probs = tf.nn.log_softmax(logits, axis=-1)

            chosen_index = tfd.Categorical(logits=logits).sample()
            log_prob = log_probs[chosen_index]

        grad = tape.gradient(log_prob, variables)
        return [tf.square(g) for g in grad]

    ########################################

    num_labels = model.num_labels

    if isinstance(batch, tf.Tensor):
        batch = tf.expand_dims(batch, axis=1)
    else:
        batch = {k: tf.expand_dims(v, axis=1) for k, v in batch.items()}

    if expectation_wrt_logits:
        fisher_single_example_fn = fisher_single_example_expected
    else:
        fisher_single_example_fn = fisher_single_example_sampled

    fishers = tf.vectorized_map(fisher_single_example_fn, batch)
    if per_example:
        return [tf.stack(f, axis=0) for f in fishers]
    else:
        return [tf.reduce_sum(f, axis=0) for f in fishers]


# @tf.function
# def compute_exact_fisher_for_batch(batch, model, variables, expectation_wrt_logits, *, per_example=False):
#     assert expectation_wrt_logits, "TODO: Handle sampling from logits."

#     @tf.function
#     def fisher_single_example(single_example_batch):
#         """
#         NOTE: I wrote this with Hugging Face classifiers in mind. There is
#         probably a good way to do the same thing but with more customizability
#         to support alternate forms of models.
#         """
#         with tf.GradientTape(watch_accessed_variables=False) as tape:
#             tape.watch(variables)

#             logits = model(single_example_batch, training=False).logits
#             # The batch dimension must be 1 to call the model, so we remove it
#             # here.
#             logits = tf.squeeze(logits, axis=0)
#             log_probs = tf.nn.log_softmax(logits, axis=-1)

#         probs = tf.nn.softmax(logits, axis=-1)
#         grads = tape.jacobian(log_probs, variables)
#         return [tf.einsum('i,i...->...', probs, tf.square(g)) for g in grads]

#     batch = {k: tf.expand_dims(v, axis=1) for k, v in batch.items()}

#     fishers = tf.vectorized_map(fisher_single_example, batch)
#     if per_example:
#         return [tf.stack(f, axis=0) for f in fishers]
#     else:
#         return [tf.reduce_sum(f, axis=0) for f in fishers]


# @tf.function
# def compute_exact_fisher_for_batch(batch, model, variables, expectation_wrt_logits, *, per_example=False):
#     assert expectation_wrt_logits, "TODO: Handle sampling from logits."

#     with tf.GradientTape(watch_accessed_variables=False) as tape:
#         tape.watch(variables)

#         logits = model(batch, training=False).logits
#         # The batch dimension must be 1 to call the model, so we remove it
#         # here.
#         log_probs = tf.nn.log_softmax(logits, axis=-1)

#     probs = tf.nn.softmax(logits, axis=-1)
#     grads = tape.jacobian(log_probs, variables)
#     # for g in grads:
#     #     print(g.shape)

#     if per_example:
#         ein_str = 'bi,bi...->b...'
#         return [tf.einsum('bi,bi...->b...', probs, tf.square(g)) for g in grads]
#     else:
#         ein_str = 'bi,bi...->...'

#     return [tf.einsum(ein_str, probs, tf.square(g)) for g in grads]


def compute_fisher_for_model(
    model, dataset: tf.data.Dataset, expectation_wrt_logits=True,
    variables=None,
    *, use_tqdm=True,
):
    if variables is None:
        variables = hf_util.get_mergeable_variables(model)

    fishers = [
        tf.Variable(tf.zeros(w.shape), trainable=False, name=f"fisher/{w.name}")
        for w in variables
    ]

    if use_tqdm:
        dataset = tqdm(dataset)

    n_examples = 0
    for batch, _ in dataset:
        n_examples += batch_size_from_batch(batch)
        batch_fishers = compute_exact_fisher_for_batch(
            batch, model, variables, expectation_wrt_logits=expectation_wrt_logits
        )
        for f, bf in zip(fishers, batch_fishers):
            f.assign_add(bf)

    for fisher in fishers:
        fisher.assign(fisher / float(n_examples))

    return fishers


###############################################################################


@dataclasses.dataclass
class DiagonalFisher(fisher_abcs.FisherAbc):

    fishers: List[tf.Variable]

    @classmethod
    def load(cls, filepath: str):
        filepath = os.path.expanduser(filepath)
        fishers = hdf5_util.load_variables_from_hdf5(filepath, trainable=False)
        return cls(fishers)

    def _save(self, filepath: str):
        # TODO
        raise NotImplementedError

    @property
    def n_parameters(self) -> int:
        return sum(int(tf.size(f).numpy()) for f in self.fishers)

    def as_flat_fisher(self) -> tf.Tensor:
        return tf.concat([tf.reshape(f, [-1]) for f in self.fishers], axis=0)
