R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i em/projects/neurips2023/make_appendix_snli_kmeans_examples.py


"""

from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer

from em.tools.nmf import lrm_npeff
from em.projects.m_npeff import snli_context
from em.projects.m_npeff import latex_generation
from em.fishers import lrm_pefs

from em.util import latex_util

from em.util.color_util import cu
from em.tools import k_means
from em.activations import bert_activations


###############################################################################

NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/snli3og_lrm_npeff/per_example_fishers/"
NMF_NAME = "feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.fit_to_train_skip_50k.h5"
# NMF_NAME = "feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.wrongs_only.expansion_001.no_full_joint.fit_to_train_skip_50k.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

# Use this only to get the predictions and example token ids without having to
# evaluate the model.
PEFS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/snli3og_lrm_npeff/per_example_fishers/"
PEFS_NAME = "feather_berts_0.train_skip_50k.50000ex.65536.h5"
PEFS_PATH = os.path.join(PEFS_DIR, PEFS_NAME)

KMEANS_PATH = "/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/kmeans.128comps.feather_berts_0.snli.train.50000ex.bert_cls_activations.h5"
ACTS_PATH = "/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/feather_berts_0.snli.train_skip_50k.50000ex.bert_cls_activations.h5"

###############################################################################

TOKENIZER = 'bert-base-uncased'
SPLIT = 'train_skip_50k'

###############################################################################

nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=False)

ctx = snli_context.SnliContext(
    split=SPLIT,
    tokenizer=AutoTokenizer.from_pretrained(TOKENIZER),
    nmf=nmf,
)

logits = lrm_pefs.SparseLrmPefs.load_logits(PEFS_PATH)
predictions = np.argmax(logits, axis=-1)


acts = bert_activations.BertClsActivations.load(ACTS_PATH)

# Use only the representations from the last layer.
activations = acts.activations[:, -768:]
labels = acts.labels


km = k_means.KMeans.load(KMEANS_PATH)
coeffs = km.create_coeffs(activations)
top_comp_inds = np.argsort(-coeffs, axis=0).T
# top_comp_inds = np.argsort(coeffs, axis=0).T


###############################################################################

CHAR_TO_TEX_LABEL = {
    'e': R'\snliE',
    'n': R'\snliN',
    'c': R'\snliC',
}

LABEL_TO_CHAR = ("c", "e", "n", "-")

LABEL_TO_TEX_LABEL = [CHAR_TO_TEX_LABEL.get(c, None) for c in LABEL_TO_CHAR]


def example_to_tex(example, component_index: int):
    s_coeff = f'{coeffs[ example.index, component_index]:.4f}'
    lines = [
        R'\snliappendixex',
        '{' + LABEL_TO_TEX_LABEL[example.label] + '}{' + LABEL_TO_TEX_LABEL[predictions[example.index]] + '}{' + s_coeff + '}',
        '{' + latex_util.escape(example.premise) + '}',
        '{' + latex_util.escape(example.hypothesis) + '}',
    ]
    return '\n'.join(lines)


def make_component_tex(component_index: int) -> str:
    body = '\n\n'.join([
        example_to_tex(ctx.examples[i], component_index)
        for i in top_comp_inds[component_index][:N_EXAMPLES]
    ])
    return '\n'.join([
        R'\begin{snliappendixcomp}{' + str(component_index) + '}',
        body,
        R'\end{snliappendixcomp}',
    ])


###############################################################################

N_EXAMPLES = 6
COMPONENT_INDICES = [1, 4, 28]

output = [
    make_component_tex(component_index)
    for component_index in sorted(COMPONENT_INDICES)
]
output = '\n\n\n'.join(output)

print(output)
