"""Things related to latex generation for top examples of components over Winogrande examples."""
import dataclasses
from typing import List, Optional, Tuple

import numpy as np
from transformers import PreTrainedTokenizer

from npeff_torch.datasets.winogrande import winogrande_analysis
from npeff_torch.examination.top_examples import component_filtering
from npeff_torch.examination.top_examples import top_examples_common
from npeff_torch.examination.top_examples import top_examples_latex
from npeff_torch.util import latex_utils

###############################################################################


@dataclasses.dataclass
class NpeffComponentLatexGeneratorForSuffixLm(top_examples_latex.LatexGeneratorAbc):
    top_examples_reader: 'top_examples_common.TopExamplesReaderAbc'
    tokenizer: PreTrainedTokenizer

    n_top_examples: int

    correct_prediction_color: str = 'ForestGreen'
    incorrect_prediction_color: str = 'BrickRed'

    # Technically the width of the text of the label/prediction itself.
    header_column_width: int = 16

    components_fontsize: Optional[str] = 'footnotesize'
    
    component_filter: Optional['component_filtering.ComponentFilterAbc'] = None

    def _make_latex_label(self, label: str) -> str:
        return latex_utils.rpad_curly(latex_utils.escape(label), self.header_column_width, len(label))

    def _make_example_top_line(
        self,
        example_info: 'top_examples_common.TopExampleInfo',
        winogrande_example: 'winogrande_analysis.WinograndeExample',
    ) -> str:
        label = winogrande_example.label
        prediction = winogrande_example.prediction

        prediction_color = self.correct_prediction_color if label == prediction else self.incorrect_prediction_color

        line = [
            f'[LABEL] {self._make_latex_label(label)}',
            f'[PRED] {{\\color{{{prediction_color}}}{self._make_latex_label(prediction)}}}',
            f'[COEFF] {example_info.coefficient:.4f}',
        ]

        return R' {} '.join(line)

    def make_example_for_component_latex_string(self, component_index: int, example_info: 'top_examples_common.TopExampleInfo') -> str:
        winogrande_example = winogrande_analysis.WinograndeExample.from_top_example_info(self.tokenizer, example_info)

        line1 = self._make_example_top_line(example_info, winogrande_example)

        return "\n".join([
            R'\noindent\texttt{' + line1 + R'\vspace{1mm} \\',
            R'{}' + latex_utils.escape(winogrande_example.sentence) + R'\vspace{2mm}\\',
            R'}',
        ])
