"""Things related to latex generation for top examples of components over CLINC150 examples."""
import dataclasses
import re
from typing import Optional, Tuple

import numpy as np
from transformers import PreTrainedTokenizer

from npeff_torch.examination.top_examples import component_filtering
from npeff_torch.examination.top_examples import top_examples_common
from npeff_torch.examination.top_examples import top_examples_latex
from npeff_torch.util import latex_utils

from . import clinc150_datasets

###############################################################################


@dataclasses.dataclass
class NpeffComponentLatexGeneratorForLm(top_examples_latex.LatexGeneratorAbc):
    # The top examples should come from the associated open qa version of the task.
    top_examples_reader: 'top_examples_common.TopExamplesReaderAbc'
    tokenizer: PreTrainedTokenizer

    n_top_examples: int

    label_strings: Tuple[str, str] = tuple(x.replace('_', ' ') for x in clinc150_datasets.LABELS_LAST_PART)

    components_fontsize: Optional[str] = 'footnotesize'

    component_filter: Optional['component_filtering.ComponentFilterAbc'] = None

    def __post_init__(self):
        self._padded_label_strings = tuple(latex_utils.rpad_curly_to_max_length(self.label_strings))

    #######################################################

    def _to_latex_label(self, label: int) -> str:
        if label >= 0:
            s = self._padded_label_strings[label]
        else:
            # We have negative labels for some reason.
            s = 'ERROR'
        # The bold requires `\usepackage{bold-extra}`
        return R'{\textbf{' + s + R'}}'

    def _make_example_top_line(self, example_info: 'top_examples_common.TopExampleInfo') -> str:
        line = []

        if example_info.label is not None:
            line.append(f'[LABEL] {self._to_latex_label(example_info.label)}')

        if example_info.logits is not None:
            line.append(f'[PRED] {self._to_latex_label(np.argmax(example_info.logits))}')

        s_coeff = f'{example_info.coefficient:.4f}'
        line.append(f'[COEFF] {s_coeff}')

        return R' {} '.join(line)

    #######################################################

    def _parse_sentence(self, example_info: 'top_examples_common.TopExampleInfo') -> str:
        input_ids = example_info.example['input_ids']
        attention_mask = example_info.example['attention_mask']

        context = self.tokenizer.decode(input_ids[attention_mask != 0])

        match = re.search(r'^Query: (.+)\nIntent:', context)

        if match:
            sentence = match.group(1)
            sentence = latex_utils.escape(sentence)
        else:
            sentence = 'ERROR'

        return sentence

    #######################################################

    def make_example_for_component_latex_string(self, component_index: int, example_info: 'top_examples_common.TopExampleInfo') -> str:
        lines = [
            self._make_example_top_line(example_info),
            self._parse_sentence(example_info),
        ]

        s = (R'\vspace{1mm} \\' + '\n' + R'{}').join(lines)
        s = R'\noindent\texttt{' + s + R'\vspace{2mm} \\' + '\n}'

        return s
