"""Utilities related to HANS."""
from typing import Any, Dict, List

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from em import datasets as em_datasets
from em.projects.anli import anli_misc1 as am


###############################################################################


def fix_up_hans_logits(logits):
    entailment_logit = logits[:, 1]
    non_entailment_logit = np.maximum(logits[:, 0], logits[:, 2])
    return np.stack([entailment_logit, non_entailment_logit, -1e9 * np.ones_like(entailment_logit)], axis=-1)


def fix_up_hans_logits_tf(logits):
    entailment_logit = logits[:, 1]
    non_entailment_logit = tf.maximum(logits[:, 0], logits[:, 2])
    return tf.stack([entailment_logit, non_entailment_logit, -1e9 * tf.ones_like(entailment_logit)], axis=-1)


def fix_up_hans_container(container: am.PefNmfAnalysisContainer):
    logits = fix_up_hans_logits(container.pef.predicted_logits)
    container.predicted_logits = logits
    container.predictions = np.argmax(logits, axis=-1)

    container.examples = container._make_nli_examples()


###############################################################################


def _ds_to_hans_examples(ds: tf.data.Dataset) -> List[Dict[str, Any]]:
    ds = list(ds.as_numpy_iterator())
    for x in ds:
        for k, v in list(x.items()):
            if k in ('idx', 'label'):
                continue
            x[k] = tf.compat.as_str(v)
    return ds


def get_first_hans_examples(split: str, n_examples: int, process_fn=None):
    ds = tfds.load("hans/hans", split=split)
    if process_fn is not None:
        ds = process_fn(ds)
    return _ds_to_hans_examples(ds.take(n_examples))


def get_hans_lone_with_flipped_examples(split: str):
    def filter_fn(x):
        return (x['heuristic'] == 'lexical_overlap') and (x['label'] == 1)

    ds = tfds.load("hans/hans", split=split).filter(filter_fn)
    flipped_ds = em_datasets.hans._make_flipped_label_ds(ds)

    choice_dataset = tf.data.Dataset.range(2).repeat()
    ds = tf.data.experimental.choose_from_datasets([ds, flipped_ds], choice_dataset)

    return _ds_to_hans_examples(ds)
