R"""



cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/divis/divis_nmf001.py

"""
import collections
from importlib import reload
import itertools
import os
import time

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import tensorflow as tf

from em.datasets import divisibility as div_ds
from em.fishers import per_example
from em.models import divis_models
from em.tools.clustering import vat
from em.tools.nmf import parallel_sklearn_nmf as p_nmf
from em.tools.nmf import nmf_common
from em.util import flat_pack

from em.analysis import bert_nmf_analysis as bna
from em.analysis import bert_nmf_analysis2 as bna2

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/divis_mech1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

MODEL = "misc_divis_model_001_half"
PER_EXAMPLES_FISHERS = f"{MODEL}.sparse_dynamic_raw.32k.32k.h5"

DECOMP = f"nmf_decomp.16k.8k.64.{PER_EXAMPLES_FISHERS}"
N_EXAMPLES = 16 * 1024
START_FISHER_INDEX, END_FISHER_INDEX = (0, 8 * 1024)

###############################################################################


model, model_config = divis_models.load_model_from_file(os.path.join(MODELS_DIR, f'{MODEL}.h5'))


print('Starting to load saved NMF decomposition.')
start = time.time()
decomp = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP))
decomp.normalize_components_to_unit_norm()
decomp.H = decomp.get_full_H()
print('Load saved NMF decomposition time: ', time.time() - start)


print('Starting to load saved per-example Fishers.')
start = time.time()
pe_fishers_data = per_example.PerExampleFlatFishers.load(
    os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS),
    n_examples=N_EXAMPLES,
    start_fisher_index=START_FISHER_INDEX,
    end_fisher_index=END_FISHER_INDEX,
    extra_data_fields=['dividends', 'divisors'],
)
print('Load saved per-example Fishers time: ', time.time() - start)


###############################################################################

W = decomp.W
H = decomp.H

###############################################################################

cos_dissim_matrix = 1 - H @ H.T
vat_dissim_matrix, permutation = vat.vat_reorder_dissimilarity_matrix(cos_dissim_matrix)
M = 1 - vat_dissim_matrix

# bna.plot_sim_matrix(M, show=True)


###############################################################################

finetuned_vars = model.trainable_variables

localizer = bna2.ComponentLocalizationInfo(variables=finetuned_vars)

frac_in_subsets = []
for i in range(decomp.H.shape[0]):
    frac_in_subsets.append(localizer.fraction_per_variable(decomp.H[i]))

frac_in_subsets = np.array(frac_in_subsets)

reordered_frac_in_subsets = frac_in_subsets[permutation]

#######################################

subset_labels = [f"{v.name.split(':')[0]}:{i}" for i, v in reversed(list(enumerate(finetuned_vars)))]

# bna.plot_component_locations(
#     frac_in_subsets=reordered_frac_in_subsets,
#     subset_labels=subset_labels,
#     vertical_stretch=3,
#     yticks_fontsize=8,
# )

###############################################################################


def print_top_examples(component_index: int, n_examples: int):
    _, inds = tf.math.top_k(W[:, component_index], k=n_examples)
    for ind in inds:
        label = pe_fishers_data.labels[ind]
        prediction = np.argmax(pe_fishers_data.predicted_logits[ind])
        #
        divisor = pe_fishers_data.divisors[ind]
        dividend = pe_fishers_data.dividends[ind]
        example = f'{divisor}|{dividend}'
        #
        print(f'{label}, {prediction}: {example}')


# # print_top_examples(0, n_examples=16)
# print_top_examples(permutation[0], n_examples=16)
# print_top_examples(permutation[59], n_examples=16)
# print_top_examples(permutation[60], n_examples=16)

# print_top_examples(permutation[11], n_examples=16)
# print_top_examples(permutation[12], n_examples=16)
# print_top_examples(permutation[13], n_examples=16)
# print_top_examples(permutation[14], n_examples=16)

# print_top_examples(permutation[29], n_examples=16)
# print_top_examples(permutation[30], n_examples=16)

# print_top_examples(permutation[40], n_examples=16)
# print_top_examples(permutation[41], n_examples=16)

# print_top_examples(permutation[47], n_examples=16)
# print_top_examples(permutation[48], n_examples=16)

# print_top_examples(permutation[22], n_examples=16)
# print_top_examples(permutation[23], n_examples=16)

# print_top_examples(permutation[20], n_examples=16)
# print_top_examples(permutation[21], n_examples=16)
# print_top_examples(permutation[24], n_examples=16)

# print_top_examples(permutation[42], n_examples=16)
# print_top_examples(permutation[43], n_examples=16)

# print_top_examples(permutation[8], n_examples=16)
# print_top_examples(permutation[32], n_examples=16)

###############################################################################

