R"""



cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/math_ds_writeup1/no_pooler_analysis001.py

"""
from importlib import reload
import os
import re
import time

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.tools.clustering import vat
from em.tools.nmf import nmf_common
from em.util import flat_pack
from em.util import hf_util

from em.analysis import bert_nmf_analysis as bna
from em.analysis import bert_nmf_analysis2 as bna2
from em.experimental import selective_ablation1

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/math_datasets_dev1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models1')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers0')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers0')

###############################################################################

TASK = 'math_dataset/original_true_false'
SEQUENCE_LENGTH = 128

MODEL = "og_tf__bert_small__100k_steps"
FROM_PT = False

PRETRAINED_MODEL = "prajjwal1/bert-small"

FISHER = "og_tf__bert_small__100k_steps.dense.32k.h5"
PER_EXAMPLES_FISHERS = "og_tf__bert_small__100k_steps.no_embeddings_no_pooler.sparse_dynamic_raw.32k.32k.h5"

# DECOMP = 'nmf_decomp.8k.4k.128.og_tf__bert_small__100k_steps.no_embeddings_no_pooler.sparse_dynamic_raw.32k.32k.h5'
# N_EXAMPLES = 8 * 1024
# START_FISHER_INDEX, END_FISHER_INDEX = (0, 4096)


# DECOMP = 'nmf_decomp.16k.8k.128.reduced_1.og_tf__bert_small__100k_steps.no_embeddings_no_pooler.sparse_dynamic_raw.32k.32k.h5'
# N_EXAMPLES = 16 * 1024
# START_FISHER_INDEX, END_FISHER_INDEX = (0, 8 * 1024)


DECOMP = 'nmf_decomp.8k.4k.256.reduced_1.og_tf__bert_small__100k_steps.no_embeddings_no_pooler.sparse_dynamic_raw.32k.32k.h5'
N_EXAMPLES = 8 * 1024
START_FISHER_INDEX, END_FISHER_INDEX = (0, 4096)


###############################################################################

'''

'''
# bna.print_top_examples(decomp.W, tokenizer, pe_fishers_data.input_ids, pe_fishers_data.labels, n_examples=8, component=14)

# COMPONENT = 59
COMPONENT = 0

###############################################################################
# TODO: In proper code, I can probably multithread/multiprocess this to do all these
# loads below in parallel.
#############################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

print('Starting to load fisher.')
start = time.time()
dense_fisher = diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER))
print('Load saved fishertime: ', time.time() - start)


print('Starting to load saved NMF decomposition.')
start = time.time()
decomp = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP))
decomp.normalize_components_to_unit_norm()
decomp.H = decomp.get_full_H()
print('Load saved NMF decomposition time: ', time.time() - start)


print('Starting to load saved per-example Fishers.')
start = time.time()
pe_fishers_data = per_example.PerExampleFlatFishers.load(
    os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS),
    n_examples=N_EXAMPLES,
    start_fisher_index=START_FISHER_INDEX,
    end_fisher_index=END_FISHER_INDEX,
)
print('Load saved per-example Fishers time: ', time.time() - start)

###############################################################################

W = decomp.W
H = decomp.H

#######################################

# We did not compute per-example Fishers for the embeddings or the pooler, so ignore them.
variable_filter = tmv.VariableFilter(
    merge_embeddings=False,
    merge_pooler=False,
)

finetuned_model = TFAutoModelForSequenceClassification.from_pretrained(
    os.path.join(MODELS_DIR, MODEL), from_pt=FROM_PT
)

finetuned_vars = hf_util.get_mergeable_variables(finetuned_model)
batch_fishers = dense_fisher.fishers

finetuned_vars, batch_fishers = variable_filter.filter_parallel_lists(finetuned_vars, batch_fishers)

###############################################################################

# component = H[COMPONENT]

# flat_packer = flat_pack.FlatPacker([v.shape for v in merge_vars])
# assert flat_packer.flat_size == component.shape[0]

###############################################################################

cos_sim_matrix = H @ H.T
cos_dissim_matrix = 1 - cos_sim_matrix

