R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i em/projects/pi/exps/mains/top_examples/make_appendix_top_examples.py
"""
import os

from transformers import AutoTokenizer

from em.fishers import per_example
from em.projects.anli import anli_misc1 as am
from em.tools.nmf import nmf_common
from em.util import latex_util

_capitalize_first_letter = am._capitalize_first_letter

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

TOKENIZER = 'bert-base-uncased'

###############################################################################

OG_PEF_NAME = "feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
OG_H_NAME = f"spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.{OG_PEF_NAME}"
NMF_NAME = f"fit_w.skip50000.50000ex.65536vpe.{OG_H_NAME}"

PEF_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

###############################################################################

CHAR_TO_TEX_LABEL = {
    'e': R'\snliE',
    'n': R'\snliN',
    'c': R'\snliC',
}

SUBSET_INDEX = 0

###############################################################################

N_EXAMPLES = 16
COMPONENT_INDICES = [345, 467, 292, 121, 354, 343, 400, 140, 11]

###############################################################################


def make_container():
    pef_path = os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_NAME)
    nmf_path = os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_NAME)

    pef = per_example.PerExampleFlatFishers.load(
        os.path.expanduser(pef_path),
        n_examples=None,
        # This leads to the Fishers not being loaded, which ends up being much faster.
        start_fisher_index=0,
        end_fisher_index=0,
    )

    nmf = nmf_common.SparseNmfDecomposition.load(os.path.expanduser(nmf_path))
    nmf.normalize_components_to_unit_norm()

    return am.PefNmfAnalysisContainer(
        pef=pef,
        nmfs=[nmf],
        tokenizer=AutoTokenizer.from_pretrained(TOKENIZER),
        shift_labels=True,
    )


def example_to_tex(container: am.PefNmfAnalysisContainer, example: am.NliExample, component_index: int):
    s_coeff = f'{container.nmfs[SUBSET_INDEX].W[example.index, component_index]:.4f}'
    lines = [
        R'\snliappendixex',
        '{' + CHAR_TO_TEX_LABEL[example.label_char] + '}{' + CHAR_TO_TEX_LABEL[example.prediction_char] + '}{' + s_coeff + '}',
        '{' + latex_util.escape(_capitalize_first_letter(example.premise)) + '}',
        '{' + latex_util.escape(_capitalize_first_letter(example.hypothesis)) + '}',
    ]
    return '\n'.join(lines)


def make_component_tex(container: am.PefNmfAnalysisContainer, component_index: int) -> str:
    body = '\n\n'.join([
        example_to_tex(container, ex, component_index)
        for ex in container.get_top_examples(SUBSET_INDEX, component_index, N_EXAMPLES)
    ])
    return '\n'.join([
        R'\begin{snliappendixcomp}{' + str(component_index) + '}',
        body,
        R'\end{snliappendixcomp}',
    ])


###############################################################################


container = make_container()

output = [
    make_component_tex(container, component_index)
    for component_index in sorted(COMPONENT_INDICES)
]
output = '\n\n\n'.join(output)

print(output)
