R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python local_scripts/pi/make_comps_pdf_mnli_01.py

"""
from importlib import reload
import os
import time

import numpy as np
from transformers import AutoTokenizer

from em.fishers import per_example
from em.projects.anli import anli_misc1 as am
from em.projects.ll import hans_container
from em.projects.ll import hans_util
from em.tools.nmf import nmf_common

from em.util.color_util import cu

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

TOKENIZER = 'bert-base-uncased'

##########################################################################

PEF_FILENAME = "feather_berts_{model_number}.mnli_train.all_vars.100000ex.131072.h5"
NMF_FILENAME = "nmf_decomp.c{n_components}_1250Iters_{n_vals_pe}pe_mvpp{min_vals_per_param}_{n_examples}ex.{pef}"

SNLI_PEF_FILENAME = "feather_berts_{model_number}.snli_train.all_vars.50000ex.65536.h5"
SNLI_NMF_FILENAME = "nmf_decomp.c{n_components}_1250Iters_{n_vals_pe}pe_mvpp{min_vals_per_param}_{n_examples}ex.{pef}"

SNLI_NMF2_FILENAME = "spH.nmf_decomp2.c{n_components}_1250Iters_{n_vals_pe}pe_mvpp{min_vals_per_param}_{n_examples}ex.{pef}"
SNLI_NMF2_REFIT_FILENAME = "refit_w.spH.nmf_decomp2.c{n_components}_1250Iters_{n_vals_pe}pe_mvpp{min_vals_per_param}_{n_examples}ex.{pef}"
# f"fit_w.skip50000.50000ex.65536vpe.{SNLI_NMF2_FILENAME}"
##########################################################################

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)


def make_container(
    PEF_FILENAME, NMF_FILENAME,
    model_number: int, n_components: int, n_vals_pe: int, min_vals_per_param: int, n_examples: int,
):
    pef_file = PEF_FILENAME.format(model_number=model_number)
    nmf_file = NMF_FILENAME.format(
        n_components=n_components,
        n_vals_pe=n_vals_pe,
        min_vals_per_param=min_vals_per_param,
        n_examples=n_examples,
        pef=pef_file,
        model_number=model_number,
    )
    #
    pef = per_example.PerExampleFlatFishers.load(
        os.path.join(PER_EXAMPLES_FISHERS_DIR, pef_file),
        n_examples=None,
        # This leads to the Fishers not being loaded, which ends up being much faster.
        start_fisher_index=0,
        end_fisher_index=0,
    )
    #
    if nmf_file.startswith("spH.") or ".spH." in nmf_file:
        # Assume this means that this is a sparsified NMF decomposition.
        nmf = nmf_common.SparseNmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_file))
        # nmf = nmf_common.SparseNmfDecomposition.load(os.path.join("/fruitbasket/users/m/tmp", nmf_file))
    else:
        nmf = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_file))
    nmf.normalize_components_to_unit_norm()
    #
    container = am.PefNmfAnalysisContainer(
        pef=pef,
        nmfs=[nmf],
        tokenizer=tokenizer,
        shift_labels=True,
    )
    return container


##########################################################################
# TODO: Add option to sort components by top coefficients magnitudes.


N_EXAMPLES = 12


def make_tex(PEF_FILENAME, NMF_FILENAME,
             prefix: str,
             model_number: int, n_components: int, n_vals_pe: int, min_vals_per_param: int, n_examples: int,
             *, order_by_coeff_mag: bool = True, N_EXAMPLES: int = N_EXAMPLES):
    container = make_container(
        PEF_FILENAME, NMF_FILENAME,
        model_number=model_number,
        n_components=n_components,
        n_vals_pe=n_vals_pe,
        min_vals_per_param=min_vals_per_param,
        n_examples=n_examples,
    )
    filename = f'{prefix}_{model_number}_{n_components}c_{n_vals_pe}pe_{min_vals_per_param}mvpp_{n_examples}ex.tex'
    if order_by_coeff_mag:
        ordered_comps = np.argsort(np.sort(-container.nmfs[0].W, axis=0)[:n_examples].sum(axis=0))
        q = container.make_latex_string_for_some_components([ordered_comps], n_examples=N_EXAMPLES)
    else:
        q = container.make_all_components_latex_string(n_examples=N_EXAMPLES)
    with open(os.path.join('/fruitbasket/users/m/tmp', filename), 'w') as f:
        f.write(container.COMPONENTS_LATEX_FILE_START)
        f.write(q)
        f.write(container.COMPONENTS_LATEX_FILE_END)
    print(cu.hlg(filename))


##########################################################################

# make_tex(
#     PEF_FILENAME, NMF_FILENAME,
#     prefix='mnli',
#     model_number=0,
#     n_components=1024,
#     n_vals_pe=65536,
#     min_vals_per_param=16,
#     n_examples=50000,
# )


# make_tex(
#     SNLI_PEF_FILENAME, SNLI_NMF_FILENAME,
#     prefix='mnli_snli',
#     model_number=0,
#     n_components=512,
#     n_vals_pe=65536,
#     min_vals_per_param=10,
#     n_examples=50000,
# )


# make_tex(
#     SNLI_PEF_FILENAME, "spH.fit_coeffs_to_sparse_H.test1.h5",
#     prefix='CUDA_TEST_mnli_snli',
#     model_number=0,
#     n_components=512,
#     n_vals_pe=65536,
#     min_vals_per_param=10,
#     n_examples=50000,
# )

# make_tex(
#     "feather_berts_{model_number}.snli_validation.all_vars.10000ex.65536.h5", "spH.fit_coeffs_to_sparse_H.test2.h5",
#     prefix='CUDA_TEST_mnli_snli_validation',
#     model_number=0,
#     n_components=512,
#     n_vals_pe=65536,
#     min_vals_per_param=10,
#     n_examples=50000,
# )


# make_tex(
#     SNLI_PEF_FILENAME, SNLI_NMF2_FILENAME,
#     prefix='mnli_snli2',
#     model_number=0,
#     n_components=512,
#     n_vals_pe=65536,
#     min_vals_per_param=10,
#     n_examples=50000,
# )
# Component 258
# Component 7


# make_tex(
#     SNLI_PEF_FILENAME, SNLI_NMF2_REFIT_FILENAME,
#     prefix='mnli_snli2_refit_w',
#     model_number=0,
#     n_components=512,
#     n_vals_pe=65536,
#     min_vals_per_param=10,
#     n_examples=50000,
# )

# make_tex(
#     'feather_berts_{model_number}.snli_train.all_vars.skip50000.250000ex.131072.h5', 
#     f"fit_w.skip50000.50000ex.65536vpe.{SNLI_NMF2_FILENAME}"[:-len('{pef}')] + SNLI_PEF_FILENAME,
#     prefix='mnli_snli2_new_train_ex',
#     model_number=0,
#     n_components=512,
#     n_vals_pe=65536,
#     min_vals_per_param=10,
#     n_examples=50000,
# )

make_tex(
    'feather_berts_{model_number}.snli_train.all_vars.skip50000.250000ex.131072.h5', 
    f"fit_w.skip50000.50000ex.65536vpe.{SNLI_NMF2_FILENAME}"[:-len('{pef}')] + SNLI_PEF_FILENAME,
    prefix='mnli_snli2_new_train_ex__top32_',
    model_number=0,
    n_components=512,
    n_vals_pe=65536,
    min_vals_per_param=10,
    n_examples=50000,
    #
    N_EXAMPLES=32,
)

# Component 467 => Competition
# Component 292 => Outside/inside
# Component 121 => Food
# Component 354 => Color + mostly entailment + simple passively-voiced hypthosis
# Component 343

R"""

make_pdf() {
    local filename=$1

    rsync -ra -e ssh \
        "m@banana.cs.unc.edu:/fruitbasket/users/m/tmp/${filename}.tex" \
        "$HOME/Downloads/${filename}.tex"

    xelatex -interaction=batchmode -output-directory=/tmp ~/Downloads/${filename}.tex
    mv /tmp/${filename}.pdf ~/Desktop/projects_data/extract_merge1/ll/pdfs
}

# make_pdf mnli_0_1024c_65536pe_16mvpp_50000ex

# make_pdf mnli_snli_0_512c_65536pe_10mvpp_50000ex
# make_pdf CUDA_TEST_mnli_snli_0_512c_65536pe_10mvpp_50000ex
# make_pdf mnli_snli2_0_512c_65536pe_10mvpp_50000ex
# make_pdf CUDA_TEST_mnli_snli_validation_0_512c_65536pe_10mvpp_50000ex
# make_pdf mnli_snli2_refit_w_0_512c_65536pe_10mvpp_50000ex
# make_pdf mnli_snli2_new_train_ex_0_512c_65536pe_10mvpp_50000ex
make_pdf mnli_snli2_new_train_ex__top32__0_512c_65536pe_10mvpp_50000ex

"""
