R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i em/projects/neurips2023/make_appendix_qqp_examples.py


"""

from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer

from em.tools.nmf import lrm_npeff
from em.projects.m_npeff import qqp_context
from em.projects.m_npeff import latex_generation
from em.fishers import lrm_pefs

from em.util import latex_util

from em.util.color_util import cu


###############################################################################

# NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/qqp_lrm_npeff2/per_example_fishers/"
# NMF_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.heldout_from_train.mnpeff.256comps.001.coeffs_fit_to_validation001.h5"
# NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

# # Use this only to get the predictions and example token ids without having to
# # evaluate the model.
# PEFS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/qqp_lrm_npeff2/per_example_fishers/"
# PEFS_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.validation.40430ex.65536.h5"
# PEFS_PATH = os.path.join(PEFS_DIR, PEFS_NAME)

# N_EXAMPLES = 6
# COMPONENT_INDICES = [0, 3, 9, 10]

# SPLIT = 'validation'

###############################################################################

NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/qqp_lrm_npeff2/per_example_fishers/"
NMF_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.heldout_from_train.50000ex.65536.wrongs_only.expansion_001.coeffs_fit001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

# Use this only to get the predictions and example token ids without having to
# evaluate the model.
PEFS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/qqp_lrm_npeff2/per_example_fishers/"
PEFS_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.heldout_from_train.50000ex.65536.h5"
PEFS_PATH = os.path.join(PEFS_DIR, PEFS_NAME)


N_EXAMPLES = 10
COMPONENT_INDICES = [21, 31, 36]
SPLIT = "train[-50000:]"


###############################################################################

TOKENIZER = 'bert-base-uncased'

###############################################################################

nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=False)

ctx = qqp_context.QqpContext(
    split=SPLIT,
    tokenizer=AutoTokenizer.from_pretrained(TOKENIZER),
    nmf=nmf,
)

logits = lrm_pefs.SparseLrmPefs.load_logits(PEFS_PATH)
predictions = np.argmax(logits, axis=-1)

###############################################################################

LABEL_TO_CHAR = ("n", "d", "-")

CHAR_TO_TEX_LABEL = {
    "d": R'\qqpD',
    "n": R'\qqpN',
}


LABEL_TO_TEX_LABEL = [CHAR_TO_TEX_LABEL.get(c, None) for c in LABEL_TO_CHAR]


def example_to_tex(example, component_index: int):
    s_coeff = f'{nmf.W[example.index, component_index]:.4f}'
    lines = [
        R'\qqpappendixex',
        '{' + LABEL_TO_TEX_LABEL[example.label] + '}{' + LABEL_TO_TEX_LABEL[predictions[example.index]] + '}{' + s_coeff + '}',
        '{' + latex_util.escape(example.sentence1) + '}',
        '{' + latex_util.escape(example.sentence2) + '}',
    ]
    return '\n'.join(lines)


def make_component_tex(component_index: int) -> str:
    body = '\n\n'.join([
        example_to_tex(ex, component_index)
        for ex in ctx.get_top_examples(component_index, N_EXAMPLES)
    ])
    return '\n'.join([
        R'\begin{snliappendixcomp}{' + str(component_index) + '}',
        body,
        R'\end{snliappendixcomp}',
    ])


###############################################################################
# snli: 0, 3, 8, 17
# qqp: 0, 3, 9, 10

output = [
    make_component_tex(component_index)
    for component_index in sorted(COMPONENT_INDICES)
]
output = '\n\n\n'.join(output)

print(output)



