"""Generic container for BERT decompositions."""
import dataclasses
import re
from typing import Tuple

import numpy as np
from transformers import PreTrainedTokenizer

from em.util import latex_util
from em.fishers import per_example


_COMPONENTS_LATEX_FILE_START = R"""% Please use XeLaTex to handle unicode properly.
\documentclass[11pt]{article}

\usepackage[margin=1in]{geometry} 
\usepackage[dvipsnames]{xcolor}
\usepackage{bold-extra}

\begin{document}

\section{Component Top Examples}
\textit{All indices used here are 0-based unless stated otherwise. Note that section indices will be 1-based though.}

"""

_COMPONENTS_LATEX_FILE_END = "\n\\end{document}"


_LABEL_TO_CHAR = ("c", "e", "n", "-")

_CHAR_TO_PADDED_LABEL_NAME_LATEX = {
    "-": 'n/a {} {} {} {} {} {} {} {} {} {}',
    "e": 'entailment {} {} {}',
    "n": 'neutral {} {} {} {} {} {}',
    "c": 'contradiction',
}

# Requires `\usepackage[dvipsnames]{xcolor}`
_CHAR_TO_LATEX_COLOR = {
    '-': 'lightgray',
    'n': 'RoyalBlue',
    'e': 'ForestGreen',
    'c': 'BrickRed',
}

###############################################################################


def _capitalize_first_letter(s: str) -> str:
    # NOTE: Won't capitalize anything if the string has leading whitespace.
    if s == '':
        return s
    return f'{s[0].upper()}{s[1:]}'

###############################################################################


@dataclasses.dataclass
class Example:
    # TODO: Add more info.
    index: int

    premise: str
    hypothesis: str

    label: int


@dataclasses.dataclass
class BertContainer:

    # shape = [n_examples, n_components]
    coeffs: np.ndarray

    tokenizer: PreTrainedTokenizer

    # The following three attributes can be loaded from the pef file.
    labels: np.ndarray
    predicted_logits: np.ndarray
    input_ids: np.ndarray

    n_top_examples: int

    components_fontsize: str = 'footnotesize'

    label_to_char: Tuple[str, ...] = _LABEL_TO_CHAR

    def __post_init__(self):
        self.examples = self._make_nli_examples()

    @classmethod
    def load_from_pef_file(cls, pef_filepath: str, **kwargs):
        pef = per_example.PerExampleFlatFishers.load(
            pef_filepath,
            n_examples=None,
            # This leads to the Fishers not being loaded, which ends up being much faster.
            start_fisher_index=0,
            end_fisher_index=0,
        )

        # Need this?
        pef.labels = (pef.labels + 1) % 3

        return cls(
            input_ids=pef.input_ids,
            labels=pef.labels,
            predicted_logits=pef.predicted_logits,
            **kwargs
        )

    def _make_nli_examples(self):
        r_cls_token = re.escape(self.tokenizer.cls_token)
        r_sep_token = re.escape(self.tokenizer.sep_token)
        example_regex = rf'^{r_cls_token}(.+){r_sep_token}(.+){r_sep_token}$'

        examples = []
        for i, input_ids in enumerate(self.input_ids):
            example = self.tokenizer.decode(input_ids)
            example = example.replace(self.tokenizer.pad_token, '')
            example = example.strip()

            match = re.search(example_regex, example)
            premise = match.group(1).strip()
            hypothesis = match.group(2).strip()

            label = self.labels[i]

            ex = Example(
                index=i,
                premise=premise,
                hypothesis=hypothesis,
                label=label,
            )
            examples.append(ex)

        return examples

    def get_top_examples(self, component_index: int, n_examples: int):
        top_inds = np.argsort(-self.coeffs[:, component_index])[:n_examples]
        return [self.examples[i] for i in top_inds]

    def generate_all_components_latex(self):
        ret = [_COMPONENTS_LATEX_FILE_START]

        for component_index in range(self.coeffs.shape[-1]):
            ret.append(R'\subsection{Component ' + str(component_index) + R'}')
            ret.append(R'\begin{' + self.components_fontsize + R'}')

            for x in self.get_top_examples(component_index, self.n_top_examples):
                ret.append(self.make_example_for_component_latex_string(x, component_index))
                ret.append('')

            ret.append(R'\end{' + self.components_fontsize + R'}')
            ret.append('')
        
        ret.append(_COMPONENTS_LATEX_FILE_END)
        return '\n'.join(ret)

    def _to_latex_label(self, label: int) -> str:
        # The bold requires `\usepackage{bold-extra}`
        label_char = self.label_to_char[label]
        s = _CHAR_TO_PADDED_LABEL_NAME_LATEX[label_char]
        color = _CHAR_TO_LATEX_COLOR[label_char]
        return R'{\color{' + color + R'}\textbf{' + s + R'}}'

    def make_example_for_component_latex_string(self, example: Example, component_index: int) -> str:
        index = example.index
        prediction = np.argmax(self.predicted_logits[index])

        premise = _capitalize_first_letter(example.premise)
        hypothesis = _capitalize_first_letter(example.hypothesis)

        s_coeff = f'{self.coeffs[index, component_index]:.4f}'

        line1 = R' {} '.join([
            f'{"[LABEL]"} {self._to_latex_label(example.label)}',
            f'{"[PRED]"} {self._to_latex_label(prediction)}',
            f'{"[COEFF]"} {s_coeff}',
        ])

        line2 = f'[P] {latex_util.escape(premise)}'
        line3 = f'[H] {latex_util.escape(hypothesis)}'
        return "\n".join([
            R'\noindent\texttt{' + line1 + R'\vspace{1mm} \\',
            R'{}' + line2 + R'\vspace{1mm} \\',
            R'{}' + line3 + R'\vspace{2mm}\\',
            R'}',
        ])
