R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i em/projects/neurips2023/make_appendix_snli_examples.py


"""

from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer

from em.tools.nmf import lrm_npeff
from em.projects.m_npeff import snli_context
from em.projects.m_npeff import latex_generation
from em.fishers import lrm_pefs

from em.util import latex_util

from em.util.color_util import cu


###############################################################################

NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/snli3og_lrm_npeff/per_example_fishers/"
# NMF_NAME = "feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.fit_to_train_skip_50k.h5"
NMF_NAME = "feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.wrongs_only.expansion_001.no_full_joint.fit_to_train_skip_50k.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

# Use this only to get the predictions and example token ids without having to
# evaluate the model.
PEFS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/snli3og_lrm_npeff/per_example_fishers/"
PEFS_NAME = "feather_berts_0.train_skip_50k.50000ex.65536.h5"
PEFS_PATH = os.path.join(PEFS_DIR, PEFS_NAME)

###############################################################################

TOKENIZER = 'bert-base-uncased'
SPLIT = 'train_skip_50k'

###############################################################################

nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=False)

ctx = snli_context.SnliContext(
    split=SPLIT,
    tokenizer=AutoTokenizer.from_pretrained(TOKENIZER),
    nmf=nmf,
)

logits = lrm_pefs.SparseLrmPefs.load_logits(PEFS_PATH)
predictions = np.argmax(logits, axis=-1)

###############################################################################

CHAR_TO_TEX_LABEL = {
    'e': R'\snliE',
    'n': R'\snliN',
    'c': R'\snliC',
}

LABEL_TO_CHAR = ("c", "e", "n", "-")

LABEL_TO_TEX_LABEL = [CHAR_TO_TEX_LABEL.get(c, None) for c in LABEL_TO_CHAR]


def example_to_tex(example, component_index: int):
    s_coeff = f'{nmf.W[example.index, component_index]:.4f}'
    lines = [
        R'\snliappendixex',
        '{' + LABEL_TO_TEX_LABEL[example.label] + '}{' + LABEL_TO_TEX_LABEL[predictions[example.index]] + '}{' + s_coeff + '}',
        '{' + latex_util.escape(example.premise) + '}',
        '{' + latex_util.escape(example.hypothesis) + '}',
    ]
    return '\n'.join(lines)


def make_component_tex(component_index: int) -> str:
    body = '\n\n'.join([
        example_to_tex(ex, component_index)
        for ex in ctx.get_top_examples(component_index, N_EXAMPLES)
    ])
    return '\n'.join([
        R'\begin{snliappendixcomp}{' + str(component_index) + '}',
        body,
        R'\end{snliappendixcomp}',
    ])


###############################################################################

# N_EXAMPLES = 6
# COMPONENT_INDICES = [0, 3, 8, 17]


N_EXAMPLES = 8 
COMPONENT_INDICES = [18, 34, 62]

output = [
    make_component_tex(component_index)
    for component_index in sorted(COMPONENT_INDICES)
]
output = '\n\n\n'.join(output)

print(output)
# 1, 4, 28
