"""Activations for BERT."""
import dataclasses
import os

import h5py
import numpy as np
import tensorflow as tf
from tqdm import tqdm

from em.util import hdf5_util

###############################################################################


@dataclasses.dataclass
class BertClsActivations:
    """Activations corresponding to the [CLS] token across layers."""

    # shape = [n_examples, n_layers * d_model]
    activations: np.ndarray

    # shape = [n_examples, sequence_length]
    input_ids: np.ndarray

    # shape = [n_examples]
    labels: np.ndarray

    # shape = [n_examples, n_classes]
    logits: np.ndarray

    @classmethod
    def load(cls, filepath: str):
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            return cls(
                activations=hdf5_util.load_h5_ds(f['data/activations']),
                input_ids=hdf5_util.load_h5_ds(f['data/input_ids']).astype(np.int32),
                labels=hdf5_util.load_h5_ds(f['data/labels']),
                logits=hdf5_util.load_h5_ds(f['data/logits']),
            )


###############################################################################


@dataclasses.dataclass
class ClsActivationsComputer:
    model: tf.keras.Model

    ############################################################
    # I think needed to be hashable to work with tf.function
    
    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return self is other

    ############################################################

    def process_batch(self, batch):
        # NOTE: This cannot be a @tf.function.

        output = self.model(batch, training=False, output_hidden_states=True)

        # The first entry of the hidden_sttates corresponds to the embeddings layer.
        # This will always be the same for the CLS token, so we remove it.
        activations = output.hidden_states[1:]
        cls_activations = [a[:, 0] for a in activations]

        # Flatten the activations to a single vector per example.
        cls_activations = tf.concat([
            tf.reshape(a, [tf.shape(a)[0], -1])
            for a in cls_activations
        ], axis=-1)

        return cls_activations, output.logits


###############################################################################

@dataclasses.dataclass
class StreamingClsActivationsSaver:
    computer: ClsActivationsComputer

    n_examples: int
    sequence_length: int

    use_tqdm: bool = True

    def __post_init__(self):
        self.model = self.computer.model
        self.n_classes = self.model.num_labels

        self.d_model = self.model.config.hidden_size
        self.d_layers = self.model.config.num_hidden_layers
        self.d_activations = self.d_model * self.d_layers

        self.n_examples_processed = None

    def _initialize_file(self, file: h5py.File):
        n_examples = self.n_examples
        self.data_grp = file.create_group('data')

        int_ds = lambda n, s: self.data_grp.create_dataset(n, s, dtype=np.int32)
        flt_ds = lambda n, s: self.data_grp.create_dataset(n, s, dtype=np.float32)

        self.activations_ds = flt_ds('activations', [n_examples, self.d_activations])
        self.input_ids_ds = flt_ds('input_ids', [n_examples, self.sequence_length])
        self.labels_ds = int_ds('labels', [n_examples])
        self.logits_ds = flt_ds('logits', [n_examples, self.n_classes])

    def _write(self, h5_ds: h5py.Dataset, data: np.ndarray):
        i1 = self.n_examples_processed
        i2 = min(self.n_examples_processed + data.shape[0], self.n_examples)
        h5_ds[i1:i2] = data[:i2 - i1].numpy().astype(h5_ds.dtype)

    def _write_batch_results_to_file(self, file, input_ids, activations, logits, labels):
        self._write(self.input_ids_ds, input_ids)
        self._write(self.activations_ds, activations)
        self._write(self.labels_ds, labels)
        self._write(self.logits_ds, logits)

    def _compute_and_save_activations(self, file: h5py.File, ds: tf.data.Dataset):
        self.n_examples_processed = 0

        self._initialize_file(file)

        for examples, labels in ds:
            activations, logits = self.computer.process_batch(examples)
            self._write_batch_results_to_file(file, examples['input_ids'], activations, logits, labels)
            self.n_examples_processed += labels.shape[0]
            if self.n_examples_processed >= self.n_examples:
                break

    def compute_and_save_activations(self, filepath: str, ds: tf.data.Dataset):
        # The dataset should be batched.
        if self.use_tqdm:
            ds = tqdm(ds)

        with h5py.File(filepath, "w") as file:
            self._compute_and_save_activations(file, ds)
