R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python local_scripts/pi/make_comps_pdf_qqp_01.py

"""
from importlib import reload
import os
import time

import numpy as np
from transformers import AutoTokenizer

from em.fishers import per_example
from em.projects.anli import anli_misc1 as am
from em.projects.ll import hans_container
from em.projects.ll import hans_util
from em.tools.nmf import nmf_common


###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

TOKENIZER = 'bert-base-uncased'

##########################################################################

# QQP_PEF_FILENAME = "bert_base_qqp.qqp_val.all_vars.first_20k.131072.h5"
# QQP_NMF_FILENAME = "nmf_decomp.c768_2kIters_32768pe_5000ex.bert_base_qqp.qqp_val.all_vars.first_20k.131072.h5"
# QQP_NMF_FILENAME = "nmf_decomp.c1024_2kIters_65536pe_20000ex_mvpp8.bert_base_qqp.qqp_val.all_vars.first_20k.131072.h5"


# # QQP_PEF_FILENAME = "enriched_incorrects_to_0_5.bert_base_qqp.qqp_val.all_vars.first_20k.131072.h5"
# QQP_PEF_FILENAME = "enriched_incorrects_to_0_25.bert_base_qqp.qqp_val.all_vars.first_20k.131072.h5"
# QQP_NMF_FILENAME = f"nmf_decomp.c1024_2kIters_65536pe_mvpp8.{QQP_PEF_FILENAME}"

# QQP_PEF_FILENAME = "bert_base_qqp.paws_final_train.all_vars.all_ex.131072.h5"
# QQP_NMF_FILENAME = f"nmf_decomp.c{1024}_2kIters_{65536}pe_mvpp{16}.{QQP_PEF_FILENAME}"

# QQP_PEF_FILENAME = "bert_base_rte.sci_tail_train.all_vars.all_ex.65536.h5"
# QQP_NMF_FILENAME = f"nmf_decomp.c{1024}_2kIters_{65536}pe_mvpp{8}.{QQP_PEF_FILENAME}"

QQP_PEF_FILENAME = 'bert_base_rte.lexical_overlap.validation.all_vars.all_ex.131072.h5'
QQP_NMF_FILENAME = f"spH.nmf_decomp.c{256}_{2500}Iters_{131072}pe_mvpp{1}.{QQP_PEF_FILENAME}"

##########################################################################

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)


def make_container_qqp():
    pef = per_example.PerExampleFlatFishers.load(
        os.path.join(PER_EXAMPLES_FISHERS_DIR, QQP_PEF_FILENAME),
        n_examples=None,
        # This leads to the Fishers not being loaded, which ends up being much faster.
        start_fisher_index=0,
        end_fisher_index=0,
    )
    pef.predicted_logits = np.concatenate([pef.predicted_logits, -1e9 * np.ones_like(pef.predicted_logits[:, :1])], axis=-1)
    # nmf = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, QQP_NMF_FILENAME))
    if QQP_NMF_FILENAME.startswith("spH.") or ".spH." in QQP_NMF_FILENAME:
        # Assume this means that this is a sparsified NMF decomposition.
        nmf = nmf_common.SparseNmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, QQP_NMF_FILENAME))
    else:
        nmf = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, QQP_NMF_FILENAME))
    print(nmf.W.shape)
    nmf.normalize_components_to_unit_norm()
    #
    container = am.PefNmfAnalysisContainer(
        pef=pef,
        nmfs=[nmf],
        tokenizer=tokenizer,
        shift_labels=False,
    )
    return container

##########################################################################


N_EXAMPLES = 12


def make_tex_qqp():
    container = make_container_qqp()
    #
    # q = container.make_all_components_latex_string(n_examples=N_EXAMPLES)
    #
    ordered_comps = np.argsort(np.sort(-container.nmfs[0].W, axis=0).sum(axis=0))
    q = container.make_latex_string_for_some_components([ordered_comps], n_examples=N_EXAMPLES)
    #
    # with open('/fruitbasket/users/m/tmp/qqp_dev_all_comps_001.tex', 'w') as f:
    # with open('/fruitbasket/users/m/tmp/qqp_dev_all_comps_002.tex', 'w') as f:
    # with open('/fruitbasket/users/m/tmp/qqp_dev_all_comps_003.tex', 'w') as f:
    # with open('/fruitbasket/users/m/tmp/qqp_dev_all_comps_004.tex', 'w') as f:
    # with open('/fruitbasket/users/m/tmp/paws_dev_all_comps_001.tex', 'w') as f:
    with open('/fruitbasket/users/m/tmp/rte_hans_validation.tex', 'w') as f:
        f.write(container.COMPONENTS_LATEX_FILE_START)
        f.write(q)
        f.write(container.COMPONENTS_LATEX_FILE_END)


##########################################################################


make_tex_qqp()


R"""

make_pdf() {
    local filename=$1

    rsync -ra -e ssh \
        "m@banana.cs.unc.edu:/fruitbasket/users/m/tmp/${filename}.tex" \
        "$HOME/Downloads/${filename}.tex"

    xelatex -interaction=batchmode -output-directory=/tmp ~/Downloads/${filename}.tex
    mv /tmp/${filename}.pdf ~/Desktop/projects_data/extract_merge1/ll/pdfs
}

make_pdf qqp_dev_all_comps_001
make_pdf qqp_dev_all_comps_002
make_pdf qqp_dev_all_comps_003
make_pdf qqp_dev_all_comps_004

make_pdf paws_dev_all_comps_001

make_pdf scitail_dev_all_comps_001
make_pdf rte_hans_validation

"""
