"""Activations for ResNets."""
import dataclasses
import os

import h5py
import numpy as np
import tensorflow as tf
from tqdm import tqdm

from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.keras import layers as keras_layers

from em.models import em_models
from em.util import monkey_patching
from em.util import hdf5_util


def create_activations_model(model_ri: str):
    activations = None

    def override_fn(og_fn, *args, **kwargs):
        nonlocal activations
        ret = og_fn(*args, **kwargs)
        activations = ret
        return ret

    mctx = monkey_patching.MonkeyPatcherContext()
    mctx.patch_method(keras_layers.GlobalAveragePooling2D, '__call__', override_fn)

    with mctx:
        og_model = em_models.from_pretrained(model_ri)

    return keras_training.Model(
        og_model.input,
        {'activations': activations, 'logits': og_model.output},
        name='resnet_activations_model')

###############################################################################


@dataclasses.dataclass
class ResnetActivations:
    """Activations to the inputs into the classifier head."""

    # shape = [n_examples, d_activations]
    activations: np.ndarray

    # shape = [n_examples]
    labels: np.ndarray

    # shape = [n_examples, n_classes]
    logits: np.ndarray

    @classmethod
    def load(cls, filepath: str):
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            return cls(
                activations=hdf5_util.load_h5_ds(f['data/activations']),
                labels=hdf5_util.load_h5_ds(f['data/labels']),
                logits=hdf5_util.load_h5_ds(f['data/logits']),
            )


###############################################################################

@dataclasses.dataclass
class StreamingActivationsSaver:
    # model must come from create_activations_model
    model: tf.keras.Model

    n_examples: int
    n_classes: int
    d_activations: int

    use_tqdm: bool = True

    def __post_init__(self):
        self.n_examples_processed = None

    def _initialize_file(self, file: h5py.File):
        n_examples = self.n_examples
        self.data_grp = file.create_group('data')

        int_ds = lambda n, s: self.data_grp.create_dataset(n, s, dtype=np.int32)
        flt_ds = lambda n, s: self.data_grp.create_dataset(n, s, dtype=np.float32)

        self.activations_ds = flt_ds('activations', [n_examples, self.d_activations])
        self.labels_ds = int_ds('labels', [n_examples])
        self.logits_ds = flt_ds('logits', [n_examples, self.n_classes])

    def _write(self, h5_ds: h5py.Dataset, data: np.ndarray):
        i1 = self.n_examples_processed
        i2 = min(self.n_examples_processed + data.shape[0], self.n_examples)
        h5_ds[i1:i2] = data[:i2 - i1].numpy().astype(h5_ds.dtype)

    def _write_batch_results_to_file(self, file, activations, logits, labels):
        self._write(self.activations_ds, activations)
        self._write(self.labels_ds, labels)
        self._write(self.logits_ds, logits)

    def _compute_and_save_activations(self, file: h5py.File, ds: tf.data.Dataset):
        self.n_examples_processed = 0

        self._initialize_file(file)

        for examples, labels in ds:
            output = self.model(examples, training=False)
            activations, logits = output['activations'], output['logits']
            self._write_batch_results_to_file(file, activations, logits, labels)
            self.n_examples_processed += labels.shape[0]
            if self.n_examples_processed >= self.n_examples:
                break

    def compute_and_save_activations(self, filepath: str, ds: tf.data.Dataset):
        # The dataset should be batched.
        if self.use_tqdm:
            ds = tqdm(ds)

        with h5py.File(filepath, "w") as file:
            self._compute_and_save_activations(file, ds)
