"""Things related to latex generation for top examples of components over sst2 examples."""
import dataclasses
import re
from typing import Optional, Tuple

import numpy as np
from transformers import PreTrainedTokenizer

from npeff_torch.examination.top_examples import component_filtering
from npeff_torch.examination.top_examples import top_examples_common
from npeff_torch.examination.top_examples import top_examples_latex
from npeff_torch.util import latex_utils

###############################################################################


@dataclasses.dataclass
class NpeffComponentLatexGeneratorForLm(top_examples_latex.LatexGeneratorAbc):
    top_examples_reader: 'top_examples_common.TopExamplesReaderAbc'
    tokenizer: PreTrainedTokenizer

    n_top_examples: int

    label_strings: Tuple[str, str] = ("Negative", "Positive")
    label_colors: Tuple[str, str] = ("BrickRed", "ForestGreen")

    components_fontsize: Optional[str] = 'footnotesize'

    component_filter: Optional['component_filtering.ComponentFilterAbc'] = None

    icl_format: bool = False

    def __post_init__(self):
        self._padded_label_strings = tuple(latex_utils.rpad_curly_to_max_length(self.label_strings))

    #######################################################

    def _to_latex_label(self, label: int) -> str:
        s = self._padded_label_strings[label]
        color = self.label_colors[label]
        # The bold requires `\usepackage{bold-extra}`
        return R'{\color{' + color + R'}\textbf{' + s + R'}}'

    def _make_example_top_line(self, example_info: 'top_examples_common.TopExampleInfo') -> str:
        line = []

        if example_info.label is not None:
            line.append(f'[LABEL] {self._to_latex_label(example_info.label)}')

        if example_info.logits is not None:
            line.append(f'[PRED] {self._to_latex_label(np.argmax(example_info.logits))}')

        s_coeff = f'{example_info.coefficient:.4f}'
        line.append(f'[COEFF] {s_coeff}')

        return R' {} '.join(line)

    #######################################################

    def _parse_sentence(self, example_info: 'top_examples_common.TopExampleInfo') -> str:
        input_ids = example_info.example['input_ids']
        attention_mask = example_info.example['attention_mask']

        context = self.tokenizer.decode(input_ids[attention_mask != 0])

        if self.icl_format:
            match = re.search(
                r'\n\nReview: (.+)\nSentiment:$',
                context)
        else:
            match = re.search(
                r'^Review: (.+)\nSentiment:$',
                context)

        if match:
            sentence = match.group(1)
            sentence = latex_utils.escape(sentence)
        else:
            sentence = 'ERROR'

        return sentence

    #######################################################

    def make_example_for_component_latex_string(self, component_index: int, example_info: 'top_examples_common.TopExampleInfo') -> str:
        lines = [
            self._make_example_top_line(example_info),
            self._parse_sentence(example_info),
        ]

        s = (R'\vspace{1mm} \\' + '\n' + R'{}').join(lines)
        s = R'\noindent\texttt{' + s + R'\vspace{2mm} \\' + '\n}'

        return s
