"""Other stuff for LRM-PEFS."""
import dataclasses

import h5py
import numpy as np
import tensorflow as tf
from tqdm import tqdm

from em.models import em_models
from em.util import hdf5_util
from . import lrm_pefs

###############################################################################
SparseLrmPefComputer = lrm_pefs.SparseLrmPefComputer
save_h5_ds = hdf5_util.save_h5_ds

###############################################################################


def _filter_batch_via_mask(x, mask):
    if isinstance(x, tf.Tensor):
        return tf.boolean_mask(x, mask)
    else:
        return {k: tf.boolean_mask(v, mask) for k, v in x.items()}


def _coalesce_cache(cache, dtype):
    return np.concatenate(cache, axis=0).astype(dtype)


###############################################################################


@dataclasses.dataclass
class WrongsOnlyStreamingLrmPefSaver:
    """
    
    NOTE: Unlike the regular streaming pefs saver, this will hold the
    entire thing in memory and save at the end.
    """
    fisher_computer: SparseLrmPefComputer

    use_tqdm: bool = True

    ############################################################
    # I think needed to be hashable to work with tf.function
    
    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return self is other

    ############################################################

    def __post_init__(self):
        self.model = self.fisher_computer.model
        self.variables = self.fisher_computer.variables
        self.n_values_per_example = self.fisher_computer.n_values_per_example

        self.n_classes = int(self.model.num_labels)
        self.effective_n_classes = self.fisher_computer.effective_n_labels
        self.n_parameters = int(tf.reduce_sum([tf.size(v) for v in self.variables]).numpy())

    @tf.function
    def _filter_to_only_wrong_predictions(self, x, y):
        logits = em_models.compute_logits(self.model, x)
        predictions = tf.argmax(logits, axis=-1)
        keep_mask = predictions != y

        x2 = _filter_batch_via_mask(x, keep_mask)
        y2 = tf.boolean_mask(y, keep_mask)

        return x2, y2

    def _cache_batch_results(self, batch_dict, y):
        d = batch_dict
        
        self.labels_cache.append(y.numpy())
        self.logits_cache.append(d['logits'].numpy())

        self.values_cache.append(d['values'].numpy())
        self.col_offsets_cache.append(d['col_offsets'].numpy())
        self.row_indices_cache.append(d['row_indices'].numpy())

        self.pef_frobenius_norms_cache.append(d['frobenius_norms'].numpy())

    def _compute_pefs(self, ds: tf.data.Dataset):
        self.labels_cache = []
        self.logits_cache = []
        self.values_cache = []
        self.col_offsets_cache = []
        self.row_indices_cache = []
        self.pef_frobenius_norms_cache = []

        for x, y in ds:
            x, y = self._filter_to_only_wrong_predictions(x, y)
            if tf.size(y) == 0:
                continue
            batch_dict = self.fisher_computer.process_batch(x)
            self._cache_batch_results(batch_dict, y)

    def _save_pefs(self, file: h5py.File):
        # TODO
        self.data_grp = file.create_group('data')
        self.data_grp.attrs['n_classes'] = self.effective_n_classes
        self.data_grp.attrs['n_parameters'] = self.n_parameters

        save_h5_ds(self.data_grp, 'labels', _coalesce_cache(self.labels_cache, np.int32))
        save_h5_ds(self.data_grp, 'logits', _coalesce_cache(self.logits_cache, np.float32))

        save_h5_ds(self.data_grp, 'values', _coalesce_cache(self.values_cache, np.float32))
        save_h5_ds(self.data_grp, 'col_offsets', _coalesce_cache(self.col_offsets_cache, np.int32))
        save_h5_ds(self.data_grp, 'row_indices', _coalesce_cache(self.row_indices_cache, np.int32))
        
        save_h5_ds(self.data_grp, 'pef_frobenius_norms', _coalesce_cache(self.pef_frobenius_norms_cache, np.float32))

    def compute_and_save_pefs(self, filepath: str, ds: tf.data.Dataset):
        # The dataset should be batched. It must also be finite since the whole
        # dataset will be processed.
        if self.use_tqdm:
            ds = tqdm(ds)

        self._compute_pefs(ds)

        with h5py.File(filepath, "w") as file:
            self._save_pefs(file)
