"""Code for generating LaTeX for QQP components."""
import dataclasses
import numpy as np
from typing import Tuple

from em.util import latex_util
from em.util.color_util import cu

from . import qqp_context as qqp_context_module

QqpContext = qqp_context_module.QqpContext
QqpExample = qqp_context_module.QqpExample

###############################################################################

_COMPONENTS_LATEX_FILE_START = R"""% Please use XeLaTex to handle unicode properly.
\documentclass[11pt]{article}

\usepackage[margin=1in]{geometry} 
\usepackage[dvipsnames]{xcolor}
\usepackage{bold-extra}

\begin{document}

\section{Component Top Examples}
\textit{All indices used here are 0-based unless stated otherwise. Note that section indices will be 1-based though.}

"""

_COMPONENTS_LATEX_FILE_END = "\n\\end{document}"

# Not duplicate, duplicate
_LABEL_TO_CHAR = ("n", "d", "-")

_CHAR_TO_PADDED_LABEL_NAME_LATEX = {
    # "-": 'n/a {} {} {} {} {} {} {} {} {} {}',
    "d": 'duplicate' + (4 * ' {}'),
    "n": 'not duplicate',
}

# Requires `\usepackage[dvipsnames]{xcolor}`
_CHAR_TO_LATEX_COLOR = {
    '-': 'lightgray',
    'n': 'BrickRed',
    'd': 'ForestGreen',
}


###############################################################################


def _capitalize_first_letter(s: str) -> str:
    # NOTE: Won't capitalize anything if the string has leading whitespace.
    if s == '':
        return s
    return f'{s[0].upper()}{s[1:]}'


###############################################################################


@dataclasses.dataclass
class TopExamplesLatexGenerator:
    qqp_ctx: QqpContext

    predicted_logits: np.ndarray

    n_top_examples: int

    components_fontsize: str = 'footnotesize'

    label_to_char: Tuple[str, ...] = _LABEL_TO_CHAR

    def generate_all_components_latex(self):
        ret = [_COMPONENTS_LATEX_FILE_START]

        for component_index in range(self.qqp_ctx.nmf.W.shape[-1]):
            ret.append(R'\subsection{Component ' + str(component_index) + R'}')
            ret.append(R'\begin{' + self.components_fontsize + R'}')

            for x in self.qqp_ctx.get_top_examples(component_index, self.n_top_examples):
                ret.append(self.make_example_for_component_latex_string(x, component_index))
                ret.append('')

            ret.append(R'\end{' + self.components_fontsize + R'}')
            ret.append('')
        
        ret.append(_COMPONENTS_LATEX_FILE_END)
        return '\n'.join(ret)

    def _to_latex_label(self, label: int) -> str:
        # The bold requires `\usepackage{bold-extra}`
        label_char = self.label_to_char[label]
        s = _CHAR_TO_PADDED_LABEL_NAME_LATEX[label_char]
        color = _CHAR_TO_LATEX_COLOR[label_char]
        return R'{\color{' + color + R'}\textbf{' + s + R'}}'

    def make_example_for_component_latex_string(self, example: QqpExample, component_index: int) -> str:
        index = example.index
        prediction = np.argmax(self.predicted_logits[index])

        sentence1 = _capitalize_first_letter(example.sentence1)
        sentence2 = _capitalize_first_letter(example.sentence2)

        s_coeff = f'{self.qqp_ctx.nmf.W[index, component_index]:.4f}'

        line1 = R' {} '.join([
            f'{"[LABEL]"} {self._to_latex_label(example.label)}',
            f'{"[PRED]"} {self._to_latex_label(prediction)}',
            f'{"[COEFF]"} {s_coeff}',
        ])

        line2 = f'[S1] {latex_util.escape(sentence1)}'
        line3 = f'[S2] {latex_util.escape(sentence2)}'
        return "\n".join([
            R'\noindent\texttt{' + line1 + R'\vspace{1mm} \\',
            R'{}' + line2 + R'\vspace{1mm} \\',
            R'{}' + line3 + R'\vspace{2mm}\\',
            R'}',
        ])
