R"""Makes a set-up for figuring out interpretations of the top examples for components.

This essentially creates a directory full of files to examine and edit so that I
we can create interpretations for components and justify them.

NOTE: I might later want to create some code that "reads" the interpretations and
their justitifications and puts that in a single place.
"""
import os
import subprocess

from absl import app
from absl import flags

from transformers import AutoTokenizer

from em.fishers import per_example
from em.projects.anli import anli_misc1 as am
from em.tools.nmf import nmf_common

FLAGS = flags.FLAGS

# Flags describing what files to read.
flags.DEFINE_string("pef_path", None, "")
flags.DEFINE_string("nmf_path", None, "")

flags.DEFINE_string("tokenizer", None, "")

flags.DEFINE_integer("n_examples", None, "Number of examples to show per component.")
flags.DEFINE_string("components_fontsize", "footnotesize", "")

# Outputs
flags.DEFINE_string("output_dir", None, "Path to directory to write output to. Must already exist.")
flags.DEFINE_string("output_prefix", '', "Optional prefix to prepend to output file names.")


"""
Will have top-level directory filled with folders named like `comp234` for each component.

Possible contents of each component's directory:
- Pdf of top examples.
"""

SUBSET_INDEX = 0

COMPONENT_LATEX_FILE_START = R"""% Please use XeLaTex to handle unicode properly.
\documentclass[11pt]{article}

\usepackage[margin=.25in]{geometry} 
\usepackage[dvipsnames]{xcolor}
\usepackage{bold-extra}

\begin{document}

"""

COMPONENT_LATEX_FILE_END = "\n\\end{document}"


def make_container(pef_path, nmf_path):
    pef = per_example.PerExampleFlatFishers.load(
        os.path.expanduser(pef_path),
        n_examples=None,
        # This leads to the Fishers not being loaded, which ends up being much faster.
        start_fisher_index=0,
        end_fisher_index=0,
    )

    nmf = nmf_common.SparseNmfDecomposition.load(os.path.expanduser(nmf_path))
    nmf.normalize_components_to_unit_norm()

    return am.PefNmfAnalysisContainer(
        pef=pef,
        nmfs=[nmf],
        tokenizer=AutoTokenizer.from_pretrained(FLAGS.tokenizer),
        shift_labels=True,
    )

# Need to number examples per component.


def make_tex_for_component(container, component_index: int) -> str:
    ret = []
    ret.append(COMPONENT_LATEX_FILE_START)
    ret.append(R'\section*{\center Component ' + str(component_index) + R'}')
    ret.append(R'\begin{' + FLAGS.components_fontsize + R'}')

    for x in container.get_top_examples(SUBSET_INDEX, component_index, FLAGS.n_examples):
        ret.append(container.make_example_for_component_latex_string(x, SUBSET_INDEX, component_index))
        ret.append('')

    ret.append(R'\end{' + FLAGS.components_fontsize + R'}')
    ret.append('')
    ret.append(COMPONENT_LATEX_FILE_END)
        
    return '\n'.join(ret)


def get_filename(component_index: int, extension: str) -> str:
    return f'{FLAGS.output_prefix}comp{component_index:03d}.{extension}'


def write_tex_for_component(container, output_dir: str, component_index: int):
    content = make_tex_for_component(container, component_index)
    filepath = os.path.join(output_dir, get_filename(component_index, 'tex'))
    with open(filepath, 'wt') as f:
        f.write(content)


def compile_pdf_for_component(output_dir: str, component_index: int):
    # Assumes the tex is already written.
    # Deletes the tex when done.
    tex_filepath = os.path.join(output_dir, get_filename(component_index, 'tex'))
    aux_filepath = os.path.join(output_dir, get_filename(component_index, 'aux'))
    log_filepath = os.path.join(output_dir, get_filename(component_index, 'log'))
    # pdf_filepath = os.path.join(output_dir, get_filename(component_index, 'pdf'))

    # xelatex -interaction=batchmode -output-directory=/tmp ~/Downloads/${filename}.tex
    cmd = [
        'xelatex',
        '-interaction=batchmode',
        f'-output-directory={output_dir}',
        tex_filepath,
    ]
    subprocess.run(cmd)

    os.remove(tex_filepath)
    os.remove(aux_filepath)
    os.remove(log_filepath)


def main(_):
    output_dir = os.path.expanduser(FLAGS.output_dir)  
    assert os.path.exists(output_dir), f'Please create the directory: {output_dir}'

    container = make_container(FLAGS.pef_path, FLAGS.nmf_path)

    n_components = container.nmfs[0].W.shape[-1]

    for component_index in range(n_components):
        write_tex_for_component(container, output_dir, component_index)

    for component_index in range(n_components):
        compile_pdf_for_component(output_dir, component_index)


if __name__ == "__main__":
    app.run(main)
