"""Things related to latex generation for top examples of components over PILE examples.

Typically, things here will assume an LM-style model making predictions of the next token.
"""
import dataclasses
from typing import Optional, List, Sequence, Tuple

import numpy as np
import torch
from transformers import PreTrainedTokenizer

from npeff_torch.examination.top_examples import component_filtering
from npeff_torch.examination.top_examples import top_examples_common
from npeff_torch.examination.top_examples import top_examples_latex
from npeff_torch.util import latex_utils


###############################################################################


@dataclasses.dataclass
class NpeffComponentLatexGenerator(top_examples_latex.LatexGeneratorAbc):
    
    top_examples_reader: 'top_examples_common.TopExamplesReaderAbc'
    tokenizer: PreTrainedTokenizer

    n_top_examples: int

    components_fontsize: Optional[str] = 'footnotesize'

    example_text_fontsize: Optional[str] = 'scriptsize'

    max_preceding_tokens_shown: Optional[int] = None
    max_following_tokens_shown: Optional[int] = None

    n_top_predictions: Optional[int] = None
    # Set to None to not include probabilities for the top predictions.
    top_predictions_probability_digits: Optional[int] = 2
    
    component_filter: Optional['component_filtering.ComponentFilterAbc'] = None

    #######################################################

    def _make_example_text(self, example_info: 'top_examples_common.TopExampleInfo') -> str:
        input_ids = example_info.example['input_ids']
        n_non_padding = np.sum(input_ids != self.tokenizer.pad_token_id)
        input_ids = input_ids[:n_non_padding]

        token_position = example_info.token_position

        if self.max_preceding_tokens_shown is not None:
            start_position = max(0, token_position - self.max_preceding_tokens_shown)
        else:
            start_position = 0

        if self.max_following_tokens_shown is not None:
            end_position = min(len(input_ids), 1 + token_position + self.max_following_tokens_shown)
        else:
            end_position = len(input_ids)

        preceding_text = latex_utils.escape(self.tokenizer.decode(input_ids[start_position:token_position]))
        preceding_text = latex_utils.remove_newlines(preceding_text)

        following_text = latex_utils.escape(self.tokenizer.decode(input_ids[token_position + 1:end_position]))
        following_text = latex_utils.remove_newlines(following_text)

        token = latex_utils.escape(latex_utils.remove_newlines(self.tokenizer.decode([input_ids[token_position]])))
        token = R'\colorbox[RGB]{255,255,0}{' + token + R'}'
        token = R"{\color{Magenta}" + token + R"}"

        ret = []

        if self.example_text_fontsize is not None:
            ret.append(R'\begin{' + self.example_text_fontsize + R'}')

        ret.append(f'{preceding_text}{token}{following_text}')

        if self.example_text_fontsize is not None:
            ret.append(R'\end{' + self.example_text_fontsize + R'}')

        return '\n'.join(ret)

    def _make_example_top_line(self, example_info: 'top_examples_common.TopExampleInfo') -> str:
        line = []

        s_coeff = f'{example_info.coefficient:.4f}'
        line.append(f'[COEFF] {s_coeff}')

        label = example_info.label
        if isinstance(label, np.ndarray) and label.size != 1:
            label = label[example_info.token_position]

        label = latex_utils.escape(latex_utils.remove_newlines(self.tokenizer.decode([label]).strip()))
        line.append(f'[LABEL] {label}')

        return R' {} '.join(line)

    def _maybe_make_example_second_line(self, example_info: 'top_examples_common.TopExampleInfo') -> Optional[str]:
        if self.n_top_predictions is None:
            return None

        line = []
        for token_id, prob in example_info.get_top_probs(self.n_top_predictions):
            token = latex_utils.escape(latex_utils.remove_newlines(self.tokenizer.decode([token_id]).strip()))
            if self.top_predictions_probability_digits is not None:
                s_prob = "{:.{}f}".format(prob, self.top_predictions_probability_digits)
                token = f'{token} ({s_prob})'
            line.append(token)

        return f'[PREDS] {", ".join(line)}'

    #######################################################

    def make_example_for_component_latex_string(self, component_index: int, example_info: 'top_examples_common.TopExampleInfo') -> str:
        lines = [
            self._make_example_top_line(example_info),
            self._maybe_make_example_second_line(example_info),
            self._make_example_text(example_info),
        ]
        lines = [line for line in lines if line is not None]

        s = (R'\vspace{1mm} \\' + '\n' + R'{}').join(lines)
        s = R'\noindent\texttt{' + s + R'\vspace{2mm} \\' + '\n}'

        return s
