R"""The dev sets of ANLI with annotations.

See https://arxiv.org/pdf/2010.12729.pdf for the paper.

The data is available at https://github.com/facebookresearch/anli/tree/main/anlizinganli


To download and prepare:


from em.datasets import annotated_anli
from em.util import vat_da_faak_vpn

for config in annotated_anli.AnnotatedAnli.BUILDER_CONFIGS:
    builder = annotated_anli.AnnotatedAnli(config=config)
    builder.download_and_prepare()


"""
import csv
import dataclasses
import os
from typing import FrozenSet, Sequence, Tuple

import tensorflow as tf
import tensorflow_datasets as tfds

from . import anli
from . import glue


ANNOTATED_ANLI_TASK_NAMES = ('r1', 'r2', 'r3', 'hard')


###############################################################################

_CITATION = R"""
@article{williams2020anlizing,
  title={Anlizing the adversarial natural language inference dataset},
  author={Williams, Adina and Thrush, Tristan and Kiela, Douwe},
  journal={arXiv preprint arXiv:2010.12729},
  year={2020}
}
"""

_DESCRIPTION = """
Adversarial NLI (ANLI) is a large-scale NLI benchmark dataset, collected via an
iterative, adversarial human-and-model-in-the-loop procedure.
"""

_FILE_PATTERN = "https://raw.githubusercontent.com/facebookresearch/anli/main/anlizinganli/anli_annot_v0.2/ANLI_analysis_{}_dev.csv"

VERSION = tfds.core.Version("0.1.0")


class AnnotatedAnliConfig(tfds.core.BuilderConfig):
    def __init__(self, **kwargs):
        super().__init__(version=VERSION, **kwargs)


class AnnotatedAnli(tfds.core.GeneratorBasedBuilder):
    VERSION = tfds.core.Version('1.0.0')

    BUILDER_CONFIGS = [
        AnnotatedAnliConfig(name='r1', description='Round one.'),
        AnnotatedAnliConfig(name='r2', description='Round two.'),
        AnnotatedAnliConfig(name='r3', description='Round three.'),
        AnnotatedAnliConfig(name='hard', description='Hard.'),
    ]

    def _info(self):
        return tfds.core.DatasetInfo(
            builder=self,
            description=_DESCRIPTION,
            features=tfds.features.FeaturesDict({
                "uid": tfds.features.Text(),
                "context": tfds.features.Text(),
                "hypothesis": tfds.features.Text(),
                "label": tfds.features.ClassLabel(names=["e", "n", "c"]),
                # Feature not in original ANLI:
                "annotations": tfds.features.Sequence(tfds.features.Text()),
                "reason": tfds.features.Text(),
            }),
            supervised_keys=None,
            homepage="https://github.com/facebookresearch/anli",
            citation=_CITATION,
        )

    def _split_generators(self, dl_manager):
        csv_filepath = dl_manager.download(
            _FILE_PATTERN.format(self._builder_config.name))

        return [
            tfds.core.SplitGenerator(
                name=tfds.Split.VALIDATION,
                gen_kwargs={
                    "csv_filepath": csv_filepath
                }),
        ]

    def _generate_examples(self, csv_filepath):
        with open(csv_filepath, 'r') as f:
            reader = csv.DictReader(f, quotechar='"', delimiter=',',
                                    quoting=csv.QUOTE_ALL, skipinitialspace=True)
            for row in reader:
                annotations = row["A1Code"].split(',')
                annotations = [a.strip() for a in annotations]
                annotations = [a for a in annotations if a]

                yield row["uid"], {
                    "uid": row["uid"],
                    "context": row["context"],
                    "hypothesis": row["statement"],
                    "label": row["gold_label"],
                    "reason": row["reason"],
                    "annotations": annotations
                }

###############################################################################


def load(
    task: str,
    split: str,
    tokenizer,
    sequence_length: int,
):
    if task not in ANNOTATED_ANLI_TASK_NAMES:
        raise ValueError(f'Invalid anli task: {task}')

    ds = tfds.load(f"annotated_anli/{task}", split=split)

    # Basically treat like MNLI.
    ds = anli.rekey_to_be_like_mnli(ds)
    ds = glue.convert_dataset_to_features(
        ds,
        tokenizer,
        sequence_length,
        task='mnli',
    )
    return ds


