"""Stuff for labeling HANS examples by words, heuristic, etc.

See:
    https://github.com/tommccoy1/hans/blob/master/templates.py
    https://github.com/tommccoy1/hans/blob/master/corpus_generator.py

"""
import collections
import dataclasses
import re
from typing import Dict, List, Optional, Sequence, Union

import numpy as np

from . import hans_labeling_constants as C

# typedefs
Indicator = np.ndarray  # 1-d boolean array of shape [n_examples]
IndicatorDict = Dict[str, Indicator]

###############################################################################


def _binary_op(dict1: Optional[IndicatorDict], dict2: Optional[IndicatorDict], op) -> Optional[IndicatorDict]:
    if dict1 is None or dict2 is None:
        return None
    assert set(dict1.keys()) == set(dict2.keys())
    return {
        k: op(dict1[k], dict2[k])
        for k in dict1.keys()
    }


def _reduce_any(indicators: Optional[IndicatorDict]) -> Optional[Indicator]:
    if indicators is None:
        return None
    return np.any(list(indicators.values()), axis=0)


###############################################################################


"""
Maybe TODOs:
    - Map singular words to plural words
    - Have singular/plural indicator
    - Combination indicators (like word in premise and hypothesis, word only in premise, word only in hypo, ...)
    - preps, rels, advs, Both individual words and overall presence.

Doing former TODOs:
    - Have in premise/hypothesis indicators
"""


@dataclasses.dataclass
class HansIndicator:
    """Wrapper clas containing all indicators."""
    template_indicators: Optional[IndicatorDict] = None
    subcase_indicators: Optional[IndicatorDict] = None

    # The p,h suffixes stand for premise, hypothesis.
    noun_indicators_p: Optional[IndicatorDict] = None
    noun_indicators_h: Optional[IndicatorDict] = None

    # The p,h suffixes stand for premise, hypothesis.
    verb_indicators_p: Optional[IndicatorDict] = None
    verb_indicators_h: Optional[IndicatorDict] = None

    # The p,h suffixes stand for premise, hypothesis.
    preps_indicators_p: Optional[IndicatorDict] = None
    preps_indicators_h: Optional[IndicatorDict] = None

    # The p,h suffixes stand for premise, hypothesis.
    rels_indicators_p: Optional[IndicatorDict] = None
    rels_indicators_h: Optional[IndicatorDict] = None

    # The p,h suffixes stand for premise, hypothesis.
    advs_indicators_p: Optional[IndicatorDict] = None
    advs_indicators_h: Optional[IndicatorDict] = None

    # NOTE: I need to update these as I add stuff.
    _PH_DERIVED_FIELD_BASES = ('noun_indicators', 'verb_indicators', 'preps_indicators', 'rels_indicators', 'advs_indicators')
    _OTHER_DERIVED_FIELDS = ('any_prep_indicator', 'any_rel_indicator', 'any_adv_indicator')

    def __post_init__(self):
        # Create all of the generic derived versions of these.
        for key_base in self._PH_DERIVED_FIELD_BASES:
            self._create_ph_deriveds(key_base)

        # Create indicators for any member of a word group appearing
        # anywhere in the examples.
        self.any_prep_indicator = _reduce_any(self.preps_indicators_ph_or)
        self.any_rel_indicator = _reduce_any(self.rels_indicators_ph_or)
        self.any_adv_indicator = _reduce_any(self.advs_indicators_ph_or)

    def _or_ph(self, key_base: str) -> Optional[IndicatorDict]:
        p_inds = getattr(self, f'{key_base}_p')
        h_inds = getattr(self, f'{key_base}_h')
        return _binary_op(p_inds, h_inds, lambda x, y: x | y)

    def _create_ph_deriveds(self, key_base: str):
        p = getattr(self, f'{key_base}_p')
        h = getattr(self, f'{key_base}_h')

        setattr(self, f'{key_base}_ph_or', _binary_op(p, h, lambda p, h: p | h))
        setattr(self, f'{key_base}_ph_and', _binary_op(p, h, lambda p, h: p & h))
        setattr(self, f'{key_base}_p_only', _binary_op(p, h, lambda p, h: p & ~h))
        setattr(self, f'{key_base}_h_only', _binary_op(p, h, lambda p, h: ~p & h))

    #######################################################
    # Public Methods

    def as_dict(self) -> Dict[str, Union[IndicatorDict, Indicator]]:
        # NOTE: I need to update this as I add stuff.
        ret = {f.name: getattr(self, f.name) for f in dataclasses.fields(self)}

        for f in self._PH_DERIVED_FIELD_BASES:
            ret[f'{f}_ph_or'] = getattr(self, f'{f}_ph_or')
            ret[f'{f}_ph_and'] = getattr(self, f'{f}_ph_and')
            ret[f'{f}_p_only'] = getattr(self, f'{f}_p_only')
            ret[f'{f}_h_only'] = getattr(self, f'{f}_h_only')

        for f in self._OTHER_DERIVED_FIELDS:
            ret[f] = getattr(self, f)

        return ret

    def iterate_over_indicators(self):
        dikt = self.as_dict()
        for k1, v1 in dikt.items():
            if v1 is None:
                continue
            elif isinstance(v1, dict):
                for k2, v2 in v1.items():
                    yield (k1, k2), v2
            else:
                assert isinstance(v1, np.ndarray)
                yield (k1,), v1


