R"""HANS (Heuristic Analysis for NLI Systems) dataset.

https://arxiv.org/pdf/1902.01007.pdf
https://github.com/tommccoy1/hans


Looks like both train and validation sets have 30k examples.


To download and prepare:


from em.datasets import hans
from em.util import vat_da_faak_vpn

for config in hans.Hans.BUILDER_CONFIGS:
    builder = hans.Hans(config=config)
    builder.download_and_prepare()


"""
import json

import tensorflow as tf
import tensorflow_datasets as tfds

from em.projects.ll import hans_parsing

from . import glue

###############################################################################


_HEURISTICS = ("lexical_overlap", "subsequence", "constituent")

_TASK_NAMES = (
    'default',
    *_HEURISTICS,
    *[f'{h}_ye' for h in _HEURISTICS],
    *[f'{h}_ne' for h in _HEURISTICS],
    #
    'lexical_overlap_ne_with_flipped',
)


def load(
    task: str,
    split: str,
    tokenizer,
    sequence_length: int,
):
    if task not in _TASK_NAMES:
        raise ValueError(f'Invalid HANS task: {task}')

    ds = tfds.load("hans/hans", split=split)

    if task.endswith('_with_flipped'):
        task = task[:-len('_with_flipped')]
        with_flipped = True
    else:
        with_flipped = False

    if task.endswith('_ye'):
        task = task[:-3]
        label_filter = 0
    elif task.endswith('_ne'):
        task = task[:-3]
        label_filter = 1
    else:
        label_filter = None

    if task in _HEURISTICS:
        ds = ds.filter(_filter_by_heuristic_fn(task))

    if label_filter is not None:
        ds = ds.filter(_filter_by_label_fn(label_filter))

    if with_flipped:
        flipped_ds = _make_flipped_label_ds(ds)
        choice_dataset = tf.data.Dataset.range(2).repeat()
        ds = tf.data.experimental.choose_from_datasets([ds, flipped_ds], choice_dataset)

    # Basically treat like MNLI except that we have two
    # classes instead of 3.
    ds = glue.convert_dataset_to_features(
        ds,
        tokenizer,
        sequence_length,
        task='mnli',
    )
    return ds


def n_classes_for_task(task: str) -> int:
    return 2


def de_facto_validation_split(task):
    return 'validation'


def examples_per_epoch(task):
    return None
    # return 30_000


###############################################################################

def _filter_by_heuristic_fn(heuristic):
    def filter_fn(x):
        return x['heuristic'] == heuristic
    return filter_fn


def _filter_by_label_fn(label):
    def filter_fn(x):
        return x['label'] == label
    return filter_fn


def _make_flipped_label_ds(ds: tf.data.Dataset):

    def make_flipped_label_ds_py_fn(template, premise):
        template = tf.compat.as_str(template.numpy())
        premise = tf.compat.as_str(premise.numpy())
        return hans_parsing.get_flipped_label_hypothesis(template, premise)

    def map_fn(x):
        new_hypothesis = tf.py_function(
            make_flipped_label_ds_py_fn,
            [x['template'], x['premise']],
            tf.string,
        )
        new_label = 1 - x['label']

        ret = x.copy()
        ret['hypothesis'] = new_hypothesis
        ret['label'] = new_label

        return ret

    return ds.map(map_fn)


###############################################################################
###############################################################################


_FILE_PATTERN = "https://raw.githubusercontent.com/tommccoy1/hans/master/heuristics_{split}_set.jsonl"


class Hans(tfds.core.GeneratorBasedBuilder):
    VERSION = tfds.core.Version('1.0.2')

    BUILDER_CONFIGS = [
        tfds.core.BuilderConfig(name='hans', description='HANS (Heuristic Analysis for NLI Systems)', version=VERSION)
    ]

    def _info(self):
        return tfds.core.DatasetInfo(
            builder=self,
            description="TODO",
            features=tfds.features.FeaturesDict({
                "idx": tf.int64,
                "premise": tfds.features.Text(),
                "hypothesis": tfds.features.Text(),
                "label": tfds.features.ClassLabel(names=["entailment", "non-entailment"]),
                # Feature not in original ANLI:
                # heuristic belongs to {"lexical_overlap", "subsequence", "constituent"}
                "heuristic": tfds.features.Text(),
                "template": tfds.features.Text(),
                "subcase": tfds.features.Text(),
            }),
            supervised_keys=None,
            homepage="https://github.com/tommccoy1/hans",
            citation="TODO",
        )

    def _split_generators(self, dl_manager):
        filepaths = dl_manager.download({
            'train': _FILE_PATTERN.format(split='train'),
            'validation': _FILE_PATTERN.format(split='evaluation'),
        })

        return [
            tfds.core.SplitGenerator(
                name=tfds.Split.TRAIN,
                gen_kwargs={
                    "filepath": filepaths['train']
                }),
            tfds.core.SplitGenerator(
                name=tfds.Split.VALIDATION,
                gen_kwargs={
                    "filepath": filepaths['validation']
                }),
        ]

    def _make_example(self, line: str):
        js = json.loads(line)

        assert js['pairID'].startswith('ex')
        idx = int(js['pairID'][2:])

        return {
            'idx': idx,
            'premise': js['sentence1'],
            'hypothesis': js['sentence2'],
            'label': js['gold_label'],
            'heuristic': js['heuristic'],
            'template': js['template'],
            'subcase': js['subcase'],
        }

    def _generate_examples(self, filepath):
        with open(filepath, 'rt') as f:
            for line in f:
                ex = self._make_example(line)
                yield ex['idx'], ex
    