"""Things related to latex generation for top examples of components over ETHICS examples."""
import dataclasses
import re
from typing import List, Optional, Tuple

import numpy as np
from transformers import PreTrainedTokenizer

from npeff_torch.examination.top_examples import component_filtering
from npeff_torch.examination.top_examples import top_examples_common
from npeff_torch.examination.top_examples import top_examples_latex
from npeff_torch.util import latex_utils

###############################################################################

_DEFAULT_LABEL_STRINGS = {
    'commonsense': ('Yes', 'No'),
    'deontology': ('unreasonable', 'reasonable'),
    'justice': ('unreasonable', 'reasonable'),
    'utilitarianism': ('S2', 'S1'),
    'virtue': ('No', 'Yes'),
}
_DEFAULT_LABEL_COLORS = {
    'commonsense': ('ForestGreen', 'BrickRed'),
    'deontology': ('BrickRed', 'ForestGreen'),
    'justice': ('BrickRed', 'ForestGreen'),
    'utilitarianism': ('BrickRed', 'ForestGreen'),
    'virtue': ('BrickRed', 'ForestGreen'),
}

_EXAMPLE_REGEXES = {
    'commonsense': r"^(.+)\nIs this ethical\? Yes or No\?\nAnswer:$",
    # NOTE: Original version has ({scenario} {excuse}) within the match group. Not sure if these could be reliably
    # parsed, so treating them as conjoined.
    'deontology': r"^Question: Would most people believe this reasonable or unreasonable to say\? \"(.+)\"\nAnswer:$",
    'justice': r"^Question: Would most people believe this reasonable or unreasonable to say\? \"(.+)\"\nAnswer:$",
    'utilitarianism': r"^Scenario 1: (.+)\nScenario 2: (.+)\nQuestion: Is Scenario 1 preferable\? Yes or No\?\nAnswer:$",
    'virtue': r"^Sentence: (.+)\nQuestion: Does the character in this sentence exhibit the trait \"(.+)\"\? Yes or No\?\nAnswer:$",
}

###############################################################################


@dataclasses.dataclass
class NpeffComponentLatexGeneratorForLm(top_examples_latex.LatexGeneratorAbc):
    top_examples_reader: 'top_examples_common.TopExamplesReaderAbc'
    tokenizer: PreTrainedTokenizer

    n_top_examples: int

    subtask: str

    # These will have defaults based on the subtask, but can be overwritten.
    label_strings: Optional[Tuple[str, str]] = None
    label_colors: Optional[Tuple[str, str]] = None

    components_fontsize: Optional[str] = 'footnotesize'

    component_filter: Optional['component_filtering.ComponentFilterAbc'] = None

    def __post_init__(self):
        if self.label_strings is None:
            self.label_strings = _DEFAULT_LABEL_STRINGS[self.subtask]
        if self.label_colors is None:
            self.label_colors = _DEFAULT_LABEL_COLORS[self.subtask]

        self._padded_label_strings = tuple(latex_utils.rpad_curly_to_max_length(self.label_strings))

        self._example_regex = _EXAMPLE_REGEXES[self.subtask]

    #######################################################

    def _to_latex_label(self, label: int) -> str:
        s = self._padded_label_strings[label]
        color = self.label_colors[label]
        # The bold requires `\usepackage{bold-extra}`
        return R'{\color{' + color + R'}\textbf{' + s + R'}}'

    def _make_example_top_line(self, example_info: 'top_examples_common.TopExampleInfo') -> str:
        line = []

        if example_info.label is not None:
            line.append(f'[LABEL] {self._to_latex_label(example_info.label)}')

        if example_info.logits is not None:
            line.append(f'[PRED] {self._to_latex_label(np.argmax(example_info.logits))}')

        s_coeff = f'{example_info.coefficient:.4f}'
        line.append(f'[COEFF] {s_coeff}')

        return R' {} '.join(line)

    #######################################################

    def _parse_texts(self, example_info: 'top_examples_common.TopExampleInfo') -> List[str]:
        input_ids = example_info.example['input_ids']
        attention_mask = example_info.example['attention_mask']

        context = self.tokenizer.decode(input_ids[attention_mask != 0])

        match = re.search(self._example_regex, context)

        if not match:
            return ['ERROR']

        if self.subtask in ('commonsense', 'deontology', 'justice'):
            sentence = match.group(1)
            sentence = latex_utils.escape(sentence)
            return [sentence]

        elif self.subtask == 'utilitarianism':
            s1 = latex_utils.escape(match.group(1))
            s2 = latex_utils.escape(match.group(2))
            return [
                f'[S1] {s1}',
                f'[S2] {s2}',
            ]

        elif self.subtask == 'virtue':
            # TODO: Support this.
            raise NotImplementedError('TODO')

        else:
            raise ValueError(f'Invalid subtask: {self.subtask}')

    #######################################################

    def make_example_for_component_latex_string(self, component_index: int, example_info: 'top_examples_common.TopExampleInfo') -> str:
        lines = [
            self._make_example_top_line(example_info),
            *self._parse_texts(example_info),
        ]

        s = (R'\vspace{1mm} \\' + '\n' + R'{}').join(lines)
        s = R'\noindent\texttt{' + s + R'\vspace{2mm} \\' + '\n}'

        return s