###############################################################################

def compute_full_indicator(examples, label_only: Optional[str] = None) -> HansIndicator:
    indicator = HansIndicator(
        template_indicators=compute_template_indicators(examples, label_only),
        subcase_indicators=compute_subcase_indicators(examples, label_only),
        #
        noun_indicators_p=_compute_word_indicators(examples, C.nouns, 'premise'),
        noun_indicators_h=_compute_word_indicators(examples, C.nouns, 'hypothesis'),
        #
        verb_indicators_p=_compute_word_indicators(examples, C.verbs, 'premise'),
        verb_indicators_h=_compute_word_indicators(examples, C.verbs, 'hypothesis'),
        #
        preps_indicators_p=_compute_word_indicators(examples, C.preps, 'premise'),
        preps_indicators_h=_compute_word_indicators(examples, C.preps, 'hypothesis'),
        #
        rels_indicators_p=_compute_word_indicators(examples, C.rels, 'premise'),
        rels_indicators_h=_compute_word_indicators(examples, C.rels, 'hypothesis'),
        #
        advs_indicators_p=_compute_word_indicators(examples, C.advs, 'premise'),
        advs_indicators_h=_compute_word_indicators(examples, C.advs, 'hypothesis'),
    )
    return indicator


###############################################################################


def compute_template_indicators(examples, label_only: Optional[str] = None) -> IndicatorDict:
    templates = C.get_templates(label_only)

    ret = {t: [] for t in templates}
    for x in examples:
        for t in templates:
            ret[t].append(x['template'] == t)

    return {
        k: np.array(v, dtype=bool)
        for k, v in ret.items()
    }


def compute_subcase_indicators(examples, label_only: Optional[str] = None) -> IndicatorDict:
    subcases = C.get_subcases(label_only)

    ret = {t: [] for t in subcases}
    for x in examples:
        for t in subcases:
            ret[t].append(x['subcase'] == t)

    return {
        k: np.array(v, dtype=bool)
        for k, v in ret.items()
    }


###############################################################################


def _compute_word_indicators(examples, words: Sequence[str], key: str) -> IndicatorDict:
    words = [w.lower() for w in words]

    ret = {w: [] for w in words}
    for x in examples:
        sw = _split_to_normalized_words(x[key])
        for w in words:
            ret[w].append(w in sw)

    return {
        k: np.array(v, dtype=bool)
        for k, v in ret.items()
    }


def _split_to_normalized_words(s: str) -> List[str]:
    s = s.lower()
    s = re.sub(r'\W', lambda m: ' ' if m.group(0) == ' ' else '', s)  # Get rid of non-word characters.
    words = s.split(' ')
    return [w for w in words if w]
