"""Code for computing and storing the diagonal Fisher approximation.

Basically my second pass at some of this.
"""
import dataclasses
from typing import Optional, Sequence

import tensorflow as tf

from em.models import em_models

from . import diagonal


batch_size_from_batch = diagonal.batch_size_from_batch


###############################################################################


def _expand_batch_dims(batch):
    if isinstance(batch, tf.Tensor):
        return tf.expand_dims(batch, axis=1)
    else:
        return {k: tf.expand_dims(v, axis=1) for k, v in batch.items()}


# def _batch_to_list(batch):
#     if isinstance(batch, tf.Tensor):
#         return list(batch)
#     else:
#         return [
#             {k: v[i] for k, v in batch.items()}
#             for i in range(batch_size_from_batch(batch))
#         ]


###############################################################################
FISHER_ALGORITHMS = (
    'vectorized',
    'fast_compile',
)


@dataclasses.dataclass
class FisherComputer:
    model: tf.keras.Model

    variables: Sequence[tf.Variable]

    per_example: bool = True
    expectation_wrt_logits: bool = True

    logits_batch_size: Optional[int] = None

    algorithm: str = FISHER_ALGORITHMS[0]

    min_prob_class: float = 0.0

    def __post_init__(self):
        assert self.algorithm in FISHER_ALGORITHMS

        if not self.expectation_wrt_logits:
            raise NotImplementedError('TODO: Support sampling.')

        if self.min_prob_class != 0.0 and self.algorithm != 'fast_compile':
            raise NotImplementedError

        self.n_labels = self.model.num_labels

        self._logits_batch_ranges = self._make_logits_batch_ranges()

        if self.algorithm == 'fast_compile':
            self._fisher_acc_vars = [
                tf.Variable(tf.zeros_like(v), trainable=False)
                for v in self.variables
            ]

    ############################################################
    # I think needed to be hashable to work with tf.function
    
    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return self is other

    ############################################################

    def _make_logits_batch_ranges(self):
        bs = self.logits_batch_size or self.n_labels

        n_full_batches = self.n_labels // bs
        partial_batch_size = self.n_labels % bs

        if n_full_batches == 0:
            return ((0, partial_batch_size),)

        ranges = [
            (i * bs, (i + 1) * bs)
            for i in range(n_full_batches)
        ]
        if partial_batch_size:
            ranges.append((ranges[-1][1], self.n_labels))

        return tuple(ranges)

    ############################################################

    @tf.function
    def _vectorizable_compute_fisher(self, single_example_batch):
        n_variables = len(self.variables)
        fishers = [tf.zeros_like(v) for v in self.variables]
        
        with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape:
            tape.watch(self.variables)
            logits = em_models.compute_logits(self.model, single_example_batch)

            # The batch dimension must be 1 to call the model, so we remove it here.
            logits = tf.squeeze(logits, axis=0)
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            probs = tf.nn.softmax(logits, axis=-1)

            with tape.stop_recording():
                for start, end in self._logits_batch_ranges:
                    sq_grads = []
                    for i in range(start, end):
                        log_prob = log_probs[i]
                        grad = tape.gradient(log_prob, self.variables)
                        sq_grad = [probs[i] * tf.square(g) for g in grad]
                        sq_grads.append(sq_grad)

                    sq_grads2 = list(zip(*sq_grads))
                    for j in range(n_variables):
                        fishers[j] = fishers[j] + tf.reduce_sum(sq_grads2[j], axis=0)

        fishers = [tf.reduce_sum(g, axis=0) for g in fishers]

        return logits, fishers

    @tf.function
    def _vectorized__compute_exact_fisher_and_logits_for_batch(self, batch):
        batch = _expand_batch_dims(batch)
        logits, fishers = tf.vectorized_map(self._vectorizable_compute_fisher, batch)

        if self.per_example:
            fishers = [tf.stack(f, axis=0) for f in fishers]
        else:
            fishers = [tf.reduce_sum(f, axis=0) for f in fishers]

        return logits, fishers

    ############################################################

    @tf.function
    def _fast_compile__fisher_for_example(self, single_example_batch):
        # single_example_batch is assumed to have the dummy batch dimension of 1.
        for f in self._fisher_acc_vars:
            f.assign(tf.zeros_like(f))

        variables = self.variables

        with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape:
            tape.watch(variables)
            logits = em_models.compute_logits(self.model, single_example_batch)

            # The batch dimension must be 1 to call the model, so we remove it here.
            logits = tf.squeeze(logits, axis=0)
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            probs = tf.nn.softmax(logits, axis=-1)

            for i in tf.range(self.n_labels):
                log_prob = log_probs[i]

                with tape.stop_recording():
                    if probs[i] < self.min_prob_class:
                        continue
                    grad = tape.gradient(log_prob, variables)
                    for f, g in zip(self._fisher_acc_vars, grad):
                        f.assign_add(probs[i] * tf.square(g))

        fishers = [tf.identity(f) for f in self._fisher_acc_vars]
        return logits, fishers

    # @tf.function
    # def _fast_compile__fisher_for_example(self, single_example_batch):
    #     # # single_example_batch is assumed to have the dummy batch dimension of 1.

    #     variables = self.variables

    #     with tf.GradientTape(watch_accessed_variables=False) as tape:
    #         tape.watch(variables)
    #         logits = em_models.compute_logits(self.model, single_example_batch)

    #         # The batch dimension must be 1 to call the model, so we remove it here.
    #         logits = tf.squeeze(logits, axis=0)
    #         log_probs = tf.nn.log_softmax(logits, axis=-1)
    #         probs = tf.nn.softmax(logits, axis=-1)

    #         selected_log_probs = tf.boolean_mask(log_probs, probs >= self.min_prob_class)

    #     selected_probs = tf.boolean_mask(probs, probs >= self.min_prob_class)

    #     grads = tape.jacobian(selected_log_probs, variables, experimental_use_pfor=False)

    #     fishers = [
    #         tf.einsum('c,c...->...', selected_probs, tf.square(g))
    #         for g in grads
    #     ]
    #     return logits, fishers

    @tf.function
    def _fast_compile__compute_exact_fisher_and_logits_for_batch(self, batch):
        tf.debugging.assert_equal(batch_size_from_batch(batch), 1, 'TODO: Support batch sizes other than 1.')
        
        logits, fishers = self._fast_compile__fisher_for_example(batch)
        if self.per_example:
            logits = tf.expand_dims(logits, axis=0)
            fishers = [tf.expand_dims(f, axis=0) for f in fishers]

        return logits, fishers

    # @tf.function
    # def _fast_compile__compute_exact_fisher_and_logits_for_batch(self, batch):
    #     batch = _expand_batch_dims(batch)
    #     logits, fishers = tf.map_fn(
    #         self._fast_compile__fisher_for_example,
    #         batch,
    #         fn_output_signature=(tf.float32, tf.float32),
    #     )
    #     if not self.per_example:
    #         fishers = [tf.reduce_sum(f, axis=0) for f in fishers]
    #     return logits, fishers

    ############################################################
    
    def compute_exact_fisher_and_logits_for_batch(self, batch):
        if self.algorithm == 'vectorized':
            return self._vectorized__compute_exact_fisher_and_logits_for_batch(batch)
        elif self.algorithm == 'fast_compile':
            return self._fast_compile__compute_exact_fisher_and_logits_for_batch(batch)
        else:
            raise ValueError(self.algorithm)
