R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/ll/make_comps_pdf_01.py

"""
from importlib import reload
import os
import time

import numpy as np
from transformers import AutoTokenizer

from em.projects.anli import anli_misc1 as am
from em.projects.ll import hans_container
from em.projects.ll import hans_util

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ll1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'
FROM_PT = True

N_DECOMPS = 25

##########################################################################

MNIST_PEF_FILENAME = "feather_berts_{model_number}.mnli.no_embeddings.16k.16k.h5"
MNIST_NMF_FILENAME = "nmf_decomp.per_sub_block.16k.16k.256.{pef}"


HANS_PEF_FILENAME = "feather_berts_{model_number}.hans.no_embeddings.16k.16k.h5"
HANS_NMF_FILENAME = "nmf_decomp.per_sub_block.16k.16k.256.{pef}"

HANS_LONE_PEF_FILENAME = "feather_berts_{model_number}.hans_lone.no_embeddings.5k.32k.h5"
HANS_LONE_NMF_FILENAME = "nmf_decomp.per_sub_block.5k.32k.256.{pef}"

HANS_LOYE_PEF_FILENAME = "feather_berts_{model_number}.hans_loye.no_embeddings.5k.32k.h5"
HANS_LOYE_NMF_FILENAME = "nmf_transformed.per_sub_block.5k.32k.256.{pef}"

HANS_LONE_MD_PEF_FILENAME = "feather_berts_{model_number}.hans_lone.no_embeddings.metric_derived.5k.131k.h5"
HANS_LONE_MD_NMF_FILENAME = "nmf_decomp.v2.per_sub_block.5k.131k.256.{pef}"

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)


def make_container_mnist(model_number: int):
    pef_file = MNIST_PEF_FILENAME.format(model_number=model_number)
    nmf_file = MNIST_NMF_FILENAME.format(pef=pef_file)
    #
    container = am.load_pef_nmf_analysis_container(
        pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, pef_file),
        nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_file),
        n_nmfs=N_DECOMPS,
        tokenizer=tokenizer,
        shift_labels=True,
    )
    container.nmfs.force_load_all()
    return container


def make_container_hans(model_number: int):
    pef_file = HANS_PEF_FILENAME.format(model_number=model_number)
    nmf_file = HANS_NMF_FILENAME.format(pef=pef_file)
    #
    container = am.load_pef_nmf_analysis_container(
        pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, pef_file),
        nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_file),
        n_nmfs=N_DECOMPS,
        tokenizer=tokenizer,
        shift_labels=False,
    )
    hans_util.fix_up_hans_container(container)
    container.nmfs.force_load_all()
    return container


def make_container_hans_lone(model_number: int):
    pef_file = HANS_LONE_PEF_FILENAME.format(model_number=model_number)
    nmf_file = HANS_LONE_NMF_FILENAME.format(pef=pef_file)
    #
    examples = hans_util.get_first_hans_examples(
        'validation',
        5000,
        lambda ds: ds.filter(lambda x: x['heuristic'] == 'lexical_overlap' and x['label'] == 1)
    )
    #
    container = hans_container.load_analysis_container(
        pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, pef_file),
        nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_file),
        n_nmfs=N_DECOMPS,
        tokenizer=tokenizer,
        examples=examples,
    )
    container.nmfs.force_load_all()
    return container


def make_container_hans_loye(model_number: int):
    pef_file = HANS_LOYE_PEF_FILENAME.format(model_number=model_number)
    nmf_file = HANS_LOYE_NMF_FILENAME.format(pef=pef_file)
    #
    container = am.load_pef_nmf_analysis_container(
        pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, pef_file),
        nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_file),
        n_nmfs=N_DECOMPS,
        tokenizer=tokenizer,
        shift_labels=False,
    )
    hans_util.fix_up_hans_container(container)
    container.nmfs.force_load_all()
    return container


def make_container_hans_lone_md(model_number: int):
    pef_file = HANS_LONE_MD_PEF_FILENAME.format(model_number=model_number)
    nmf_file = HANS_LONE_MD_NMF_FILENAME.format(pef=pef_file)
    #
    container = am.load_pef_nmf_analysis_container(
        pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, pef_file),
        nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_file),
        n_nmfs=N_DECOMPS,
        tokenizer=tokenizer,
        shift_labels=False,
    )
    hans_util.fix_up_hans_container(container)
    container.nmfs.force_load_all()
    return container


def make_subset_names(n_nmfs):
    ret = []
    for i in range(n_nmfs // 2):
        ret.append(f'Layer {i} Attention Sub-Block')
        ret.append(f'Layer {i} Feedforward Sub-Block')
    if n_nmfs % 2:
        ret.append('Pooler')
    return ret


###############################################################################

N_EXAMPLES = 8

# MODEL_NUMBER = 0
# MODEL_NUMBER = 1
# MODEL_NUMBER = 15
# MODEL_NUMBER = 25

##########################################################################

subset_names = make_subset_names(N_DECOMPS)


def make_tex_mnli(model_number: int):
    container = make_container_mnist(model_number)
    q = container.make_all_components_latex_string(n_examples=N_EXAMPLES, nmf_names=subset_names)
    with open(f'/fruitbasket/users/m/tmp/mnli_{model_number}_all_comps.tex', 'w') as f:
        f.write(container.COMPONENTS_LATEX_FILE_START)
        f.write(q)
        f.write(container.COMPONENTS_LATEX_FILE_END)


def make_tex_hans(model_number: int):
    container = make_container_hans(model_number)
    q = container.make_all_components_latex_string(n_examples=N_EXAMPLES, nmf_names=subset_names)
    with open(f'/fruitbasket/users/m/tmp/hans_{model_number}_all_comps.tex', 'w') as f:
        f.write(container.COMPONENTS_LATEX_FILE_START)
        f.write(q)
        f.write(container.COMPONENTS_LATEX_FILE_END)


def make_tex_hans_lone(model_number: int):
    container = make_container_hans_lone(model_number)
    q = container.make_all_components_latex_string(n_examples=N_EXAMPLES, nmf_names=subset_names)
    with open(f'/fruitbasket/users/m/tmp/hans_lone_{model_number}_all_comps.tex', 'w') as f:
        f.write(container.COMPONENTS_LATEX_FILE_START)
        f.write(q)
        f.write(container.COMPONENTS_LATEX_FILE_END)


def make_tex_hans_loye(model_number: int):
    container = make_container_hans_loye(model_number)
    q = container.make_all_components_latex_string(n_examples=N_EXAMPLES, nmf_names=subset_names)
    with open(f'/fruitbasket/users/m/tmp/hans_loye_transformed_{model_number}_all_comps.tex', 'w') as f:
        f.write(container.COMPONENTS_LATEX_FILE_START)
        f.write(q)
        f.write(container.COMPONENTS_LATEX_FILE_END)


def make_tex_hans_lone_md(model_number: int):
    container = make_container_hans_lone_md(model_number)
    q = container.make_all_components_latex_string(n_examples=N_EXAMPLES, nmf_names=subset_names)
    with open(f'/fruitbasket/users/m/tmp/hans_lone_md_{model_number}_all_comps.tex', 'w') as f:
        f.write(container.COMPONENTS_LATEX_FILE_START)
        f.write(q)
        f.write(container.COMPONENTS_LATEX_FILE_END)


for model_number in [0, 1, 15, 25]:
    # make_tex_mnli(model_number)
    # make_tex_hans(model_number)
    # make_tex_hans_lone(model_number)
    # make_tex_hans_loye(model_number)
    make_tex_hans_lone_md(model_number)

R"""


