"""Human evaluation latex generation for SNLI."""
import abc
import dataclasses
import re
from typing import Optional, List, Sequence, Tuple

import numpy as np
from transformers import PreTrainedTokenizer

from npeff_torch.examination.top_examples import top_examples_common
from npeff_torch.examination.top_examples.human_evals import humev_top_examples_same_theme_latex
from npeff_torch.examination.top_examples.human_evals import humev_top_examples_theme_latex
from npeff_torch.examination.top_examples import top_examples_latex
from npeff_torch.util import latex_utils


###############################################################################
_EXAMPLE_FORMATS = ('encoder', 'lm')
###############################################################################
_DEFAULT_THEME_LATEX_INTRO = R"""
This document contains groupings of examples from SNLI. SNLI is a natural language
inference task where the goal is determine whether the premise (labeled [P]) entails,
contradicts, or is neutral with respect to the hypothesis (labeled [H]).

For each group of examples, please determine if there is some common theme among the examples
in the group. In the second column of the CSV, please write \texttt{yes}, \texttt{maybe}, or \texttt{no}
(and only those three options) depending whether you detected the presence of a theme. In you put
\texttt{yes} or \texttt{maybe}, please put a brief description of the theme in the third column of
the CSV.
"""
###############################################################################
_DEFAULT_SAME_THEME_LATEX_INTRO = R"""
This document contains groupings of examples from SNLI. SNLI is a natural language
inference task where the goal is determine whether the premise (labeled [P]) entails,
contradicts, or is neutral with respect to the hypothesis (labeled [H]).

This document contains pairs of groups of examples. For each pairing, please determine whether each group
individually contains a common theme among its examples. If both groups do contain a theme, then determine
whether the themes of the two groups are very similar. In the second column of the CSV, please write
\texttt{yes}, \texttt{maybe}, or \texttt{no} (and only those three options) depending on whether both groups
contain very similar themes. To be clear, write \texttt{yes} only if both groups contain a theme that is very
similar; write \texttt{no} if both groups contain different themes or if both groups do not contain a detectable
theme. If put \texttt{yes} or \texttt{maybe}, please put a brief description of the theme in the third column of
the CSV.
"""
###############################################################################


def _capitalize_first_letter(s: str) -> str:
    # NOTE: Won't capitalize anything if the string has leading whitespace.
    if s == '':
        return s
    return f'{s[0].upper()}{s[1:]}'


###############################################################################

class _HumevThemeMixin:

    def _remove_special_tokens(self, token_ids: Sequence[int]) -> List[int]:
        return [int(t) for t in token_ids if t not in self._all_special_tokens_ids]

    #######################################################

    def _parse_ph(self, example_info: 'top_examples_common.TopExampleInfo') -> Tuple[str, str]:
        if self.example_format == 'encoder':
            return self._parse_ph_encoder(example_info)
        elif self.example_format == 'lm':
            return self._parse_ph_lm(example_info)
        else:
            raise ValueError(self.example_format)

    def _parse_ph_encoder(self, example_info: 'top_examples_common.TopExampleInfo') -> Tuple[str, str]:
        input_ids = example_info.example['input_ids']

        sep_indices, = np.nonzero(input_ids == self.tokenizer.sep_token_id)
        split_ind = sep_indices[0] + 1

        premise_token_ids = self._remove_special_tokens(input_ids[:split_ind])
        premise = _capitalize_first_letter(self.tokenizer.decode(premise_token_ids).strip())
        premise = latex_utils.escape(premise)

        hypothesis_token_ids = self._remove_special_tokens(input_ids[split_ind:])
        hypothesis = _capitalize_first_letter(self.tokenizer.decode(hypothesis_token_ids).strip())
        hypothesis = latex_utils.escape(hypothesis)

        return premise, hypothesis

    def _parse_ph_lm(self, example_info: 'top_examples_common.TopExampleInfo') -> Tuple[str, str]:
        input_ids = example_info.example['input_ids']
        attention_mask = example_info.example['attention_mask']

        context = self.tokenizer.decode(input_ids[attention_mask != 0])

        match = re.search(
            r'^(.+)\nQuestion: (.+)\. True, False or Neither\?\nAnswer:$',
            context)

        if match:
            premise = match.group(1)
            premise = latex_utils.escape(premise)

            hypothesis = f'{match.group(2)}.'
            hypothesis = latex_utils.escape(hypothesis)
        else:
            premise = 'ERROR'
            hypothesis = 'ERROR'

        return premise, hypothesis

    #######################################################

    def make_example_latex_string(self, example_info: 'top_examples_common.TopExampleInfo') -> str:
        premise, hypothesis = self._parse_ph(example_info)

        premise = f'[P] {premise}'
        hypothesis = f'[H] {hypothesis}'

        return "\n".join([
            R'\noindent\texttt{' + premise + R' \\',
            R'{}' + hypothesis + R'\vspace{2mm}\\',
            R'}',
        ])

###############################################################################


@dataclasses.dataclass
class SnliHumevTopExamplesThemeLatexGenerator(_HumevThemeMixin, humev_top_examples_theme_latex.HumevTopExamplesThemeLatexGeneratorAbc):
    tokenizer: PreTrainedTokenizer
    
    example_format: str

    components_fontsize: Optional[str] = 'footnotesize'

    def __post_init__(self):
        assert self.example_format in _EXAMPLE_FORMATS
        self._all_special_tokens_ids: List[int] = self.tokenizer.all_special_ids

    #######################################################

    def make_latex_intro(self) -> str:
        # Should contain information to help reviewers about the task.
        return _DEFAULT_THEME_LATEX_INTRO


@dataclasses.dataclass
class SnliHumevTopExamplesSameThemeLatexGenerator(_HumevThemeMixin, humev_top_examples_same_theme_latex.HumevTopExamplesSameThemeLatexGeneratorAbc):
    tokenizer: PreTrainedTokenizer
    
    example_format: str

    components_fontsize: Optional[str] = 'footnotesize'

    def __post_init__(self):
        assert self.example_format in _EXAMPLE_FORMATS
        self._all_special_tokens_ids: List[int] = self.tokenizer.all_special_ids

    #######################################################

    def make_latex_intro(self) -> str:
        # Should contain information to help reviewers about the task.
        return _DEFAULT_SAME_THEME_LATEX_INTRO
