R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/transfer1/anli1/anli_per_sublock_nmf_dev004.py

"""
import dataclasses
from importlib import reload
import os
import time

from colorama import Fore, Back, Style
import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as sps
import seaborn as sns
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.tools.nmf import nmf_common
from em.tools.nmf import nmf_transform
from em.util import flat_pack
from em.util import hf_util
from em.util.color_util import cu

from em.projects.anli import anli_misc1 as am


rocket_cmap = sns.color_palette("rocket", as_cmap=True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/anli_correct1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

MODEL = "textattack/bert-base-uncased-MNLI"
FROM_PT = True

PRETRAINED_MODEL = 'bert-base-uncased'

ANLI_PER_EXAMPLES_FISHERS = "textattack_bert_base_uncased_MNLI.anli_r3.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
ANLI_DECOMP = f"nmf_decomp.per_sub_block.16k.16k.256.{ANLI_PER_EXAMPLES_FISHERS}"

MNLI_PER_EXAMPLES_FISHERS = "textattack_bert_base_uncased_MNLI.mnli.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
MNLI_DECOMP = "transformed_nmf.nmf_anli_r3.textattack_bert_base_uncased_MNLI.mnli.no_embeddings.sparse_dynamic_raw.8k.16k.subblocks.h5"


print(cu.hr('NOT INCLUDING POOLER LAYER'))
# NOTE: It looks like the pooler-layer NMF NaNed out, so I am not including them.
N_DECOMPS = 24


###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)


cmn_args = dict(
    n_nmfs=N_DECOMPS,
    tokenizer=tokenizer,
    shift_labels=True,
)


reload(am)

anli_cont = am.load_pef_nmf_analysis_container(
    pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, ANLI_PER_EXAMPLES_FISHERS),
    nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, ANLI_DECOMP),
    **cmn_args,
)

mnli_cont = am.load_pef_nmf_analysis_container(
    pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, MNLI_PER_EXAMPLES_FISHERS),
    nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, MNLI_DECOMP),
    **cmn_args,
)

anli_cont.nmfs.force_load_all()
mnli_cont.nmfs.force_load_all()

###############################################################################


def asdfasdf(container, indicator, k: int, p_threshold: float, fraction_threshold: float):
    ret = []
    for i in range(container.n_nmfs):
        fractions, p_values = container.estimate_selectivity_for_indicator(i, indicator, k=k)
        mask = (fractions >= fraction_threshold) & (p_values <= p_threshold)
        inds, = np.nonzero(mask)
        ret.append(inds)
    return ret


anli_incorrects = anli_cont.get_incorrect_prediction_indicator()
mnli_incorrects = mnli_cont.get_incorrect_prediction_indicator()

anli_incorrect_indices = asdfasdf(anli_cont, anli_incorrects, k=16, p_threshold=0.05, fraction_threshold=0.5)
mnli_incorrect_indices = asdfasdf(mnli_cont, mnli_incorrects, k=16, p_threshold=0.05, fraction_threshold=0.5)


incorrect_for_both = [
    set(a) & set(m)
    for a, m in zip(anli_incorrect_indices, mnli_incorrect_indices)
]
print(list(enumerate(incorrect_for_both)))


#######################################


# Another "before" component.
anli_cont.print_top_examples(nmf_index=5, n_examples=16, component_index=36)
print(8 * '\n')
mnli_cont.print_top_examples(nmf_index=5, n_examples=16, component_index=36)


# For ANLI appears to have hypotheses related to
# "the ___ is in the article" and "the arcticle says ___". The word "article"
# is most common but other similar words such as statement, quote, and report
# are also used. MNLI doesn't look like it contained many similar examples,
# but examples look consistent with that as well.
anli_cont.print_top_examples(nmf_index=7, n_examples=16, component_index=181)
print(8 * '\n')
mnli_cont.print_top_examples(nmf_index=7, n_examples=16, component_index=181)


# Didn't look that hard, but couldn't immediately make out a pattern.
anli_cont.print_top_examples(nmf_index=10, n_examples=16, component_index=168)
print(8 * '\n')
mnli_cont.print_top_examples(nmf_index=10, n_examples=16, component_index=168)


###############################################################################

anli_incorrect_indices2 = asdfasdf(anli_cont, anli_incorrects, k=16, p_threshold=0.05, fraction_threshold=0.75)
print(list(enumerate(anli_incorrect_indices2)))

#######################################

# It looks like about half of the examples are what I assume are about words
# containing letters. The phrase "contains a(n) __" appears to be common amongst
# those examples. I can use that to try to programatically look for those examples.
#
# The MNLI examples aren't about words containing letters (idk if MNLI has any such
# examples). However, they tend to be short and looks like a lot are wrong as well.
anli_cont.print_top_examples(nmf_index=2, n_examples=16, component_index=175)
print(8 * '\n')
mnli_cont.print_top_examples(nmf_index=2, n_examples=16, component_index=175)

# For ANLI, very similar to 2:175, i.e. words containing letters. This
# appears to only have such examples (i.e. more selective to it). The words
# "has a(n) __" are common in addition to "contains a(n) __". I'm guessing there
# is also a slight tuning towards words that are short and out-of-place that
# sort of look like they might mean characters. For example, "port - au - prince office".
#
# The MNLI examples aren't about words containing letters (idk if MNLI has any such
# examples). However, they tend to be short. Unlike, 2:175, the model gets most predictions
# correct on the MNLI examples. The coefficient values are also far lower for MNLI than
# for ANLI, so the MNLI examples here might not really be using this component.
anli_cont.print_top_examples(nmf_index=3, n_examples=16, component_index=154)
print(8 * '\n')
mnli_cont.print_top_examples(nmf_index=3, n_examples=16, component_index=154)


# Looks like there is some pattern amongst the top examples, but need to
# look more deeply in order to maybe describe it. Maybe something to do with
# making inferences based on the premise.
anli_cont.print_top_examples(nmf_index=4, n_examples=16, component_index=36)
print(8 * '\n')
mnli_cont.print_top_examples(nmf_index=4, n_examples=16, component_index=36)


# Stuff like "is more than", "is longer than", and "is less than" (amongst others)
# of some property of an object discussed in the premise to a number. Interestingly,
# most of the predictions are neutral while most of the true labels are entailment.
anli_cont.print_top_examples(nmf_index=11, n_examples=16, component_index=29)

# Premise deals with whether "was born in" either some time or relative to another event.
# Can look for premises with "was born" to find these programmatically.
anli_cont.print_top_examples(nmf_index=11, n_examples=16, component_index=140)

# Need to look deeper for exact pattern, but has short premises usually being
# like "__ is/was __".
anli_cont.print_top_examples(nmf_index=11, n_examples=16, component_index=147)


# Stuff like "has met", "has been viewed", "heard", "seen", ...
# [More specific pattern, look at examples, I don't feel like writing it out here.]
# The true labels are all neutral while the predictions tend to be contradiction.
anli_cont.print_top_examples(nmf_index=11, n_examples=16, component_index=155)


# Another "before" component.
anli_cont.print_top_examples(nmf_index=19, n_examples=16, component_index=162)

# Probably has a pattern to the examples here.
anli_cont.print_top_examples(nmf_index=19, n_examples=16, component_index=235)


# Again, stuff mostly about words containing letters, abbreviations, etc. that
# type of thing. Not hyper-selective for it though. Maybe a bit with other countries/languages
# as part of the "flavor" of these problems.
anli_cont.print_top_examples(nmf_index=19, n_examples=16, component_index=243)


###########################


# THIS WORKS!
reload(am)
anli_cont.__class__ = am.PefNmfAnalysisContainer
mnli_cont.__class__ = am.PefNmfAnalysisContainer

anli_cont.print_top_examples_latex(nmf_index=11, n_examples=16, component_index=155)


# 

def make_subset_names(n_nmfs):
    ret = []
    assert n_nmfs % 2 == 0
    for i in range(n_nmfs // 2):
        ret.append(f'Layer {i} Attention Sub-Block')
        ret.append(f'Layer {i} Feedforward Sub-Block')
    return ret


subset_names = make_subset_names(N_DECOMPS)

# # q = anli_cont.make_all_components_latex_string(n_examples=16, nmf_names=subset_names)
# q = anli_cont.make_all_components_latex_string(n_examples=8, nmf_names=subset_names)

# with open('/fruitbasket/users/m/tmp/all_anli_components_latex.tex', 'w') as f:
#     f.write(anli_cont.COMPONENTS_LATEX_FILE_START)
#     f.write(q)
#     f.write(anli_cont.COMPONENTS_LATEX_FILE_END)



anli_incorrect_indices3 = asdfasdf(anli_cont, anli_incorrects, k=16, p_threshold=0.05, fraction_threshold=0.7)
# print(list(enumerate(anli_incorrect_indices3)))

# with open('/fruitbasket/users/m/tmp/anli_components_wrong_ft07_top16_latex.tex', 'w') as f:
#     f.write(anli_cont.COMPONENTS_LATEX_FILE_START)
#     f.write(anli_cont.make_latex_string_for_some_components(anli_incorrect_indices3, n_examples=16, nmf_names=subset_names))
#     f.write(anli_cont.COMPONENTS_LATEX_FILE_END)


# with open('/fruitbasket/users/m/tmp/anli_components_on_mnli_wrong_ft07_top16_latex.tex', 'w') as f:
#     f.write(mnli_cont.COMPONENTS_LATEX_FILE_START)
#     f.write(mnli_cont.make_latex_string_for_some_components(anli_incorrect_indices3, n_examples=16, nmf_names=subset_names))
#     f.write(mnli_cont.COMPONENTS_LATEX_FILE_END)


R"""