packer = flat_pack.FlatPacker([v.shape for v in finetuned_vars])
comps_by_var = packer.decode_tf(H)

# plt.imshow(comps_by_var[0][permutation[0]]);plt.show()
# plt.imshow(comps_by_var[0][permutation[59]]);plt.show()
# plt.imshow(comps_by_var[0][permutation[60]]);plt.show()

# plt.imshow(comps_by_var[0][permutation[32]]);plt.show()

# reordered_embeddings_comps = [0, 38, 51, 55, 57, 59, 60, 63]
# plt.imshow(comps_by_var[0].numpy()[permutation[reordered_embeddings_comps]].sum(axis=0));plt.show()

# plt.imshow(comps_by_var[3][permutation[15]], cmap=sns.color_palette("rocket", as_cmap=True));plt.show()
# # plt.imshow(np.log(comps_by_var[3][permutation[15]] + 1e-12), cmap=sns.color_palette("rocket", as_cmap=True));plt.show()

# plt.plot(comps_by_var[3][permutation[15]].numpy().sum(axis=-1));plt.show()
# plt.plot(comps_by_var[3][permutation[15]].numpy().sum(axis=0));plt.show()


def plot_sing_values(var_index: int, reordered_comps, top_k: int = 16):
    comps = comps_by_var[var_index].numpy()[permutation[reordered_comps]]
    s = tf.linalg.svd(comps, compute_uv=False)
    for x in s:
        plt.plot(x[:top_k] / tf.reduce_sum(x))


# plot_sing_values(3, [10, 11, 12, 13, 14, 15, 16, 40, 41])
# plt.show()

# plot_sing_values(5, [58, 61, 62])
# plt.show()

# plot_sing_values(13, [29, 30, 47, 48, 56])
# plt.show()

# plot_sing_values(17, [42, 43, 44])
# plt.show()

# plot_sing_values(19, [4, 5])
# plt.show()

# plot_sing_values(23, [19, 20, 42, 45])
# plt.show()

# plot_sing_values(27, [19, 21, 22, 23, 34, 41, 45])
# plt.show()


###############################################################################

# kernel_variables = [v for v in finetuned_vars if 'kernel' in v.name]

# localizer_kernel = bna2.ComponentLocalizationInfo(variables=kernel_variables)

# frac_in_kernels = []
# for i in range(decomp.H.shape[0]):
#     frac_in_kernels.append(localizer_kernel.fraction_per_variable(decomp.H[i]))

# frac_in_kernels = np.array(frac_in_kernels)

# reordered_frac_in_kernels = frac_in_kernels[permutation]

# kernel_labels = [f"{v.name.split(':')[0]}:{i}" for i, v in reversed(list(enumerate(kernel_variables)))]

# bna.plot_component_locations(
#     frac_in_subsets=reordered_frac_in_kernels,
#     subset_labels=kernel_labels,
#     vertical_stretch=3,
#     yticks_fontsize=8,
# )


# ###############################################################################

# kernel.shape = [last_dim, self.units]
# kernel.shape = [d_input, d_output]

# It looks like typically either the dependencies across inputs or the dependencies
# across outputs are  quite sparse while the other isn't really (although its mass
# is relatively concentrated on a few units). From a brief look, it appears that the
# d_model-sized dependency vectors are sparse while the d_ff-sized dependency vectors
# are not. I am unsure of the general pattern, but this might have to do with the residual
# structure of my network.


def plot_dependencies_across_outputs(var_index: int, reordered_comps):
    comps = comps_by_var[var_index].numpy()[permutation[reordered_comps]]
    normalized_output_deps = comps.mean(axis=-2)
    for x in normalized_output_deps:
        plt.plot(x / np.max(x))


# plot_dependencies_across_outputs(3, [10, 11, 12, 13, 14, 15, 16, 40, 41])
# plt.show()

# plot_dependencies_across_outputs(5, [58, 61, 62])
# plt.show()

# plot_dependencies_across_outputs(13, [29, 30, 47, 48, 56])
# plt.show()

# plot_dependencies_across_outputs(17, [42, 43, 44])
# plt.show()

# plot_dependencies_across_outputs(19, [4, 5])
# plt.show()

# plot_dependencies_across_outputs(23, [19, 20, 42, 45])
# plt.show()


def plot_dependencies_across_inputs(var_index: int, reordered_comps):
    comps = comps_by_var[var_index].numpy()[permutation[reordered_comps]]
    normalized_output_deps = comps.mean(axis=-1)
    for x in normalized_output_deps:
        plt.plot(x / np.max(x))


# plot_dependencies_across_inputs(3, [10, 11, 12, 13, 14, 15, 16, 40, 41])
# plt.show()

# plot_dependencies_across_inputs(5, [58, 61, 62])
# plt.show()

# plot_dependencies_across_inputs(13, [29, 30, 47, 48, 56])
# plt.show()

