"""Stuff for analysis NMFs on the annotated ANLI dev sets."""
import dataclasses
from typing import List

import numpy as np
import tensorflow_datasets as tfds
from transformers import AutoTokenizer

from em.datasets import anli
from em.datasets import annotated_anli
from em.tools.mi import continuous_discrete_mi as cd_mi


AnnotatedExample = annotated_anli.AnnotatedExample


@dataclasses.dataclass
class AaaContext:
    task: str
    coeffs: np.ndarray

    predicted_logits: np.ndarray

    def __post_init__(self):
        self.examples = self._load_examples_into_memory()

        # These must all be the same length. I am assuming that the coeffs and
        # predicted logits are passed in the same order as the dataset that we read.
        assert len(self.examples) == len(self.predicted_logits) == len(self.coeffs)

        self.predicted_labels = np.argmax(self.predicted_logits, axis=-1)

    def _load_examples_into_memory(self) -> List[AnnotatedExample]:
        ds = tfds.load(f"annotated_anli/{self.task}", split='validation')
        ds = anli.rekey_to_be_like_mnli(ds)
        return [AnnotatedExample.from_tfds(x) for x in ds.as_numpy_iterator()]

    #######################################################

    def get_contains_annotation_indicator(self, annotation_prefix: str) -> np.ndarray:
        return np.array([
            e.contains_annotation(annotation_prefix)
            for e in self.examples
        ], dtype=np.bool)


"""
Metrics for relating coefficients to binary/categorical data:
    - Mutual information?
    - AUC?
"""