make_pdf() {
    local filename=$1

    rsync -ra -e ssh \
        "m@banana.cs.unc.edu:/fruitbasket/users/m/tmp/${filename}.tex" \
        "$HOME/Downloads/${filename}.tex"

    xelatex -interaction=batchmode -output-directory=/tmp ~/Downloads/${filename}.tex
    mv /tmp/${filename}.pdf ~/Desktop/projects_data/extract_merge1/ll/pdfs
}


make_pdf mnli_0_all_comps
make_pdf mnli_1_all_comps
make_pdf mnli_15_all_comps
make_pdf mnli_25_all_comps


make_pdf hans_0_all_comps
make_pdf hans_1_all_comps
make_pdf hans_15_all_comps
make_pdf hans_25_all_comps


make_pdf hans_lone_0_all_comps
make_pdf hans_lone_1_all_comps
make_pdf hans_lone_15_all_comps
make_pdf hans_lone_25_all_comps


make_pdf hans_loye_transformed_0_all_comps
make_pdf hans_loye_transformed_1_all_comps
make_pdf hans_loye_transformed_15_all_comps
make_pdf hans_loye_transformed_25_all_comps

make_pdf hans_lone_md_0_all_comps
make_pdf hans_lone_md_1_all_comps
make_pdf hans_lone_md_15_all_comps
make_pdf hans_lone_md_25_all_comps

"""
