"""Stuff for generating per-example Fishers to share across scripts."""
import dataclasses
import functools
from typing import Any, Callable, Dict, Optional, Sequence, Union

import h5py
import numpy as np
import tensorflow as tf
from tqdm import tqdm

from em.datasets import divisibility
from em.fishers import diagonal
from em.util import sparse_tf_util

# typedefs

# Placeholder for batches of examples from the dataset. Note that this
# does not include the label. Usually tf.Tensor or Dict[str, tf.Tensor],
# but can be different depending on the dataset.
DataBatch = Union[tf.Tensor, Dict[str, tf.Tensor], Any]

# Either a sparse tensor or a dense tensor consisting of
# the indices of a sparse tensor. If a dense tensor, it should be
# an int32/int64 tensor.
SparseOrIndices = Union[tf.Tensor, tf.sparse.SparseTensor]

###############################################################################


class FlavorConfig:
    """Just a base class for some configs to inherit."""

    # def __hash__(self):
    #     # NOTE: Just a quick cheese.
    #     return 0

    def save_as_sparse(self) -> bool:
        return False

    def get_flavor(self) -> str:
        return self.__class__.__name__

    def get_dense_fisher_size(self, variables: Sequence[tf.Variable]) -> int:
        return tf.reduce_sum([tf.size(v) for v in variables]).numpy()

    def configure_given_model_variables(self, variables: Sequence[tf.Variable]):
        pass

    @tf.function
    def postprocess_per_variable_batch_fishers(
            self, batch_fishers: Sequence[tf.Tensor]) -> Sequence[tf.Tensor]:
        return batch_fishers

    @tf.function
    def postprocess_flat_batch_fishers(self, batch_fishers: tf.Tensor):
        return batch_fishers


@dataclasses.dataclass(eq=True, frozen=True)
class DenseConfig(FlavorConfig):

    def get_fisher_size(self, variables: Sequence[tf.Variable]) -> int:
        return self.get_dense_fisher_size(variables)


@dataclasses.dataclass(eq=True, frozen=True)
class SparseStaticConfig(FlavorConfig):
    # Must be the same length as the variables list with matching shapes.
    sparse_indices: Sequence[SparseOrIndices]

    def __post_init__(self):
        self.sparse_indices = [
            tf.cast(sparse_tf_util.get_sparse_indices(s), tf.int32)
            for s in self.sparse_indices
        ]

    def save_as_sparse(self) -> bool:
        # Setting this explicitly so I don't accidently set it wrong in the future.
        # Even though the fishers are dense, we can save as a dense tensor since
        # the mask is static.
        return True

    @tf.function
    def postprocess_per_variable_batch_fishers(
            self, batch_fishers: Sequence[tf.Tensor]) -> Sequence[tf.Tensor]:
        return [
            tf.vectorized_map(functools.partial(tf.gather_nd, indices=inds), f)
            for f, inds in zip(batch_fishers, self.sparse_indices)
        ]

    def get_fisher_size(self, variables: Sequence[tf.Variable]) -> int:
        del variables
        return tf.reduce_sum([si.shape[0] for si in self.sparse_indices]).numpy()


@dataclasses.dataclass(eq=True, frozen=True)
class SparseDynamicRawConfig(FlavorConfig):
    n_fisher_values_per_example: int

    def save_as_sparse(self) -> bool:
        return True

    def get_fisher_size(self, variables: Sequence[tf.Variable]) -> int:
        del variables
        return self.n_fisher_values_per_example

    @tf.function
    def postprocess_flat_batch_fishers(self, batch_fishers: tf.Tensor):
        values, indices = tf.math.top_k(batch_fishers, k=self.n_fisher_values_per_example)
        return values, indices


@dataclasses.dataclass(eq=True, frozen=True)
class SparseDynamicMetricDerivedConfig(FlavorConfig):
    n_fisher_values_per_example: int
    # Must be the same length as the variables list with matching shapes.
    pretrained_model_variables: Sequence[tf.Variable]

    def save_as_sparse(self) -> bool:
        return True

    def get_fisher_size(self, variables: Sequence[tf.Variable]) -> int:
        del variables
        return self.n_fisher_values_per_example

    def configure_given_model_variables(self, variables: Sequence[tf.Variable]):
        self._sq_parameter_deltas = _get_squared_parameter_deltas(variables, self.pretrained_model_variables)

    @tf.function
    def postprocess_flat_batch_fishers(self, batch_fishers: tf.Tensor):
        _, indices = tf.math.top_k(
            batch_fishers * self._sq_parameter_deltas,
            k=self.n_fisher_values_per_example)
        values = tf.gather(batch_fishers, indices, batch_dims=1)
        return values, indices


#######################################


def _flatten(tensors: Sequence[tf.Tensor]) -> tf.Tensor:
    return tf.concat([
        tf.reshape(t, [-1])
        for t in tensors
    ], axis=0)


