"""Things related to latex generation for top examples of components over trivia-qa examples."""
import dataclasses
from typing import Optional

from transformers import PreTrainedTokenizer

from npeff_torch.examination.top_examples import component_filtering
from npeff_torch.examination.top_examples import top_examples_common
from npeff_torch.examination.top_examples import top_examples_latex
from npeff_torch.util import latex_utils

###############################################################################


@dataclasses.dataclass
class NpeffComponentLatexGeneratorForOpenQa(top_examples_latex.LatexGeneratorAbc):
    top_examples_reader: 'top_examples_common.TopExamplesReaderAbc'
    tokenizer: PreTrainedTokenizer

    n_top_examples: int

    n_top_predictions: Optional[int] = None
    # Set to None to not include probabilities for the top predictions.
    top_predictions_probability_digits: Optional[int] = 2

    components_fontsize: Optional[str] = 'footnotesize'

    component_filter: Optional['component_filtering.ComponentFilterAbc'] = None

    def _make_top_line(self, example_info: 'top_examples_common.TopExampleInfo') -> str:
        s_coeff = f'{example_info.coefficient:.4f}'
        return f'[COEFF] {s_coeff}'

    def _make_predictions_line(self, example_info: 'top_examples_common.TopExampleInfo') -> str:
        if self.n_top_predictions is None:
            return None

        line = []
        for token_id, prob in example_info.get_top_probs(self.n_top_predictions):
            token = latex_utils.escape(latex_utils.remove_newlines(self.tokenizer.decode([token_id]).strip()))
            if self.top_predictions_probability_digits is not None:
                s_prob = "{:.{}f}".format(prob, self.top_predictions_probability_digits)
                token = f'{token} ({s_prob})'
            line.append(token)

        return f'[PREDS] {", ".join(line)}'

    def _make_context_line(self, example_info: 'top_examples_common.TopExampleInfo') -> str:
        input_ids = example_info.example['input_ids']
        attention_mask = example_info.example['attention_mask']
        
        context = self.tokenizer.decode(input_ids[attention_mask != 0])
        return latex_utils.escape(context)

    def make_example_for_component_latex_string(self, component_index: int, example_info: 'top_examples_common.TopExampleInfo') -> str:
        lines = [
            self._make_top_line(example_info),
            self._make_predictions_line(example_info),
            self._make_context_line(example_info),
        ]
        lines = [line for line in lines if line is not None]

        s = (R'\vspace{1mm} \\' + '\n' + R'{}').join(lines)
        s = R'\noindent\texttt{' + s + R'\vspace{2mm} \\' + '\n}'

        return s
