R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i local_scripts/wino1/wino_comps001.py

"""

import dataclasses
from importlib import reload
import os
import time

from colorama import Fore, Back, Style
import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as sps
import seaborn as sns
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.tools.nmf import nmf_common
from em.tools.nmf import nmf_transform
from em.util import flat_pack
from em.util import hf_util
from em.util.color_util import cu

from em.projects.anli import anli_misc1 as am


rocket_cmap = sns.color_palette("rocket", as_cmap=True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/winogrange1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'

# MODEL = "bert_base_mnli_to_winogrande_xl_4_epochs_01"
# PER_EXAMPLES_FISHERS = f"{MODEL}.winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
# DECOMP_FILENAME = f"nmf_decomp.per_sub_block.16k.16k.256.{PER_EXAMPLES_FISHERS}"
# unlabeled_indicator = None

# MODEL = "bert_base_mnli_to_winogrande_xl_4_epochs_01"
# PER_EXAMPLES_FISHERS = f"{MODEL}.test_val_winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
# DECOMP_FILENAME = f"nmf_decomp.per_sub_block.6k.16k.64.{PER_EXAMPLES_FISHERS}"
# # Dev set has labels, test set does not.
# unlabeled_indicator = np.ones([6_000], dtype=np.bool)
# unlabeled_indicator[:2 * 1_267] = False

# MODEL = "bert_base_mnli_to_winogrande_xl_4_epochs_01"
# PER_EXAMPLES_FISHERS = f"{MODEL}.test_val_winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
# DECOMP_FILENAME = f"nmf_decomp.per_sub_block.6k.16k.256.{PER_EXAMPLES_FISHERS}"
# # Dev set has labels, test set does not.
# unlabeled_indicator = np.ones([6_000], dtype=np.bool)
# unlabeled_indicator[:2 * 1_267] = False


# MODEL = "bert_base_mnli_to_winogrande_xl_4_epochs_01"
# PER_EXAMPLES_FISHERS = f"{MODEL}.test_val_winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
# DECOMP_FILENAME = f"nmf_decomp.per_sub_block.16k.16k.256.transformed_from_nmf_train.{PER_EXAMPLES_FISHERS}"
# # Dev set has labels, test set does not.
# unlabeled_indicator = np.ones([6_000], dtype=np.bool)
# unlabeled_indicator[:2 * 1_267] = False


MODEL = "bert_base_mnli_to_winogrande_custom_8_epochs_01"
PER_EXAMPLES_FISHERS = f"{MODEL}.winogrande_heldout.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
DECOMP_FILENAME = f"nmf_decomp.per_sub_block.10k.16k.256.{PER_EXAMPLES_FISHERS}"
unlabeled_indicator = None

FROM_PT = False

N_DECOMPS = 25

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

container = am.load_pef_nmf_analysis_container(
    pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS),
    nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP_FILENAME),
    n_nmfs=N_DECOMPS,
    tokenizer=tokenizer,
    shift_labels=True,
    unlabeled_indicator=unlabeled_indicator,
)

# Do this kinda hack to get labels correct. Having shift_labels=True
# and then doing this gets the string labels and the integer labels
# properly matched.
container.labels = container.pef.labels
container.examples = container._make_nli_examples()


container.nmfs.force_load_all()

# reload(am); container.__class__ = am.PefNmfAnalysisContainer

###############################################################################

subset_names = am.make_per_sub_block_subset_names(N_DECOMPS - 1)
subset_names.append('Pooler')

############################################

N_EXAMPLES = 8

# FILENAME = "winogrande_4_epochs_all_components"
# FILENAME = "winogrande_4_epochs_all_components_dev_test_rank_64"
# FILENAME = "winogrande_4_epochs_all_components_dev_test_rank_256"
# FILENAME = "winogrande_4_epochs_all_components_dev_test_from_train_nmf_rank_256"
FILENAME = "winogrande_8_epochs_all_components_heldout_rank_256"

with open(f'/fruitbasket/users/m/tmp/{FILENAME}.tex', 'w') as f:
    f.write(container.COMPONENTS_LATEX_FILE_START)
    f.write(container.make_all_components_latex_string(n_examples=N_EXAMPLES, nmf_names=subset_names))
    f.write(container.COMPONENTS_LATEX_FILE_END)


R"""


# FILENAME=winogrande_4_epochs_all_components
# FILENAME=winogrande_4_epochs_all_components_dev_test_rank_64
# FILENAME=winogrande_4_epochs_all_components_dev_test_rank_256
# FILENAME=winogrande_4_epochs_all_components_dev_test_from_train_nmf_rank_256
FILENAME=winogrande_8_epochs_all_components_heldout_rank_256

rsync -ra -e ssh \
    "m@banana.cs.unc.edu:/fruitbasket/users/m/tmp/$FILENAME.tex" \
    "$HOME/Downloads/$FILENAME.tex"

xelatex -interaction=batchmode -output-directory=/tmp ~/Downloads/$FILENAME.tex
mv /tmp/$FILENAME.pdf ~/Downloads


"""


# 1.14.4 Component 3 [from dev_test_from_train_nmf] looks like stuff for babies/children
# 1.11.173 Component 172 [from dev_test_from_train_nmf] feels vaguely medical/health-related


# Entropy of predicted logits to select for for overfit examples?

###############################################################################

# Filter out components where the highest coefficient is n times bigger than the second highest.
# Look for components that appear correct, make sure to handle the unlabelled examples.
# Maybe transfer from training set components to dev/val components.
# Some automatic selection/annotation of "types/characteristics" of examples.
# Something with pairs of examples showing up in the tops.
# Maybe look at examples with up to c * the coefficient of the top example, where c < 1.
