"""Shared code for seeing how close the sparse approximations to LRM Fishers are."""
import dataclasses

import numpy as np
import tensorflow as tf
from tqdm import tqdm

from em.fishers import lrm_pefs


###############################################################################

def _compute_lrm_frobenius_inner_product(A: tf.Tensor, B: tf.Tensor) -> tf.Tensor:
    # A,B.shape = [batch_size, n_classes, n_params]
    # return.shape = [batch_size]
    AtB = tf.einsum('bcj,bkj->bck', A, B)
    return tf.reduce_sum(tf.square(AtB), axis=[-2, -1])

###############################################################################


@dataclasses.dataclass
class ImpactComputer:
    fisher_computer: lrm_pefs.SparseLrmPefComputer

    n_examples: int

    use_tqdm: bool = True

    def __post_init__(self):
        self.n_labels = self.fisher_computer.n_labels
        self.n_values_per_example = self.fisher_computer.n_values_per_example

    ############################################################
    # I think needed to be hashable to work with tf.function
    
    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return self is other

    ############################################################
    # @tf.function
    def _make_inds1(self, args):
        ex_row_indices, n_params = args
        # ex_row_indices.shape = [n_classes + 1]
        ret = [
            i * tf.ones([ex_row_indices[i + 1] - ex_row_indices[i]], dtype=tf.int32)
            for i in range(self.n_labels)
        ]
        ret = tf.concat(ret, axis=0)
        return tf.reshape(ret, [n_params])

    @tf.function
    def _make_sp_fishers(self, dense_fishers, values, col_offsets, row_indices):
        # dense_fishers.shape = [batch_size, n_classes, n_params]
        #
        # The return sparse fishers are represented as a dense matrix with the same
        # shape as dense_fishers.
        batch_size = tf.shape(dense_fishers)[0]
        # n_classes = tf.shape(dense_fishers)[1]
        n_params1 = tf.shape(dense_fishers)[-1:]

        # inds0.shape = [batch_size, n_values_per_example]
        inds0 = tf.range(batch_size, dtype=tf.int32)
        inds0 = inds0[:, None] * tf.ones(tf.shape(values), dtype=tf.int32)

        # inds1.shape = [batch_size, n_values_per_example]
        inds1 = tf.vectorized_map(self._make_inds1, (col_offsets, n_params1))

        # inds2.shape = [batch_size, n_values_per_example]
        inds2 = row_indices

        inds = tf.stack([
            tf.reshape(inds0, [-1]),
            tf.reshape(inds1, [-1]),
            tf.reshape(inds2, [-1]),
        ], axis=-1)

        sp_fishers = tf.zeros_like(dense_fishers)
        sp_fishers = tf.tensor_scatter_nd_add(
            sp_fishers,
            inds,
            tf.reshape(values, [-1])
        )

        return sp_fishers

    @tf.function
    def _process_batch(self, batch):
        dense_fishers, _ = self.fisher_computer.compute_dense_lrm_pefs_and_logits_for_batch(batch)
        dense_fishers = lrm_pefs.flatten_batch_mpefs(dense_fishers)

        values, col_offsets, row_indices = self.fisher_computer.sparsify_batch_mpefs(dense_fishers)

        sp_fishers = self._make_sp_fishers(dense_fishers, values, col_offsets, row_indices)

        dnF2 = _compute_lrm_frobenius_inner_product(dense_fishers, dense_fishers)
        spF2 = _compute_lrm_frobenius_inner_product(sp_fishers, sp_fishers)
        spdn = _compute_lrm_frobenius_inner_product(sp_fishers, dense_fishers)

        sq_dists = dnF2 + spF2 - 2 * spdn
        dists = tf.sqrt(sq_dists)
        norm_dists = dists / tf.maximum(1e-7, tf.sqrt(dnF2))

        return norm_dists

    def run(self, ds: tf.data.Dataset) -> np.ndarray:
        # The dataset should be batched.
        if self.use_tqdm:
            ds = tqdm(ds)

        ret = []
        n_examples_processed = 0

        for x, y in ds:
            norm_dists = self._process_batch(x)
            ret.append(norm_dists.numpy())

            n_examples_processed += y.shape[0]
            if n_examples_processed >= self.n_examples:
                break

        ret = np.concatenate(ret, axis=0)
        return ret[:self.n_examples]
