R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=2 python -i local_scripts/transfer1/anli1/anli_per_sublock_nmf_dev003.py

"""
from importlib import reload
import os
import time

from colorama import Fore, Back, Style
import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as sps
import seaborn as sns
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.tools.nmf import nmf_common
from em.tools.nmf import nmf_transform
from em.util import flat_pack
from em.util import hf_util

rocket_cmap = sns.color_palette("rocket", as_cmap=True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/anli_correct1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

MODEL = "textattack/bert-base-uncased-MNLI"
FROM_PT = True

PRETRAINED_MODEL = 'bert-base-uncased'

PER_EXAMPLES_FISHERS = "textattack_bert_base_uncased_MNLI.anli_r3.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
DECOMP = f"nmf_decomp.per_sub_block.16k.16k.256.{PER_EXAMPLES_FISHERS}"

###############################################################################

# N_EXAMPLES = 16 * 1024
N_EXAMPLES = 32 * 1024
START_FISHER_INDEX, END_FISHER_INDEX = (0, 16 * 1024)

model = TFAutoModelForSequenceClassification.from_pretrained(
    MODEL,
    from_pt=FROM_PT,
)
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

print('Starting to load saved per-example Fishers.')
start = time.time()
pe_fishers_data = per_example.PerExampleFlatFishers.load(
    os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS),
    n_examples=N_EXAMPLES,
    start_fisher_index=START_FISHER_INDEX,
    end_fisher_index=END_FISHER_INDEX,
)
print('Load saved per-example Fishers time: ', time.time() - start)


###############################################################################

def load_nmf_decomp(subset_index: int):
    print('Starting to load saved NMF decomposition.')
    start = time.time()
    filename = f"{DECOMP[:-3]}.ssi{subset_index}.h5"
    decomp = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, filename))

    # print(Fore.RED + Style.BRIGHT + 'NOT NORMALIZING H COMPONENTS TO UNIT NORM' + Style.RESET_ALL)
    decomp.normalize_components_to_unit_norm()

    # # Not getting full H is intended.
    # decomp.H = decomp.get_full_H()
    print('Load saved NMF decomposition time: ', time.time() - start)
    return decomp


ssi1 = 20
ssi2 = 21
decomp1 = load_nmf_decomp(ssi1)
decomp2 = load_nmf_decomp(ssi2)

###############################################################################


def print_top_examples(decomp, component: int, n_examples: int):
    if isinstance(decomp, np.ndarray):
        W = decomp
    else:
        W = decomp.W
    _, inds = tf.math.top_k(W[:, component], k=n_examples)
    for ind in inds:
        label = pe_fishers_data.labels[ind]
        if isinstance(label, tf.Tensor):
            label = label.numpy()
        #
        # Stuff for MNLI:
        label = (label + 1) % 3
        #
        pred = np.argmax(pe_fishers_data.predicted_logits[ind])
        example = tokenizer.decode(pe_fishers_data.input_ids[ind])
        example = example.replace(tokenizer.pad_token, '')
        example = example.strip()
        print(f'{Style.BRIGHT}{Fore.YELLOW}{label}{Style.RESET_ALL}, {Style.BRIGHT}{Fore.YELLOW}{pred}{Style.RESET_ALL}, {Fore.GREEN}{W[ind, component]:.3f}{Style.RESET_ALL}: {example}')


print_top_examples(decomp1, n_examples=8, component=0)
print_top_examples(decomp2, n_examples=8, component=0)


###############################################################################

def get_values_for_subset(pe_fishers_data, packer, subset_index: int, start: int, end: int):
    start_index, end_index = packer.get_range_for_tensor_by_index(subset_index)
    #
    subset_values = []
    subset_indices = []
    #
    fishers = pe_fishers_data.fishers[start:end]
    fisher_indices = pe_fishers_data.fisher_indices[start:end]
    for values, inds in zip(fishers, fisher_indices):
        mask = (start_index <= inds) & (inds < end_index)
        subset_values.append(values[mask])
        subset_indices.append(inds[mask] - start_index)
    #
    return subset_values, subset_indices


###############################################################################


variable_filter = tmv.VariableFilter(merge_embeddings=False)

variables = hf_util.get_mergeable_variables(model)
variables = variable_filter.filter_parallel_lists(variables)


subsets = tmv.group_by_sub_blocks(variables)

# We don't care about the actual shapes, just their number of parameters.
packer = flat_pack.FlatPacker([[sum(tf.size(v) for v in subset)] for subset in subsets])

subset_values, subset_indices = get_values_for_subset(pe_fishers_data, packer, ssi2, start=16 * 1024, end=20 * 1024)

# print('Starting NMF transform.')
# start = time.time()
# coeffs = nmf_transform.transform(decomp2, subset_values, subset_indices)
# print('NMF transform time: ', time.time() - start)

# # Needed for the print_top_examples to work.
# aligned_coeffs = np.concatenate([
#     -1 * np.ones([16 * 1024, coeffs.shape[-1]], dtype=coeffs.dtype),
#     coeffs],
#     axis=0,
# )


# plt.imshow(coeffs.T, cmap=rocket_cmap);plt.show()


###############################################################################

# # decomp2, component 0 => looks to predict correctly, mostly 2
# print_top_examples(decomp2, n_examples=8, component=0)
# print_top_examples(aligned_coeffs, n_examples=8, component=0)


# print_top_examples(decomp2, n_examples=8, component=34)
# print_top_examples(aligned_coeffs, n_examples=8, component=34)

# print_top_examples(decomp2, n_examples=8, component=35)
# print_top_examples(aligned_coeffs, n_examples=8, component=35)


# print_top_examples(decomp2, n_examples=8, component=46)
# print_top_examples(aligned_coeffs, n_examples=8, component=46)

# # print_top_examples(aligned_coeffs, n_examples=8, component=50)

###############################################################################

# acc = ((pe_fishers_data.labels + 1) % 3 == np.argmax(pe_fishers_data.predicted_logits, axis=-1)).astype(np.float64).mean()
# acc = 70%
# I'm guessing MNLI acc is like 83-84%

print_top_examples(decomp2, n_examples=8, component=59)
print_top_examples(decomp2, n_examples=8, component=70)


###############################################################################

corrects = ((pe_fishers_data.labels + 1) % 3) == np.argmax(pe_fishers_data.predicted_logits, axis=-1)
acc = corrects.astype(np.float64).mean()
print(acc)

incorrects = ((pe_fishers_data.labels + 1) % 3) != np.argmax(pe_fishers_data.predicted_logits, axis=-1)
incorrect_frac = incorrects.astype(np.float64).mean()
print(incorrect_frac)


def moo_the_cow(decomp, indicator, k):
    _, inds = tf.math.top_k(decomp.W.T, k=k)
    return indicator[inds].astype(np.float64).mean(axis=-1)


q = moo_the_cow(decomp2, corrects, k=16)
# q = moo_the_cow(decomp1, corrects, k=16)


plt.plot(np.sort(q));plt.show()

print_top_examples(decomp2, n_examples=16, component=-1)
print_top_examples(decomp1, n_examples=16, component=-1)


R'''
P(acc <= t) given null hypothesis.
'''



R"""
*all fine-tuning I do here has the embeddings frozen.

Compare MNLI and ANLI (r3) train and dev performance for:
- fine-tune mnli ckpt on anli
- fine-tune mnli ckpt on anli with mnli EWC
- selectively Fisher ablate "bad" ANLI NMF components then finetune on ANLI
- selectively Fisher ablate "bad" ANLI NMF components then finetune on ANLI with mnli EWC
- selectively Fisher ablate "bad" ANLI NMF components then finetune on ANLI with regularizer to steer away from "bad" components
    - For that reg, maybe ablate the "bad" components hard and then freeze/add strong L2-regularization their parameters

- Also try some sort of example selection based on ANLI NMF components, probably combined with stuff from above.
"""