def _get_squared_parameter_deltas(
    variables: Sequence[tf.Variable],
    pretrained_model_variables: Sequence[tf.Variable],
) -> tf.Tensor:
    assert len(variables) == len(pretrained_model_variables)
    return _flatten([
        (ft - pt)**2
        for ft, pt in zip(variables, pretrained_model_variables)
    ])

###############################################################################


@tf.function
def _compute_batch_fishers(
    model: tf.keras.Model,
    variables: Sequence[tf.Variable],
    batch: DataBatch,
    postprocess_per_variable_batch_fishers_fn: Callable[[Sequence[tf.Tensor]], Sequence[tf.Tensor]],
    postprocess_flat_batch_fishers_fn: Callable[[tf.Tensor], Any],
    expectation_wrt_logits: bool,
):
    n_batch_examples = diagonal.batch_size_from_batch(batch)

    batch_fishers = diagonal.compute_exact_fisher_for_batch(
        batch,
        model,
        variables,
        expectation_wrt_logits=expectation_wrt_logits,
        per_example=True,
    )

    dense_fisher_norms = [
        tf.reduce_sum(tf.reshape(f**2, [n_batch_examples, -1]), axis=-1)
        for f in batch_fishers
    ]
    dense_fisher_norms = tf.sqrt(tf.reduce_sum(dense_fisher_norms, axis=0))

    batch_fishers = postprocess_per_variable_batch_fishers_fn(batch_fishers)

    batch_fishers = [
        tf.reshape(f, [n_batch_examples, -1])
        for f in batch_fishers
    ]
    batch_fishers = tf.concat(batch_fishers, axis=-1)

    return dense_fisher_norms, postprocess_flat_batch_fishers_fn(batch_fishers)


#######################################