# plot_dependencies_across_inputs(17, [42, 43, 44])
# plt.show()

# plot_dependencies_across_inputs(19, [4, 5])
# plt.show()

# plot_dependencies_across_inputs(23, [19, 20, 42, 45])
# plt.show()


###############################################################################


# u = inputs_dim
# v = outputs_dim
RankOneDecomp = collections.namedtuple('RankOneDecomp', ['u', 'v'])


def rank_one_decomps(comps_by_kernels):
    ret = []
    for c in comps_by_kernels:
        _, u, v = tf.linalg.svd(c)
        ret.append(RankOneDecomp(u=u[..., 0], v=v[..., 0]))
    return ret


# Only for the res blocks.
comps_by_kernels = [
    c for c, v in zip(comps_by_var, finetuned_vars)
    if 'kernel' in v.name and 'ffw_res_block' in v.name
]

# TODO: Redo with rank 1 NMF decomps and/or just summing components across their axes.
# It seems that we have negative entries in these.
comps_by_kernels_rank1s = rank_one_decomps(comps_by_kernels)

# dm_to_dff_comps = comps_by_kernels_rank1s[0].v.numpy()
# dff_to_dm_comps = comps_by_kernels_rank1s[1].u.numpy()

# dm_to_dff_comps = comps_by_kernels_rank1s[-2].v.numpy()
# dff_to_dm_comps = comps_by_kernels_rank1s[-1].u.numpy()

dm_to_dff_comps = comps_by_kernels_rank1s[4].v.numpy()
dff_to_dm_comps = comps_by_kernels_rank1s[5].u.numpy()

# dm_to_dff_comps /= np.sqrt((dm_to_dff_comps**2).sum(axis=-1, keepdims=True)) + 1e-12
# dff_to_dm_comps /= np.sqrt((dff_to_dm_comps**2).sum(axis=-1, keepdims=True)) + 1e-12


A = dm_to_dff_comps @ dff_to_dm_comps.T
B = np.abs(dm_to_dff_comps) @ np.abs(dff_to_dm_comps).T

# plt.imshow(A);plt.show()
# plt.imshow(B);plt.show()

# ######################
# # From third res block.

# # 30, 19 (negative)
# print_top_examples(30, n_examples=8)
# print_top_examples(19, n_examples=8)

# # 44, 28 (negative)
# print_top_examples(44, n_examples=8)
# print_top_examples(28, n_examples=8)

# # 13, 44 (positive)
# print_top_examples(13, n_examples=8)
# print_top_examples(44, n_examples=8)

# # 34, 19 (positive)
# print_top_examples(34, n_examples=8)
# print_top_examples(19, n_examples=8)

# ######################
# # Between last kernel of first block and first kernel of second block.

# # 7, 9
# print_top_examples(7, n_examples=8)
# print_top_examples(9, n_examples=8)

# # 48, 9
# print_top_examples(48, n_examples=8)
# print_top_examples(9, n_examples=8)

# # 9, 56
# print_top_examples(9, n_examples=8)
# print_top_examples(56, n_examples=8)

# # 6, 28
# print_top_examples(6, n_examples=8)
# print_top_examples(28, n_examples=8)

# # 34, 33
# print_top_examples(34, n_examples=8)
# print_top_examples(33, n_examples=8)

# ######################
# # From last two res kernels.

# # 29, 39
# print_top_examples(29, n_examples=8)
# print_top_examples(39, n_examples=8)

# # 47, 39
# print_top_examples(47, n_examples=8)
# print_top_examples(39, n_examples=8)

# # 25, 44
# print_top_examples(25, n_examples=8)
# print_top_examples(44, n_examples=8)

# # 7, 11
# print_top_examples(7, n_examples=8)
# print_top_examples(11, n_examples=8)


# ######################
# # From first two res kernels.

# # 52, 12
# print_top_examples(52, n_examples=8)
# print_top_examples(12, n_examples=8)

# # 23, 56
# print_top_examples(23, n_examples=8)
# print_top_examples(56, n_examples=8)

# # 9, 12
# print_top_examples(9, n_examples=8)
# print_top_examples(12, n_examples=8)

###############################################################################


# NOTE: The NMF decomp appears to produce better connections than the SVD version,
# or at least looking at the similarity A matrix it does, but that might just be luck.


# No multiprocessing: 17.01611828804016
# n_processes=1: 45.75366687774658 sec
# n_processes=16: 44.57087278366089 sec
# n_processes=None: 51.36946225166321 sec

# reload(p_nmf)
# start = time.time()
# Us, Vs = p_nmf.perform_nmfs(comps_by_kernels[3], n_components=1)
# Us = np.squeeze(Us, axis=-1)
# Vs = np.squeeze(Vs, axis=-2)
# print(time.time() - start)


