"""Things related to latex generation for top examples of components over SNLI examples."""
import abc
import dataclasses
import re
from typing import Optional, List, Sequence, Tuple

import numpy as np
from transformers import PreTrainedTokenizer

from npeff_torch.examination.top_examples import component_filtering
from npeff_torch.examination.top_examples import top_examples_common
from npeff_torch.examination.top_examples import top_examples_latex
from npeff_torch.util import latex_utils


###############################################################################


def _capitalize_first_letter(s: str) -> str:
    # NOTE: Won't capitalize anything if the string has leading whitespace.
    if s == '':
        return s
    return f'{s[0].upper()}{s[1:]}'


###############################################################################

class _NpeffComponentLatexGeneratorBase(top_examples_latex.LatexGeneratorAbc):

    @abc.abstractmethod
    def _parse_ph(self, example_info: 'top_examples_common.TopExampleInfo') -> Tuple[str, str]:
        raise NotImplementedError

    #######################################################

    def _to_latex_label(self, label: int) -> str:
        s = self._padded_label_strings[label]
        color = self.label_colors[label]
        # The bold requires `\usepackage{bold-extra}`
        return R'{\color{' + color + R'}\textbf{' + s + R'}}'

    def _make_example_top_line(self, example_info: 'top_examples_common.TopExampleInfo') -> str:
        line = []

        if example_info.label is not None:
            line.append(f'[LABEL] {self._to_latex_label(example_info.label)}')

        if example_info.logits is not None:
            line.append(f'[PRED] {self._to_latex_label(np.argmax(example_info.logits))}')

        s_coeff = f'{example_info.coefficient:.4f}'
        line.append(f'[COEFF] {s_coeff}')

        return R' {} '.join(line)

    #######################################################

    def make_example_for_component_latex_string(self, component_index: int, example_info: 'top_examples_common.TopExampleInfo') -> str:
        premise, hypothesis = self._parse_ph(example_info)

        line1 = self._make_example_top_line(example_info)
        line2 = f'[P] {premise}'
        line3 = f'[H] {hypothesis}'

        return "\n".join([
            R'\noindent\texttt{' + line1 + R'\vspace{1mm} \\',
            R'{}' + line2 + R'\vspace{1mm} \\',
            R'{}' + line3 + R'\vspace{2mm}\\',
            R'}',
        ])


###############################################################################


@dataclasses.dataclass
class NpeffComponentLatexGenerator(_NpeffComponentLatexGeneratorBase):
    
    top_examples_reader: 'top_examples_common.TopExamplesReaderAbc'
    tokenizer: PreTrainedTokenizer

    n_top_examples: int

    label_strings: Tuple[str, str, str] = ('entailment', 'neutral', 'contradiction')
    label_colors: Tuple[str, str, str] = ('ForestGreen', 'RoyalBlue', 'BrickRed')

    components_fontsize: Optional[str] = 'footnotesize'

    component_filter: Optional['component_filtering.ComponentFilterAbc'] = None

    def __post_init__(self):
        self._padded_label_strings = tuple(latex_utils.rpad_curly_to_max_length(self.label_strings))
        self._all_special_tokens_ids: List[int] = self.tokenizer.all_special_ids

    #######################################################

    def _remove_special_tokens(self, token_ids: Sequence[int]) -> List[int]:
        return [int(t) for t in token_ids if t not in self._all_special_tokens_ids]

    def _parse_ph(self, example_info: 'top_examples_common.TopExampleInfo') -> Tuple[str, str]:
        input_ids = example_info.example['input_ids']

        sep_indices, = np.nonzero(input_ids == self.tokenizer.sep_token_id)
        split_ind = sep_indices[0] + 1

        premise_token_ids = self._remove_special_tokens(input_ids[:split_ind])
        premise = _capitalize_first_letter(self.tokenizer.decode(premise_token_ids).strip())
        premise = latex_utils.escape(premise)

        hypothesis_token_ids = self._remove_special_tokens(input_ids[split_ind:])
        hypothesis = _capitalize_first_letter(self.tokenizer.decode(hypothesis_token_ids).strip())
        hypothesis = latex_utils.escape(hypothesis)

        return premise, hypothesis


###############################################################################


@dataclasses.dataclass
class NpeffComponentLatexGeneratorForLm(_NpeffComponentLatexGeneratorBase):
    
    top_examples_reader: 'top_examples_common.TopExamplesReaderAbc'
    tokenizer: PreTrainedTokenizer

    n_top_examples: int

    label_strings: Tuple[str, str, str] = ("True", "False", "Neither")
    label_colors: Tuple[str, str, str] = ("ForestGreen", "BrickRed", "RoyalBlue")

    components_fontsize: Optional[str] = 'footnotesize'
    
    component_filter: Optional['component_filtering.ComponentFilterAbc'] = None

    def __post_init__(self):
        self._padded_label_strings = tuple(latex_utils.rpad_curly_to_max_length(self.label_strings))
        self._all_special_tokens_ids: List[int] = self.tokenizer.all_special_ids

    #######################################################

    def _remove_special_tokens(self, token_ids: Sequence[int]) -> List[int]:
        return [int(t) for t in token_ids if t not in self._all_special_tokens_ids]

    def _parse_ph(self, example_info: 'top_examples_common.TopExampleInfo') -> Tuple[str, str]:
        input_ids = example_info.example['input_ids']
        attention_mask = example_info.example['attention_mask']

        context = self.tokenizer.decode(input_ids[attention_mask != 0])

        match = re.search(
            r'^(.+)\nQuestion: (.+)\. True, False or Neither\?\nAnswer:$',
            context)

        if match:
            premise = match.group(1)
            premise = latex_utils.escape(premise)

            hypothesis = f'{match.group(2)}.'
            hypothesis = latex_utils.escape(hypothesis)
        else:
            premise = 'ERROR'
            hypothesis = 'ERROR'

        return premise, hypothesis