# FILENAME=all_anli_components_latex
# FILENAME=anli_components_wrong_ft07_top16_latex
FILENAME=anli_components_on_mnli_wrong_ft07_top16_latex

rsync -ra -e ssh \
    "m@banana.cs.unc.edu:/fruitbasket/users/m/tmp/$FILENAME.tex" \
    "$HOME/Downloads/$FILENAME.tex"

xelatex -interaction=batchmode -output-directory=/tmp ~/Downloads/$FILENAME.tex
mv /tmp/$FILENAME.pdf ~/Downloads


"""

###########################


###############################################################################

before_indicator = anli_cont.get_indicator_by_example_fn(lambda e: e.hypothesis_contains_word('before'))

# (~anli_incorrects & before_indicator).sum() / before_indicator.sum()

born_indicator = anli_cont.get_indicator_by_example_fn(lambda e: e.hypothesis_contains_word('born'))
# (~anli_incorrects & born_indicator).sum() / born_indicator.sum()

contains_indicator = anli_cont.get_indicator_by_example_fn(lambda e: e.hypothesis_contains_word('contains'))
# (~anli_incorrects & contains_indicator).sum() / contains_indicator.sum()


###############################################################################

# nmf_index = 19
# incorrects = anli_cont.get_incorrect_prediction_indicator()

# fracs8, p_values8 = anli_cont.estimate_selectivity_for_indicator(nmf_index, incorrects, k=8)
# fracs16, p_values16 = anli_cont.estimate_selectivity_for_indicator(nmf_index, incorrects, k=16)
# fracs32, p_values32 = anli_cont.estimate_selectivity_for_indicator(nmf_index, incorrects, k=32)


# nmf_index = 19
# incorrects = mnli_cont.get_incorrect_prediction_indicator()

# fracs8, p_values8 = mnli_cont.estimate_selectivity_for_indicator(nmf_index, incorrects, k=8)
# fracs16, p_values16 = mnli_cont.estimate_selectivity_for_indicator(nmf_index, incorrects, k=16)
# fracs32, p_values32 = mnli_cont.estimate_selectivity_for_indicator(nmf_index, incorrects, k=32)



#
#
#
#
#
# MAYBE SEE IF I CAN FIGURE OUT A WAY TO SELECTIVELY TRAIN ON DATA with the hypothesis
# containing "before" using some information from the NMF to help do this better.
#
#
#
#
#
#

# # The hypothesis contains "before" (or an equivalent term).
# anli_cont.print_top_examples(nmf_index=nmf_index, n_examples=16, component_index=162)
# mnli_cont.print_top_examples(nmf_index=nmf_index, n_examples=16, component_index=162)

###############################################################################
###############################################################################


# incorrects = ((pe_fishers_data.labels + 1) % 3) != np.argmax(pe_fishers_data.predicted_logits, axis=-1)
# incorrect_frac = incorrects.astype(np.float64).mean()
# print(incorrect_frac)


# Q8 = [
#     am.estimate_selectivity_for_indicator(decomp.W, incorrects, k=8)
#     for decomp in decomps
# ]
# Q16 = [
#     am.estimate_selectivity_for_indicator(decomp.W, incorrects, k=16)
#     for decomp in decomps
# ]
# Q32 = [
#     am.estimate_selectivity_for_indicator(decomp.W, incorrects, k=32)
#     for decomp in decomps
# ]


R"""
*all fine-tuning I do here has the embeddings frozen.

Compare MNLI and ANLI (r3) train and dev performance for:
- fine-tune mnli ckpt on anli
- fine-tune mnli ckpt on anli with mnli EWC
- selectively Fisher ablate "bad" ANLI NMF components then finetune on ANLI
- selectively Fisher ablate "bad" ANLI NMF components then finetune on ANLI with mnli EWC
- selectively Fisher ablate "bad" ANLI NMF components then finetune on ANLI with regularizer to steer away from "bad" components
    - For that reg, maybe ablate the "bad" components hard and then freeze/add strong L2-regularization their parameters

- Also try some sort of example selection based on ANLI NMF components, probably combined with stuff from above.
"""