# ivat_dissim_matrix = vat.ivat_reorder_dissimilarity_matrix(cos_dissim_matrix)
# ivat_sim_matrix = 1 - ivat_dissim_matrix

# reload(vat)
vat_dissim_matrix, permutation = vat.vat_reorder_dissimilarity_matrix(cos_dissim_matrix)
vat_sim_matrix = 1 - vat_dissim_matrix

# plt.imshow(vat_sim_matrix, vmin=0, vmax=1);plt.show()

# M equals vat_sim_matrix (up to numerical precision)
M = cos_sim_matrix[permutation]
M = M[:, permutation]

bna.plot_sim_matrix(M, show=True)


###############################################################################

localizer = bna2.ComponentLocalizationInfo(variables=finetuned_vars)

frac_in_subsets = []
for i in range(decomp.H.shape[0]):
    # frac_in_subsets.append(localizer.fraction_per_layer(decomp.H[i]))
    frac_in_subsets.append(localizer.fraction_per_variable(decomp.H[i]))

frac_in_subsets = np.array(frac_in_subsets)

reordered_frac_in_subsets = frac_in_subsets[permutation]

#######################################

# subset_labels = [
#     'Layer 4',
#     'Layer 3',
#     'Layer 2',
#     'Layer 1',
# ]
subset_labels = [tmv.to_nice_name(v) for v in reversed(finetuned_vars)]

bna.plot_component_locations(
    frac_in_subsets=reordered_frac_in_subsets,
    subset_labels=subset_labels,
    # vertical_stretch=5,
    #
    vertical_stretch=3,
    yticks_fontsize=9,
)
###############################################################################


def print_top_examples(
    component: int,
    n_examples: int,
):
    _, inds = tf.math.top_k(decomp.W[:, component], k=n_examples)
    for ind in inds:
        label = pe_fishers_data.labels[ind]
        prediction = np.argmax(pe_fishers_data.predicted_logits[ind])
        example = tokenizer.decode(pe_fishers_data.input_ids[ind])
        example = example.replace(tokenizer.pad_token, '')
        example = example.replace(tokenizer.cls_token, '')
        example = example.replace(tokenizer.sep_token, '')
        example = example.strip()
        print(f'{label}, {prediction}: {example}')


print_top_examples(n_examples=16, component=0)

# # Component 43 of the 16k.8.64 is pretty cool. Looks it is correlated
# # with the prediction meaning ``unequal'' regardless of whether the
# # example says equal or unequal. So it is correlated with a specific
# # meaning rather than the label.
# print_top_examples(n_examples=16, component=43)


# 8k.4k.128
# # Let's focus on component 23 for now.
# print_top_examples(n_examples=16, component=23)


# 8k.4k.256
# Looks simple, basically big odd numbers are not composite, should be easy
# to analyze.
print_top_examples(n_examples=16, component=16)
print_top_examples(n_examples=16, component=38)
print_top_examples(n_examples=16, component=53)
print_top_examples(n_examples=16, component=60)
print_top_examples(n_examples=16, component=62)

# Same thing as 16, but basically says big odd numbers are prime:
print_top_examples(n_examples=16, component=20)
print_top_examples(n_examples=16, component=31)
print_top_examples(n_examples=16, component=32)


# Numbers ending in 5 are composite:
print_top_examples(n_examples=16, component=48)


# Odd numbers are not divisible by even numbers
print_top_examples(n_examples=16, component=28)
print_top_examples(n_examples=16, component=67)

# Interesting, handles cases where stuff is divisble by 2, 5, or 10, which
# can be determined trivially by looking at the last digit.
print_top_examples(n_examples=16, component=56)

# Same as 56, but only for 5s and 10s.
print_top_examples(n_examples=16, component=57)

###############################################################################

# # numerics = [str(d) for d in range(10)]
# keywords = ['divide', 'factor', 'multiple']

# assert all(kw in tokenizer.vocab for kw in keywords)

# asdf = []
# for token in tokenizer.vocab.keys():
#     # if any(n in token for n in numerics):
#     if re.search(r'^\d+$', token):
#         # print(token)
#         asdf.append(token)

# asdf = list(sorted(asdf))
# print(asdf)