def n_classes_for_task(task: str) -> int:
    return 3


def de_facto_validation_split(task):
    return 'validation'


def examples_per_epoch(task):
    return None


###############################################################################
###############################################################################

def _get_all_annotation_prefixes(annotations: Sequence[str]) -> FrozenSet[str]:
    # We do NOT treat the empty string as a prefix.
    prefixes = set()
    for a in annotations:
        splits = a.split('-')
        for i, _ in enumerate(splits):
            prefixes.add('-'.join(splits[:i + 1]))
    return frozenset(prefixes)


###########################################################


ANNOTATIONS = [
    'NUMERICAL-CARDINAL',
    'NUMERICAL-CARDINAL-AGE',
    'NUMERICAL-CARDINAL-COUNTING',
    'NUMERICAL-CARDINAL-DATES',
    'NUMERICAL-CARDINAL-NOMINAL',
    'NUMERICAL-CARDINAL-NOMINAL-DATES',
    'NUMERICAL-ORDINAL',
    'NUMERICAL-ORDINAL-DATES',
    #
    'BASIC-0',
    'BASIC-CAUSEEFFECT',
    'BASIC-COMPARATIVESUPERLATIVE',
    'BASIC-CONJUNCTION',
    'BASIC-IDIOM',
    'BASIC-LEXICAL-0',
    'BASIC-NEGATION',
    #
    'REASONING-0',
    'REASONING-CONTAINMENT-LOCATION',
    'REASONING-CONTAINMENT-TIME',
    'REASONING-DEBATABLE',
    'REASONING-FACTS',
    'REFERENCE-COREFERENCE',
    'REFERENCE-FAMILY',
    'REFERENCE-NAMES',
    #
    'TRICKY-EXHAUSTIFICATION',
    'TRICKY-PRESUPPOSITION',
    'TRICKY-SYNTACTIC',
    'TRICKY-TRANSLATION',
    'TRICKY-WORDPLAY',
    #
    'IMPERFECTIONS-0',
    'IMPERFECTIONS-AMBIGUITY',
    'IMPERFECTIONS-ERROR',
    'IMPERFECTIONS-NONNATIVE',
    'IMPERFECTIONS-SPELLING',
    #
    'EVENTCOREF',
]


ANNOTATION_PREFIXES = _get_all_annotation_prefixes(ANNOTATIONS)


###########################################################

def _ensure_np(x):
    if isinstance(x, tf.Tensor):
        return x.numpy()
    return x


@dataclasses.dataclass
class AnnotatedExample:
    idx: str

    premise: str
    hypothesis: str

    label: int

    annotations: Tuple[str, ...]
    reason: str

    def __post_init__(self):
        self.annotations = tuple(a.upper() for a in self.annotations)

    #############################################

    def contains_annotation(self, annotation_prefix: str) -> bool:
        return any(a.startswith(annotation_prefix) for a in self.annotations)
    
    def contains_all_annotations(self, annotation_prefixes: Sequence[str]) -> bool:
        # The assertion prevents the silent error that results from accidentally passing a string.
        assert not isinstance(annotation_prefixes, str)
        return all(self.contains_annotation(ap) for ap in annotation_prefixes)

    #############################################
    
    def contains_annotation_exactly(self, annotation: str) -> bool:
        return any(a == annotation for a in self.annotations)

    def contains_all_annotations_exactly(self, annotations: Sequence[str]) -> bool:
        # The assertion prevents the silent error that results from accidentally passing a string.
        assert not isinstance(annotations, str)
        return all(self.contains_annotation_exactly(a) for a in annotations)

    #############################################

    @classmethod
    def from_tfds(cls, example):
        
        def get_str(field: str) -> str:
            return tf.compat.as_str(_ensure_np(example[field]))

        annotations = tuple(
            tf.compat.as_str(_ensure_np(a))
            for a in example['annotations']
        )

        return cls(
            idx=get_str('idx'),
            premise=get_str('premise'),
            hypothesis=get_str('hypothesis'),
            label=int(_ensure_np(example['label'])),
            annotations=annotations,
            reason=get_str('reason'),
        )
        