"""Set of examples used in perturbations."""
import dataclasses
from typing import Dict, Optional, Sequence, Tuple, Union

import numpy as np
import tensorflow as tf

from em.models import em_models


##########################################################################
# typedefs

# Note that the np.ndarrays should have their first dimension be the batch dimension.

TextExamples = Dict[str, np.ndarray]
ImageExamples = np.ndarray

Examples = Union[TextExamples, ImageExamples]
LabeledExamples = Tuple[Examples, np.ndarray]


##########################################################################

def slice_labeled_examples(examples: LabeledExamples, example_indices: np.ndarray):
    examples, labels = examples
    if isinstance(examples, np.ndarray):
        examples_slice = examples[example_indices]
    else:
        examples_slice = {
            k: v[example_indices]
            for k, v in examples.items()
        }
    return examples_slice, labels[example_indices]

##########################################################################


@dataclasses.dataclass
class ExamplesContext:
    labeled_examples: LabeledExamples
    og_logits: np.ndarray

    def __post_init__(self):
        self.labels = self.labeled_examples[1]

    @property
    def n_examples(self) -> int:
        return self.og_logits.shape[0]

    def get_ds(self, example_indices: Optional[np.ndarray] = None) -> tf.data.Dataset:
        if example_indices is None:
            return tf.data.Dataset.from_tensor_slices(self.labeled_examples)
        return tf.data.Dataset.from_tensor_slices(
            slice_labeled_examples(self.labeled_examples, example_indices))

    def get_og_logits_labeled_ds(self, example_indices: Optional[np.ndarray] = None) -> tf.data.Dataset:
        ex_ds = self.get_ds(example_indices).map(lambda x, y: x)
        logits = self.og_logits if example_indices is None else self.og_logits[example_indices]
        return tf.data.Dataset.zip((ex_ds, tf.data.Dataset.from_tensor_slices(logits)))

    def evaluate(self, model, batch_size: int, example_indices: Optional[np.ndarray] = None) -> 'EvaluationResults':
        labels, logits = [], []
        for x, y in self.get_ds(example_indices).batch(batch_size):
            labels.append(y.numpy())
            logits.append(em_models.compute_logits(model, x).numpy())
        return EvaluationResults(
            labels=np.concatenate(labels, axis=0),
            logits=np.concatenate(logits, axis=0),
            og_logits=self.og_logits[example_indices] if example_indices is not None else self.og_logits,
        )

    @classmethod
    def from_dataset(cls, batched_ds: tf.data.Dataset, model: tf.keras.Model):
        unbatched_ds = batched_ds.unbatch()
        n_examples = len([None for _ in unbatched_ds])
        for labeled_examples in unbatched_ds.batch(n_examples).as_numpy_iterator():
            break

        logits = []
        for x, _ in batched_ds:
            batch_logits = em_models.compute_logits(model, x)
            logits.append(batch_logits.numpy())

        return cls(
            labeled_examples=labeled_examples,
            og_logits=np.concatenate(logits, axis=0),
        )

    @classmethod
    def from_pef(cls, pef, tokenizer):
        return cls(
            labeled_examples=pef.get_full_examples(tokenizer, trim=True),
            og_logits=pef.predicted_logits,
        )


##########################################################################

@dataclasses.dataclass
class EvaluationResults:
    labels: np.ndarray
    logits: np.ndarray

    og_logits: np.ndarray

    def __post_init__(self):
        self.set_up_derived_attributes()

    def set_up_derived_attributes(self):
        self.predictions = np.argmax(self.logits, axis=-1)
        self.correct_predictions = self.predictions == self.labels

    ############################################################

    def acc(self) -> float:
        return self.correct_predictions.astype(np.float64).mean()

    def acc_for_examples(self, example_indices: Sequence[int]) -> float:
        example_indices = np.array(list(sorted(example_indices)), dtype=np.int32)
        return self.correct_predictions[example_indices].astype(np.float64).mean()

    ############################################################

    def _loss(self, labels, logits) -> float:
        return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True).numpy().mean()

    def loss(self) -> float:
        return self._loss(self.labels, self.logits)

    def loss_for_examples(self, example_indices: Sequence[int]) -> float:
        example_indices = np.array(list(sorted(example_indices)), dtype=np.int32)
        return self._loss(self.labels[example_indices], self.logits[example_indices])

    ############################################################

    def _kl(self, og_logits, logits) -> float:
        assert og_logits is not None
        return tf.keras.losses.kl_divergence(tf.math.softmax(logits), tf.math.softmax(og_logits)).numpy().mean()

    def kl(self) -> float:
        return self._kl(self.og_logits, self.logits)

    def kl_for_examples(self, example_indices: Sequence[int]) -> float:
        example_indices = np.array(list(sorted(example_indices)), dtype=np.int32)
        return self._kl(self.og_logits[example_indices], self.logits[example_indices])
