"""Code for some PEF stuff.

Basically my second pass at some of this.
"""
import dataclasses
from typing import Optional

import h5py
import numpy as np
import tensorflow as tf
from tqdm import tqdm

from em import datasets as em_datasets
from em.models import em_models

from . import diagonal2

###############################################################################

FISHER_FLAVORS = ('full', 'sparse_dynamic_raw')

###############################################################################


def _flatten_fishers(fishers):
    fishers = [
        tf.reshape(f, [tf.shape(f)[0], -1])
        for f in fishers
    ]
    return tf.concat(fishers, axis=-1)


def _compute_fisher_norms(flat_fishers):
    return tf.sqrt(tf.reduce_sum(tf.square(flat_fishers), axis=-1))


###############################################################################

class _FileHelper:

    def __init__(
        self,
        saver: 'StreamingPefSaver',
        file: h5py.File,
        n_examples: int,
        sequence_length: Optional[int],
    ):
        # sequence_length will be None for non-text datasets, so set it to 1
        # to save dummy input_ids to file.
        self.has_input_ids = sequence_length is not None
        if sequence_length is None:
            sequence_length = 1

        self.saver = saver
        self.file = file
        self.n_examples = n_examples
        self.sequence_length = sequence_length

        self.n_examples_processed = 0

        self.dense_fisher_size = self._compute_dense_fisher_size()
        self.saved_fisher_size = self._compute_saved_fisher_size()

        self.data_grp = file.create_group('data')
        self.data_grp.attrs['flavor'] = saver.flavor

        self._create_common_datasets()
        self._create_fisher_datasets()

    ############################################################

    def _compute_dense_fisher_size(self):
        return tf.reduce_sum([tf.size(v) for v in self.saver.variables]).numpy()

    def _compute_saved_fisher_size(self):
        flavor = self.saver.flavor
        if flavor == 'full':
            return self.dense_fisher_size
        elif flavor == 'sparse_dynamic_raw':
            return self.saver.n_fisher_values_per_example
        else:
            raise ValueError(flavor)

    def _create_common_datasets(self):
        self.labels_ds = self.data_grp.create_dataset(
            'labels',
            [self.n_examples],
            dtype=np.int32)
        self.predicted_logits_ds = self.data_grp.create_dataset(
            'predicted_logits',
            [self.n_examples, self.saver.model.num_labels],
            dtype=np.float32)

        self.input_ids_ds = self.data_grp.create_dataset(
            'input_ids',
            [self.n_examples, self.sequence_length],
            dtype=np.int32)

    def _create_fisher_datasets(self):
        flavor = self.saver.flavor

        # Norms of the entire dense diagonal Fisher.
        self.dense_fisher_norms_ds = self.data_grp.create_dataset(
            'dense_fisher_norms',
            [self.n_examples],
            dtype=np.float32)

        if flavor == 'full':
            self.fishers_ds = self.data_grp.create_dataset(
                'fishers',
                [self.n_examples, self.saved_fisher_size],
                dtype=np.float32)

        elif flavor == 'sparse_dynamic_raw':
            self.fisher_grp = self.data_grp.create_group('fisher')
            self.fisher_grp.attrs['dense_fisher_size'] = self.dense_fisher_size

            self.fisher_values_ds = self.fisher_grp.create_dataset(
                'values',
                [self.n_examples, self.saved_fisher_size],
                dtype=np.float32)
            self.fisher_indices_ds = self.fisher_grp.create_dataset(
                'indices',
                [self.n_examples, self.saved_fisher_size],
                dtype=np.int32)

        else:
            raise ValueError(flavor)

    ############################################################

    def _write(self, h5_ds: h5py.Dataset, data: np.ndarray):
        i1 = self.n_examples_processed
        i2 = min(self.n_examples_processed + data.shape[0], self.n_examples)
        h5_ds[i1:i2] = data[:i2 - i1].numpy().astype(h5_ds.dtype)

    def _write_fishers(self, fishers):
        if isinstance(fishers, tuple):
            fisher_values, fisher_indices = fishers
            self._write(self.fisher_values_ds, fisher_values)
            self._write(self.fisher_indices_ds, fisher_indices)
        else:
            self._write(self.fishers_ds, fishers)

    def _write_input_ids(self, examples):
        if self.has_input_ids:
            input_ids = examples['input_ids']

        else:
            # Just write some indices in place of input ids.
            batch_size = diagonal2.batch_size_from_batch(examples)
            input_ids = tf.cast(
                list(range(self.n_examples_processed, self.n_examples_processed + batch_size)),
                tf.int32,
            )

        self._write(self.input_ids_ds, input_ids)

    ############################################################

    def write(self, *, examples, labels, logits, norms, fishers):
        if self.n_examples_processed >= self.n_examples:
            return True

        self._write_fishers(fishers)
        self._write_input_ids(examples)
        self._write(self.labels_ds, labels)
        self._write(self.predicted_logits_ds, logits)
        self._write(self.dense_fisher_norms_ds, norms)

        self.n_examples_processed += diagonal2.batch_size_from_batch(examples)


###############################################################################

@dataclasses.dataclass
class StreamingPefSaver:
    fisher_computer: diagonal2.FisherComputer

    flavor: str

    # Required for "sparse_dynamic_*":
    n_fisher_values_per_example: Optional[int] = None

    use_tqdm: bool = True

    def __post_init__(self):
        assert self.flavor in FISHER_FLAVORS

        if self.flavor == 'sparse_dynamic_raw':
            assert self.n_fisher_values_per_example is not None

        self.model = self.fisher_computer.model
        self.variables = self.fisher_computer.variables

    ############################################################
    # I think needed to be hashable to work with tf.function

    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return self is other

    ############################################################

    @tf.function
    def _compute_fishers_and_logits_for_batch(self, batch):
        logits, fishers = self.fisher_computer.compute_exact_fisher_and_logits_for_batch(batch)

        fishers = _flatten_fishers(fishers)
        norms = _compute_fisher_norms(fishers)

        # The topk seems to be taking up the most time. Bigger batch sizes help.
        # TODO: Try approx top k.
        if self.flavor == "sparse_dynamic_raw":
            # See https://gist.github.com/redzhepdx/1e65bd2a721ff7fbf532 for possible Cuda implementation of quickselect.
            # NOTE: I think that is for doing a bunch of quick selects at once, which isn't really what I want.
            #
            # https://hpc.fau.de/files/2021/03/2021-03-23_ribizel.pdf
            # https://icl.utk.edu/files/publications/2019/icl-utk-1230-2019.pdf
            values, indices = tf.math.top_k(fishers, k=self.n_fisher_values_per_example)
            return logits, norms, (values, indices)

        return logits, norms, fishers

    def _compute_and_save_pefs(self, file: h5py.File, ds: tf.data.Dataset, n_examples: int):
        helper = _FileHelper(self,
                             file,
                             n_examples=n_examples,
                             sequence_length=em_datasets.infer_sequence_length(ds))
        if self.use_tqdm:
            ds = tqdm(ds)

        for examples, labels in ds:
            logits, norms, fishers = self._compute_fishers_and_logits_for_batch(examples)
            should_break = helper.write(
                examples=examples,
                labels=labels,
                logits=logits,
                norms=norms,
                fishers=fishers,
            )
            if should_break:
                break

    ############################################################

    def compute_and_save_pefs(self, filepath: str, ds: tf.data.Dataset, n_examples: int):
        with h5py.File(filepath, "w") as file:
            self._compute_and_save_pefs(file, ds, n_examples=n_examples)