@dataclasses.dataclass
class PerExampleFisherGeneratorAndSaver:
    """The main class for computing and saving per-example Fishers."""
    model: tf.keras.Model
    variables: Sequence[tf.Variable]

    # Dataset should not be batched.
    dataset: tf.data.Dataset
    n_examples: int
    batch_size: int

    flavor_config: FlavorConfig

    # Must be open in write mode.
    file: h5py.File

    expectation_wrt_logits: bool = True

    def __post_init__(self):
        self._has_generated = False

        self._pe_ds = self.dataset.take(self.n_examples).batch(self.batch_size)

        self.flavor_config.configure_given_model_variables(self.variables)

        self._dense_fisher_size = self.flavor_config.get_dense_fisher_size(self.variables)
        self._fisher_size = self.flavor_config.get_fisher_size(self.variables)

    def _create_data_group(self):
        self._data_grp = self.file.create_group('data')
        self._data_grp.attrs['flavor'] = self.flavor_config.get_flavor()
        return self._data_grp

    def initialize_extra_h5_datasets(self, data_group: h5py.Group):
        # Intended to be overwritten by subclasses if needed.
        pass

    def _initialize_h5_datasets(self):
        dg = self._create_data_group()

        # TODO: Maybe create the sq_parameter_deltas if I plan to save it.

        # Norms of the entire dense diagonal Fisher.
        self._dense_fisher_norms_ds = dg.create_dataset(
            'dense_fisher_norms',
            [self.n_examples],
            dtype=np.float32)

        if self.flavor_config.save_as_sparse():
            fg = dg.create_group('fisher')
            fg.attrs['dense_fisher_size'] = self._dense_fisher_size

            self._fisher_values_ds = fg.create_dataset(
                'values',
                [self.n_examples, self._fisher_size],
                dtype=np.float32)

            self._fisher_indices_ds = fg.create_dataset(
                'indices',
                [self.n_examples, self._fisher_size],
                dtype=np.int32)

        else:
            self._fishers_ds = dg.create_dataset(
                'fishers',
                [self.n_examples, self._fisher_size],
                dtype=np.float32)

        self.initialize_extra_h5_datasets(dg)

    def write_to_extra_h5_datasets(
        self,
        dense_fisher_norms,
        batch_fishers,
        examples,
        labels,
        start_example_index: int,
        end_example_index: int,
    ):
        # Intended to be overwritten by subclasses if needed.
        pass

    def write_to_h5_dataset(
        self,
        ds: h5py.Dataset,
        batch_data: Union[np.ndarray, tf.Tensor],
        start_example_index: int,
        end_example_index: int,
    ):
        i1 = start_example_index
        i2 = end_example_index
        k = i2 - i1

        values = batch_data[:k]
        if isinstance(values, tf.Tensor):
            values = values.numpy()

        ds[i1:i2] = values.astype(ds.dtype)

    def _write_to_h5_datasets(
        self,
        dense_fisher_norms,
        batch_fishers,
        start_example_index: int,
        end_example_index: int,
    ):
        i1 = start_example_index
        i2 = end_example_index
        k = i2 - i1

        self._dense_fisher_norms_ds[i1:i2] = dense_fisher_norms[:k].numpy().astype(np.float32)

        if self.flavor_config.save_as_sparse():
            fisher_values, fisher_indices = batch_fishers
            self._fisher_values_ds[i1:i2] = fisher_values[:k].numpy().astype(np.float32)
            self._fisher_indices_ds[i1:i2] = fisher_indices[:k].numpy().astype(np.int32)
        else:
            self._fishers_ds[i1:i2] = batch_fishers[:k].numpy().astype(np.float32)

    def generate_and_save(
        self,
        use_tqdm: bool = True
    ):
        if self._has_generated:
            raise ValueError("You can only call the generate and save method once per instance.")
        self._has_generated = True

        self._initialize_h5_datasets()

        if use_tqdm:
            pe_ds = tqdm(self._pe_ds, total=self.n_examples // self.batch_size)
        else:
            pe_ds = self._pe_ds

        n_examples_processed = 0
        for examples, labels in pe_ds:
            if n_examples_processed >= self.n_examples:
                break
            n_batch_examples = diagonal.batch_size_from_batch(examples)

            dense_fisher_norms, batch_fishers = _compute_batch_fishers(
                self.model,
                self.variables,
                examples,
                postprocess_per_variable_batch_fishers_fn=self.flavor_config.postprocess_per_variable_batch_fishers,
                postprocess_flat_batch_fishers_fn=self.flavor_config.postprocess_flat_batch_fishers,
                expectation_wrt_logits=self.expectation_wrt_logits,
            )

            # The start_index is inclusive, end_index is exclusive.
            start_index = n_examples_processed
            end_index = min(n_examples_processed + n_batch_examples, self.n_examples)

            self._write_to_h5_datasets(
                dense_fisher_norms=dense_fisher_norms,
                batch_fishers=batch_fishers,
                start_example_index=start_index,
                end_example_index=end_index
            )

            self.write_to_extra_h5_datasets(
                dense_fisher_norms=dense_fisher_norms,
                batch_fishers=batch_fishers,
                examples=examples,
                labels=labels,
                start_example_index=start_index,
                end_example_index=end_index,
            )

            n_examples_processed += n_batch_examples

        if n_examples_processed < self.n_examples:
            raise ValueError('The dataset passed contained fewer than n_examples.')


##############################################################################

@dataclasses.dataclass
class TransformerClassifierGeneratorAndSaver(PerExampleFisherGeneratorAndSaver):

    # Need to set these to None default, otherwise dataclasses gets annoyed.
    sequence_length: int = None
    
    def initialize_extra_h5_datasets(self, data_group: h5py.Group):
        n_classes = self.model.num_labels
        self._input_ids_ds = data_group.create_dataset(
            'input_ids',
            [self.n_examples, self.sequence_length],
            dtype=np.int32)
        self._labels_ds = data_group.create_dataset(
            'labels',
            [self.n_examples],
            dtype=np.int32)
        self._predicted_logits_ds = data_group.create_dataset(
            'predicted_logits',
            [self.n_examples, n_classes],
            dtype=np.float32)

    def write_to_extra_h5_datasets(
        self,
        dense_fisher_norms,
        batch_fishers,
        examples,
        labels,
        start_example_index: int,
        end_example_index: int,
    ):
        predicted_logits = self.model(examples, training=False).logits

        inds = (start_example_index, end_example_index)
        self.write_to_h5_dataset(self._input_ids_ds, examples['input_ids'], *inds)
        self.write_to_h5_dataset(self._labels_ds, labels, *inds)
        self.write_to_h5_dataset(self._predicted_logits_ds, predicted_logits, *inds)


@dataclasses.dataclass
class DivisibilityGeneratorAndSaver(PerExampleFisherGeneratorAndSaver):

    # Need to set these to None default, otherwise dataclasses gets annoyed.
    #
    # Note that this is needed in addition to the `dataset` argument.
    dataset_config: divisibility.DivisibilityDatasetConfig = None

    def initialize_extra_h5_datasets(self, data_group: h5py.Group):
        # These are stored as numbers.
        self._dividends_ds = data_group.create_dataset(
            'dividends',
            [self.n_examples],
            dtype=np.int64)
        self._divisors_ds = data_group.create_dataset(
            'divisors',
            [self.n_examples],
            dtype=np.int64)

        self._labels_ds = data_group.create_dataset(
            'labels',
            [self.n_examples],
            dtype=np.int32)
        self._predicted_logits_ds = data_group.create_dataset(
            'predicted_logits',
            [self.n_examples, 2],
            dtype=np.float32)

    def write_to_extra_h5_datasets(
        self,
        dense_fisher_norms,
        batch_fishers,
        examples,
        labels,
        start_example_index: int,
        end_example_index: int,
    ):
        predicted_logits = self.model(examples, training=False)

        examples = examples.numpy()
        divisors = self.dataset_config.get_divisors_from_examples(examples)
        dividends = self.dataset_config.get_dividends_from_examples(examples)

        inds = (start_example_index, end_example_index)

        self.write_to_h5_dataset(self._divisors_ds, divisors, *inds)
        self.write_to_h5_dataset(self._dividends_ds, dividends, *inds)

        self.write_to_h5_dataset(self._labels_ds, labels, *inds)
        self.write_to_h5_dataset(self._predicted_logits_ds, predicted_logits, *inds)
