"""Things related to latex generation for top examples of components over YAT examples."""
import dataclasses
import re
from typing import Optional, Tuple

import numpy as np
from transformers import PreTrainedTokenizer

from npeff_torch.examination.top_examples import component_filtering
from npeff_torch.examination.top_examples import top_examples_common
from npeff_torch.examination.top_examples import top_examples_latex
from npeff_torch.util import latex_utils

from . import yahoo_answers_topics_datasets

###############################################################################


@dataclasses.dataclass
class NpeffComponentLatexGeneratorForLm(top_examples_latex.LatexGeneratorAbc):
    top_examples_reader: 'top_examples_common.TopExamplesReaderAbc'
    tokenizer: PreTrainedTokenizer

    n_top_examples: int

    label_strings: Tuple[str, str] = yahoo_answers_topics_datasets.LABELS

    components_fontsize: Optional[str] = 'footnotesize'

    component_filter: Optional['component_filtering.ComponentFilterAbc'] = None

    def __post_init__(self):
        self._padded_label_strings = tuple(latex_utils.rpad_curly_to_max_length(self.label_strings))

    #######################################################

    def _to_latex_label(self, label: int) -> str:
        s = self._padded_label_strings[label]
        # The bold requires `\usepackage{bold-extra}`
        return R'{\textbf{' + s + R'}}'

    def _make_example_top_line(self, example_info: 'top_examples_common.TopExampleInfo') -> str:
        line = []

        if example_info.label is not None:
            line.append(f'[LABEL] {self._to_latex_label(example_info.label)}')

        if example_info.logits is not None:
            line.append(f'[PRED] {self._to_latex_label(np.argmax(example_info.logits))}')

        s_coeff = f'{example_info.coefficient:.4f}'
        line.append(f'[COEFF] {s_coeff}')

        return R' {} '.join(line)

    #######################################################

    def _parse_sentence(self, example_info: 'top_examples_common.TopExampleInfo') -> str:
        input_ids = example_info.example['input_ids']
        attention_mask = example_info.example['attention_mask']

        context = self.tokenizer.decode(input_ids[attention_mask != 0])

        match = re.search(
            r'^Question: (.+)\nWhat broad topic is this question about\? Choose from:\n',
            context)

        if match:
            sentence = match.group(1)
            sentence = latex_utils.escape(sentence)
        else:
            sentence = 'ERROR'

        return sentence

    #######################################################

    def make_example_for_component_latex_string(self, component_index: int, example_info: 'top_examples_common.TopExampleInfo') -> str:
        lines = [
            self._make_example_top_line(example_info),
            self._parse_sentence(example_info),
        ]

        s = (R'\vspace{1mm} \\' + '\n' + R'{}').join(lines)
        s = R'\noindent\texttt{' + s + R'\vspace{2mm} \\' + '\n}'

        return s
