"""Code for generating LaTeX for SNLI components."""
import dataclasses
import numpy as np
from typing import Tuple

from em.util import latex_util
from em.util.color_util import cu

from . import snli_context as snli_context_module

SnliContext = snli_context_module.SnliContext
SnliExample = snli_context_module.SnliExample

###############################################################################

_COMPONENTS_LATEX_FILE_START = R"""% Please use XeLaTex to handle unicode properly.
\documentclass[11pt]{article}

\usepackage[margin=1in]{geometry} 
\usepackage[dvipsnames]{xcolor}
\usepackage{bold-extra}

\begin{document}

\section{Component Top Examples}
\textit{All indices used here are 0-based unless stated otherwise. Note that section indices will be 1-based though.}

"""

_COMPONENTS_LATEX_FILE_END = "\n\\end{document}"


_LABEL_TO_CHAR = ("c", "e", "n", "-")

_CHAR_TO_PADDED_LABEL_NAME_LATEX = {
    "-": 'n/a {} {} {} {} {} {} {} {} {} {}',
    "e": 'entailment {} {} {}',
    "n": 'neutral {} {} {} {} {} {}',
    "c": 'contradiction',
}

# Requires `\usepackage[dvipsnames]{xcolor}`
_CHAR_TO_LATEX_COLOR = {
    '-': 'lightgray',
    'n': 'RoyalBlue',
    'e': 'ForestGreen',
    'c': 'BrickRed',
}


###############################################################################


def _capitalize_first_letter(s: str) -> str:
    # NOTE: Won't capitalize anything if the string has leading whitespace.
    if s == '':
        return s
    return f'{s[0].upper()}{s[1:]}'

###############################################################################


@dataclasses.dataclass
class TopExamplesLatexGenerator:
    snli_ctx: SnliContext

    predicted_logits: np.ndarray

    n_top_examples: int

    components_fontsize: str = 'footnotesize'

    label_to_char: Tuple[str, ...] = _LABEL_TO_CHAR

    def generate_all_components_latex(self):
        ret = [_COMPONENTS_LATEX_FILE_START]

        for component_index in range(self.snli_ctx.nmf.W.shape[-1]):
            ret.append(R'\subsection{Component ' + str(component_index) + R'}')
            ret.append(R'\begin{' + self.components_fontsize + R'}')

            for x in self.snli_ctx.get_top_examples(component_index, self.n_top_examples):
                ret.append(self.make_example_for_component_latex_string(x, component_index))
                ret.append('')

            ret.append(R'\end{' + self.components_fontsize + R'}')
            ret.append('')
        
        ret.append(_COMPONENTS_LATEX_FILE_END)
        return '\n'.join(ret)

    def _to_latex_label(self, label: int) -> str:
        # The bold requires `\usepackage{bold-extra}`
        label_char = self.label_to_char[label]
        s = _CHAR_TO_PADDED_LABEL_NAME_LATEX[label_char]
        color = _CHAR_TO_LATEX_COLOR[label_char]
        return R'{\color{' + color + R'}\textbf{' + s + R'}}'

    def make_example_for_component_latex_string(self, example: SnliExample, component_index: int) -> str:
        index = example.index
        prediction = np.argmax(self.predicted_logits[index])
        # # TODO: Quick hack.
        # if prediction == 0:
        #     prediction = 1
        # elif prediction == 1:
        #     prediction = 0

        premise = _capitalize_first_letter(example.premise)
        hypothesis = _capitalize_first_letter(example.hypothesis)

        s_coeff = f'{self.snli_ctx.nmf.W[index, component_index]:.4f}'

        line1 = R' {} '.join([
            f'{"[LABEL]"} {self._to_latex_label(example.label)}',
            f'{"[PRED]"} {self._to_latex_label(prediction)}',
            f'{"[COEFF]"} {s_coeff}',
        ])

        line2 = f'[P] {latex_util.escape(premise)}'
        line3 = f'[H] {latex_util.escape(hypothesis)}'
        return "\n".join([
            R'\noindent\texttt{' + line1 + R'\vspace{1mm} \\',
            R'{}' + line2 + R'\vspace{1mm} \\',
            R'{}' + line3 + R'\vspace{2mm}\\',
            R'}',
        ])
