"""Utility class encapsulating the data about an NLI example."""
import dataclasses
import re
from typing import Optional, Tuple, Union

import numpy as np


_UNLABELED_CHAR_ABBREV = '-'

# The '-' means no label.
_LABEL_CHAR_ABBREVS = (_UNLABELED_CHAR_ABBREV, "e", "n", "c")


def _word_chars_only(s: str) -> str:
    return re.sub(r'\W', ' ', s)


def _get_words_list(s: str) -> Tuple[str, ...]:
    # Very simple heuristic.
    words = _word_chars_only(s).strip().split(' ')
    return tuple(w for w in words if w)


@dataclasses.dataclass
class NliExample:
    premise: str
    hypothesis: str

    label_char: str
    prediction_char: str

    predicted_logits: np.ndarray

    index: Optional[int] = None

    def __post_init__(self):
        assert self.label_char in _LABEL_CHAR_ABBREVS
        assert self.prediction_char in _LABEL_CHAR_ABBREVS
        assert list(self.predicted_logits.shape) == [3]

        self._premise_lower_words = _get_words_list(self.premise.lower())
        self._hypothesis_lower_words = _get_words_list(self.hypothesis.lower())

    def is_correctly_labeled(self) -> Union[bool, None]:
        # Returns None if this example is not labeled.
        if self.label_char == _UNLABELED_CHAR_ABBREV:
            return None
        return self.label_char == self.prediction_char

    def contains_word(self, word: str) -> bool:
        return self.premise_contains_word(word) or self.hypothesis_contains_word(word)

    def premise_contains_word(self, word: str) -> bool:
        if not re.search(r'^\w+$', word):
            raise ValueError(f'Invalid word: {word}. Word must contain only word-characters.')
        return word.lower() in self._premise_lower_words

    def hypothesis_contains_word(self, word: str) -> bool:
        if not re.search(r'^\w+$', word):
            raise ValueError(f'Invalid word: {word}. Word must contain only word-characters.')
        return word.lower() in self._hypothesis_lower_words