# TODO: The magnitudes of W and H from these NMFs do mean much when
# comparing across different decomps due to symetry of magnitudes within
# a decomp.

# ker1, ker2 = 2, 3
ker1, ker2 = 8, 9


# TODO: Better way than adding small constant, needed to get rid of NaNs
# that throw errors in the NMF.
_, Vs = p_nmf.perform_nmfs(comps_by_kernels[ker1] + 1e-12, n_components=1)
Vs = np.squeeze(Vs, axis=-2)

Us, _ = p_nmf.perform_nmfs(comps_by_kernels[ker2] + 1e-12, n_components=1)
Us = np.squeeze(Us, axis=-1)


nUs = Us / (np.sqrt((Us**2).sum(axis=-1, keepdims=True)) + 1e-12)
nnUs = (nUs * tf.reduce_sum(comps_by_kernels[ker1], axis=[1, 2])[..., None]).numpy()

nVs = Vs / (np.sqrt((Vs**2).sum(axis=-1, keepdims=True)) + 1e-12)
nnVs = (nVs * tf.reduce_sum(comps_by_kernels[ker2], axis=[1, 2])[..., None]).numpy()


A = Us @ Vs.T
B = nUs @ nVs.T
C = nnUs @ nnVs.T

# plt.imshow(A);plt.show()
# plt.imshow(B);plt.show()
# plt.imshow(C);plt.show()

# plt.imshow(np.sqrt(Us) @ np.sqrt(Vs).T);plt.show()

##################################
# For kernels 8 -> 9

# #################
# # From A

# # 59, 42
# print_top_examples(59, n_examples=8)
# print_top_examples(42, n_examples=8)

# # 10, 57
# print_top_examples(10, n_examples=8)
# print_top_examples(57, n_examples=8)

# # 32, 57
# print_top_examples(32, n_examples=8)
# print_top_examples(57, n_examples=8)

# # 10, 5
# print_top_examples(10, n_examples=8)
# print_top_examples(5, n_examples=8)


# #################
# # From B

# # 4, 1
# print_top_examples(4, n_examples=8)
# print_top_examples(1, n_examples=8)

# # 61, 9
# print_top_examples(61, n_examples=8)
# print_top_examples(9, n_examples=8)


# #################
# # From C

# # 42, 59
# print_top_examples(42, n_examples=8)
# print_top_examples(59, n_examples=8)

# # 38, 54
# print_top_examples(38, n_examples=8)
# print_top_examples(54, n_examples=8)

# # 38, 5
# print_top_examples(38, n_examples=8)
# print_top_examples(5, n_examples=8)


# ##################################
# # For kernels 2 -> 3

# #################
# # From C

# # 55, 53
# print_top_examples(55, n_examples=8)
# print_top_examples(53, n_examples=8)

# # 43, 50
# print_top_examples(43, n_examples=8)
# print_top_examples(40, n_examples=8)

# # 31, 50
# print_top_examples(31, n_examples=8)
# print_top_examples(40, n_examples=8)

# # 63, 50
# print_top_examples(63, n_examples=8)
# print_top_examples(40, n_examples=8)

# #################
# # From B

# # 4, 39
# print_top_examples(4, n_examples=8)
# print_top_examples(39, n_examples=8)

# # 28, 44
# print_top_examples(28, n_examples=8)
# print_top_examples(44, n_examples=8)

# # 28, 6
# print_top_examples(28, n_examples=8)
# print_top_examples(6, n_examples=8)

# # 56, 9
# print_top_examples(56, n_examples=8)
# print_top_examples(9, n_examples=8)

# #################
# # From A

# # 9, 53
# print_top_examples(9, n_examples=8)
# print_top_examples(53, n_examples=8)

# # 9, 1
# print_top_examples(9, n_examples=8)
# print_top_examples(1, n_examples=8)

# # 44, 53
# print_top_examples(44, n_examples=8)
# print_top_examples(53, n_examples=8)

# # 14, 50
# print_top_examples(14, n_examples=8)
# print_top_examples(50, n_examples=8)

###############################################################################

plt.imshow(tf.reshape(tf.transpose(comps_by_var[27], [0, 2, 1]), [128, 1024]));plt.show()


###############################################################################

# TODO: "trace" units within and across examples, maybe make graph. Remember the
# residual connections when doing this!
# Trace from logits to inputs versus trace from inputs to logits?


# Let's look at ranks of the components for each kernel. Most have singular values
# concentrated in a single value. This means they are close to rank 1. This is the
# case for the gradient of a scalar function of kernel wrt to the kernel, which
# arises, for example, when computing the per-example Fisher using a single sample.
#
# Note that this exact reasoning only holds for the case of a simple fully connected
# network though (my residual connections stuff are probably fine).
