R"""



cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/divis/divis_nmf002.py

"""
import collections
from importlib import reload
import itertools
import os
import time

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import tensorflow as tf

from em.datasets import divisibility as div_ds
from em.fishers import per_example
from em.models import divis_models
from em.models import transformer_model_vars as tmv
from em.tools.clustering import vat
from em.tools.nmf import parallel_sklearn_nmf as p_nmf
from em.tools.nmf import nmf_common
from em.util import flat_pack

from em.analysis import bert_nmf_analysis as bna
from em.analysis import bert_nmf_analysis2 as bna2

rocket_cmap = sns.color_palette("rocket", as_cmap=True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/divis_mech1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

MODEL = "misc_divis_model_001_half"
PER_EXAMPLES_FISHERS = f"{MODEL}.sparse_dynamic_raw.32k.32k.h5"

DECOMP = f"nmf_decomp.16k.8k.64.{PER_EXAMPLES_FISHERS}"
N_EXAMPLES = 16 * 1024
START_FISHER_INDEX, END_FISHER_INDEX = (0, 8 * 1024)

###############################################################################


model, model_config = divis_models.load_model_from_file(os.path.join(MODELS_DIR, f'{MODEL}.h5'))


print('Starting to load saved NMF decomposition.')
start = time.time()
decomp = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP))
decomp.normalize_components_to_unit_norm()
decomp.H = decomp.get_full_H()
print('Load saved NMF decomposition time: ', time.time() - start)


print('Starting to load saved per-example Fishers.')
start = time.time()
pe_fishers_data = per_example.PerExampleFlatFishers.load(
    os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS),
    n_examples=N_EXAMPLES,
    start_fisher_index=START_FISHER_INDEX,
    end_fisher_index=END_FISHER_INDEX,
    extra_data_fields=['dividends', 'divisors'],
)
print('Load saved per-example Fishers time: ', time.time() - start)


###############################################################################

W = decomp.W
H = decomp.H

###############################################################################

cos_dissim_matrix = 1 - H @ H.T
vat_dissim_matrix, permutation = vat.vat_reorder_dissimilarity_matrix(cos_dissim_matrix)
M = 1 - vat_dissim_matrix

# bna.plot_sim_matrix(M, show=True)


###############################################################################

finetuned_vars = model.trainable_variables

localizer = bna2.ComponentLocalizationInfo(variables=finetuned_vars)

frac_in_subsets = []
for i in range(decomp.H.shape[0]):
    frac_in_subsets.append(localizer.fraction_per_variable(decomp.H[i]))

frac_in_subsets = np.array(frac_in_subsets)

reordered_frac_in_subsets = frac_in_subsets[permutation]

#######################################

subset_labels = [f"{v.name.split(':')[0]}:{i}" for i, v in reversed(list(enumerate(finetuned_vars)))]

# bna.plot_component_locations(
#     frac_in_subsets=reordered_frac_in_subsets,
#     subset_labels=subset_labels,
#     vertical_stretch=3,
#     yticks_fontsize=8,
# )

###############################################################################


def print_top_examples(component_index: int, n_examples: int):
    _, inds = tf.math.top_k(W[:, component_index], k=n_examples)
    for ind in inds:
        label = pe_fishers_data.labels[ind]
        prediction = np.argmax(pe_fishers_data.predicted_logits[ind])
        #
        divisor = pe_fishers_data.divisors[ind]
        dividend = pe_fishers_data.dividends[ind]
        example = f'{divisor}|{dividend}'
        #
        print(f'{label}, {prediction}: {example}')


###############################################################################

packer = flat_pack.FlatPacker([v.shape for v in finetuned_vars])
comps_by_var = packer.decode_tf(H)


# It looks like typically either the dependencies across inputs or the dependencies
# across outputs are  quite sparse while the other isn't really (although its mass
# is relatively concentrated on a few units). From a brief look, it appears that the
# d_model-sized dependency vectors are sparse while the d_ff-sized dependency vectors
# are not. I am unsure of the general pattern, but this might have to do with the residual
# structure of my network.


def plot_dependencies_across_outputs(var_index: int, reordered_comps):
    comps = comps_by_var[var_index].numpy()[permutation[reordered_comps]]
    normalized_output_deps = comps.mean(axis=-2)
    for x in normalized_output_deps:
        plt.plot(x / np.max(x))


def plot_dependencies_across_inputs(var_index: int, reordered_comps):
    comps = comps_by_var[var_index].numpy()[permutation[reordered_comps]]
    normalized_output_deps = comps.mean(axis=-1)
    for x in normalized_output_deps:
        plt.plot(x / np.max(x))


###############################################################################

# TODO: The magnitudes of W and H from these NMFs do mean much when
# comparing across different decomps due to symetry of magnitudes within
# a decomp.

# Only for the res blocks.
comps_by_kernels = [
    c for c, v in zip(comps_by_var, finetuned_vars)
    if 'kernel' in v.name and 'ffw_res_block' in v.name
]


# ker1, ker2 = 2, 3
ker1, ker2 = 8, 9

res_block_index = ker1 // 2
res_block = model.layers[3 + res_block_index]


# TODO: Better way than adding small constant, needed to get rid of NaNs
# that throw errors in the NMF.
a, _ = p_nmf.perform_nmfs(comps_by_kernels[ker1] + 1e-12, n_components=1)
a = np.squeeze(a, axis=-1)
a = np.sqrt(a)

b, _ = p_nmf.perform_nmfs(comps_by_kernels[ker2] + 1e-12, n_components=1)
b = np.squeeze(b, axis=-1)
b = np.sqrt(b)

# TODO: Need to incorporate biases.
ra = res_block.dense1(a).numpy()


# plt.imshow(ra @ b.T);plt.show()

# plt.imshow(np.sqrt(Us) @ np.sqrt(Vs).T);plt.show()

# 57, 5
print_top_examples(57, n_examples=8)
print_top_examples(5, n_examples=8)

# 6, 5
print_top_examples(6, n_examples=8)
print_top_examples(5, n_examples=8)

# 57, 7
print_top_examples(57, n_examples=8)
print_top_examples(7, n_examples=8)

# 38, 25
print_top_examples(38, n_examples=8)
print_top_examples(25, n_examples=8)

###############################################################################


# reload(tmv)
homog_res_block_comps = tmv.homogenize_kernel_biases([
    tmv.NamedProxy(c.numpy(), v.name) for c, v in zip(comps_by_var, finetuned_vars)
    if 'ffw_res_block' in v.name
])

# ker1, ker2 = 8, 9
ker1, ker2 = 4, 5

res_block_index = ker1 // 2
res_block = model.layers[3 + res_block_index]

# TODO: Better way than adding small constant, needed to get rid of NaNs
# that throw errors in the NMF.
ha, _ = p_nmf.perform_nmfs(homog_res_block_comps[ker1] + 1e-12, n_components=1)
ha = np.squeeze(ha, axis=-1)
ha = np.sqrt(ha)

hb, _ = p_nmf.perform_nmfs(homog_res_block_comps[ker2] + 1e-12, n_components=1)
hb = np.squeeze(hb, axis=-1)
hb = np.sqrt(hb)


def dense_layer_on_homog(dense, x):
    W = tf.concat([dense.kernel, dense.bias[None, :]], axis=0)
    return dense.activation(x @ W)


def l2_normalize(x):
    return x / (np.sqrt((x**2).sum(axis=-1, keepdims=True)) + 1e-12)


rha = dense_layer_on_homog(res_block.dense1, ha).numpy()
hrha = np.concatenate([rha, np.ones([rha.shape[0], 1], dtype=rha.dtype)], axis=-1)

# plt.imshow(rha @ hb[:, :-1].T);plt.show()
# plt.imshow(hrha @ hb.T);plt.show()

# plt.imshow(l2_normalize(rha) @ l2_normalize(hb[:, :-1]).T);plt.show()
# plt.imshow(l2_normalize(hrha) @ l2_normalize(hb).T);plt.show()


#######################################
# From: l2_normalize(rha) @ l2_normalize(hb[:, :-1]).T

# 25, 38
print_top_examples(25, n_examples=8)
print_top_examples(38, n_examples=8)

# 7, 10
print_top_examples(7, n_examples=8)
print_top_examples(10, n_examples=8)

# 28, 10
print_top_examples(28, n_examples=8)
print_top_examples(10, n_examples=8)

# 62, 38
print_top_examples(62, n_examples=8)
print_top_examples(38, n_examples=8)

# 20, 59
print_top_examples(20, n_examples=8)
print_top_examples(59, n_examples=8)

# 6, 25
print_top_examples(6, n_examples=8)
print_top_examples(25, n_examples=8)

# 61, 8
print_top_examples(61, n_examples=8)
print_top_examples(8, n_examples=8)


#######################################

# 19, 63
print_top_examples(19, n_examples=8)
print_top_examples(63, n_examples=8)

# 20, 63
print_top_examples(20, n_examples=8)
print_top_examples(63, n_examples=8)

# 47, 63
print_top_examples(47, n_examples=8)
print_top_examples(63, n_examples=8)

# 20, 19
print_top_examples(20, n_examples=8)
print_top_examples(19, n_examples=8)

# 16, 11
print_top_examples(16, n_examples=8)
print_top_examples(11, n_examples=8)


###############################################################################
# Let's analyze a single example (or several single examples) in detail.
# See how the relationships between components are reflected in the coefficients
# of the example.
divisors = pe_fishers_data.divisors
dividends = pe_fishers_data.dividends

focused_divisor = 8
divisor_example_inds, = np.nonzero(divisors == focused_divisor)

# plt.imshow(W[divisor_example_inds[:128]].T, cmap=rocket_cmap);plt.tight_layout();plt.show()

# comp 16 appears to be selective for divisors of 8 with dividends with
# even 100th places and last two digits forming a number divisible by 8.
# 24 and 64 appear to be the most common last digits although other multiples
# of 8 do appear.
print_top_examples(16, n_examples=32)
plt.plot(frac_in_subsets[16]);plt.show()
# plt.plot(frac_in_subsets[16, 3::2]);plt.show()

print_top_examples(22, n_examples=32)
plt.plot(frac_in_subsets[22]);plt.show()

# comp 48 appears to be selective for divisors of 8 with dividends whose last
# three digits are divisble by 8. Most divendends end with 0 with a few ending
# with 8. These will always be divisible by 8 since if n = 1000 * k + m with 
# 8|m, we have 8|1000 since 1000 = 10^3 = 2^3 * 5^3 = 8 * 5^3, so 8|1000 *k and
# thus 8|(1000*k + m). Having the last digit be 8 is equivalent to having the
# last digit be 0 from a divisibility point of view. To see why, suppose n = 10 * k + 8.
# We see that n = 10 * k + 8 = 10 * k (mod 8).
print_top_examples(48, n_examples=32)
# Much of the component's mass (about half) is concentrated in the embeddings and the embeddings -> d_model
# linear projection layer. About 15% of the mass is in the logits layer.
plt.plot(frac_in_subsets[48]);plt.show()

###########################################################


def plot_embeddings_projection_layer_for_comp(component_index: int):
    # embeddings_table = model.layers[0].embeddings_table
    embeddings_table = comps_by_var[0][component_index]
    c = comps_by_var[1][component_index]
    c = tf.reshape(c, embeddings_table.shape.concatenate(c.shape[-1:]))
    # c = tf.einsum('ijk,lj->ilk', c, embeddings_table)
    c = tf.einsum('ijk,lj->lik', c, embeddings_table)
    c = tf.reshape(c, [-1, c.shape[-1]])
    plt.imshow(c, cmap=rocket_cmap)
    plt.tight_layout()
    plt.show()


# plot_embeddings_projection_layer_for_comp(16)
# plot_embeddings_projection_layer_for_comp(48)

bna.plot_coeffs_for_top_examples(
    decomp,
    np.arange(permutation.shape[0]),
    component=48,
    # component=16,
    n_examples=128,
    show=True
)


def plot_coeffs_for_top_examples_masked(component_index, n_examples=128):
    # Masks out the component and components extremely concentrated on the
    # embeddings so that we can look better at other components.
    _, inds = tf.math.top_k(decomp.W[:, component_index], k=n_examples)
    x = np.copy(decomp.W[inds, :])
    x[:, component_index] = np.nan
    x[:, frac_in_subsets[:, 0] > 0.9] = np.nan
    plt.imshow(x, cmap=rocket_cmap)
    plt.show()


# plot_coeffs_for_top_examples_masked(48)
# plot_coeffs_for_top_examples_masked(16, n_examples=64)


# # Some of the common co-occuring components for top examples of component 48:
# # 3, 4, 23, 26, 29, 39, 52 (there are also a few more)

# # Comp 3 is not super-selective for dividends ending with zero, but is fairly
# # highly enriched in them. Generally has dividends containing a lot of zeros
# # with a lot of them ending in zero as well. Most of mass is in embeddings with
# # a little in the kernel of the first dense layer of the first res block.
# print_top_examples(3, n_examples=16)
# plt.plot(frac_in_subsets[3]);plt.show()

# # I think comp 4 is similar to comp 3 but maybe with sixes instead of zeros and
# # also without a preference for having dividends ending with six.
# print_top_examples(4, n_examples=16)
# plt.plot(frac_in_subsets[4]);plt.show()

# # Comp 23 is similar to comp 3 but I think with twos instead of zeros.
# print_top_examples(23, n_examples=16)
# plt.plot(frac_in_subsets[23]);plt.show()

# print_top_examples(26, n_examples=16)
# plt.plot(frac_in_subsets[26]);plt.show()

# print_top_examples(29, n_examples=16)
# plt.plot(frac_in_subsets[29]);plt.show()

# print_top_examples(39, n_examples=16)
# plt.plot(frac_in_subsets[39]);plt.show()

# print_top_examples(52, n_examples=16)
# plt.plot(frac_in_subsets[52]);plt.show()


# ############################
# # From the plotting top masked coeffs for comp 48.
# # 15, 18, 22
# print_top_examples(15, n_examples=16)
# print_top_examples(18, n_examples=16)
# print_top_examples(22, n_examples=16)
###############################################################################

component_index = 16
n_examples = 8
_, top_example_inds = tf.math.top_k(decomp.W[:, component_index], k=n_examples)
top_example_inds = top_example_inds.numpy()

#

fisher_indices = per_example._to_coo_indices(pe_fishers_data.fisher_indices[top_example_inds])
fisher_values = pe_fishers_data.fishers[top_example_inds].reshape([-1])
dense_shape = (top_example_inds.shape[0], pe_fishers_data.fisher_dense_size)

sparse_example_fishers = tf.sparse.SparseTensor(fisher_indices.T, fisher_values, dense_shape)
sparse_example_fishers = tf.sparse.reorder(sparse_example_fishers)
dense_example_fishers = tf.sparse.to_dense(sparse_example_fishers)

unpacked_example_fishers = packer.decode_tf(dense_example_fishers)

#

top_example_coeffs = W[top_example_inds]
recon_example_fishers = top_example_coeffs @ H
unpacked_recon_example_fishers = packer.decode_tf(recon_example_fishers)

#

# plt.imshow(unpacked_example_fishers[0].numpy().reshape([n_examples, -1]));plt.show()
# plt.imshow(unpacked_recon_example_fishers[0].numpy().reshape([n_examples, -1]));plt.show()

# tf.linalg.norm(unpacked_example_fishers[0] - unpacked_recon_example_fishers[0]).numpy()


# plt.imshow(unpacked_example_fishers[5].numpy().reshape([-1, unpacked_example_fishers[5].shape[-1]]).T);plt.tight_layout();plt.show()
# plt.imshow(unpacked_recon_example_fishers[5].numpy().reshape([-1, unpacked_recon_example_fishers[5].shape[-1]]).T);plt.tight_layout();plt.show()

###############################################################################

rank1_res_block_nmfs = []
for ker in homog_res_block_comps:
    alpha, beta = p_nmf.perform_nmfs(homog_res_block_comps[ker1] + 1e-12, n_components=1)
    alpha = np.squeeze(alpha, axis=-1)
    beta = np.squeeze(beta, axis=-2)
    rank1_res_block_nmfs.append((np.sqrt(alpha), np.sqrt(beta)))


# Compare g to layer(a) (in terms of activated units), this could be used to determine what other components
# are needed since for ReLUs the g should have a 0 for inactive units.
# or not?

aa, gg = rank1_res_block_nmfs[3]
aa = aa[16]
gg = gg[16]

res_block = model.layers[3 + 1]

hh, = dense_layer_on_homog(res_block.dense1, aa[None, ...])
frac_same = ((hh > 1e-5) == (gg > 1e-5)).numpy().astype(np.float64).mean()
print(frac_same)
# plt.plot(hh > 1e-5)
# plt.plot(gg > 1e-5)
# plt.show()
