R"""



cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/divis/divis_per_layer_nmf003.py

"""
import collections
from importlib import reload
import itertools
import os
import time

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import tensorflow as tf

from em.datasets import divisibility as div_ds
from em.fishers import per_example
from em.models import divis_models
from em.models import transformer_model_vars as tmv
from em.tools.clustering import vat
from em.tools.nmf import parallel_sklearn_nmf as p_nmf
from em.tools.nmf import nmf_common
from em.util import flat_pack

from em.analysis import bert_nmf_analysis as bna
from em.analysis import bert_nmf_analysis2 as bna2

rocket_cmap = sns.color_palette("rocket", as_cmap=True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/divis_mech1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

MODEL = "misc_divis_model_001_deep_semi_sixteenth"
PER_EXAMPLES_FISHERS = f"{MODEL}.sparse_dynamic_raw.64k.32k.h5"

DECOMP = f"nmf_decomp.32k.8k.384.{PER_EXAMPLES_FISHERS}"
N_EXAMPLES = 32 * 1024
START_FISHER_INDEX, END_FISHER_INDEX = (0, 8 * 1024)

###############################################################################


model, model_config = divis_models.load_model_from_file(os.path.join(MODELS_DIR, f'{MODEL}.h5'))


print('Starting to load saved per-example Fishers.')
start = time.time()
pe_fishers_data = per_example.PerExampleFlatFishers.load(
    os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS),
    n_examples=N_EXAMPLES,
    start_fisher_index=START_FISHER_INDEX,
    end_fisher_index=END_FISHER_INDEX,
    extra_data_fields=['dividends', 'divisors'],
)
print('Load saved per-example Fishers time: ', time.time() - start)


###############################################################################

def load_nmf_decomp(subset_index: int):
    print('Starting to load saved NMF decomposition.')
    start = time.time()
    filename = f"{DECOMP[:-3]}.ssi{subset_index}.h5"
    decomp = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, filename))
    decomp.normalize_components_to_unit_norm()
    decomp.H = decomp.get_full_H()
    print('Load saved NMF decomposition time: ', time.time() - start)
    return decomp


ssi1 = 4
ssi2 = 5

decomp1 = load_nmf_decomp(ssi1)
decomp2 = load_nmf_decomp(ssi2)

variables = model.trainable_variables
homogenized_variables = tmv.homogenize_kernel_biases(variables)

###############################################################################


def dense_layer_on_homog(dense, x, use_activation=True):
    w = tf.concat([dense.kernel, dense.bias[None, :]], axis=0)
    if use_activation:
        return dense.activation(x @ w)
    else:
        return x @ w


def l2_normalize(x):
    return x / (np.sqrt((x**2).sum(axis=-1, keepdims=True)) + 1e-12)


def print_top_examples(decomp, component_index: int, n_examples: int):
    _, inds = tf.math.top_k(decomp.W[:, component_index], k=n_examples)
    for ind in inds:
        label = pe_fishers_data.labels[ind]
        prediction = np.argmax(pe_fishers_data.predicted_logits[ind])
        #
        divisor = pe_fishers_data.divisors[ind]
        dividend = pe_fishers_data.dividends[ind]
        example = f'{divisor}|{dividend}'
        #
        print(f'{label}, {prediction}: {example}')


print_top_examples(decomp1, 0, n_examples=16)
print_top_examples(decomp2, 0, n_examples=16)

###############################################################################

rsH1 = decomp1.H.reshape([decomp1.H.shape[0], *homogenized_variables[ssi1].shape])
rsH2 = decomp2.H.reshape([decomp2.H.shape[0], *homogenized_variables[ssi2].shape])

res_block_index = (ssi1 - 2) // 2
res_block = model.layers[3 + res_block_index]

# TODO: Better way than adding small constant, needed to get rid of NaNs
# that throw errors in the NMF.
ha, ga = p_nmf.perform_nmfs(rsH1 + 1e-12, max_iter=2000, n_components=1)
ha = np.sqrt(np.squeeze(ha, axis=-1))
ga = np.sqrt(np.squeeze(ga, axis=-2))

hb, gb = p_nmf.perform_nmfs(rsH2 + 1e-12, max_iter=2000, n_components=1)
hb = np.sqrt(np.squeeze(hb, axis=-1))
gb = np.sqrt(np.squeeze(gb, axis=-2))


lin_fha = dense_layer_on_homog(res_block.dense1, ha, use_activation=False).numpy()
nonlin_fha = dense_layer_on_homog(res_block.dense1, ha, use_activation=True).numpy()


# idx = 2
# plt.plot(lin_fha[idx]);plt.plot(nonlin_fha[idx]);plt.plot(ga[idx]);plt.show()
# print_top_examples(decomp1, idx, n_examples=16)

# # asdf = tf.einsum('ij,ij->i', nonlin_fha - lin_fha, ga).numpy()
# asdf = tf.einsum('ij,ij->i', l2_normalize(nonlin_fha) - l2_normalize(lin_fha), l2_normalize(ga)).numpy()
# top_inds = np.argsort(asdf)

# print_top_examples(decomp1, top_inds[0], n_examples=16)


# rga = dense_layer_on_homog(res_block.dense2, ga)
rga = res_block.dense2(ga)
plt.imshow(rga @ gb.T, cmap=rocket_cmap);plt.tight_layout();plt.show()

# print_top_examples(decomp1, 242, n_examples=16)
# print_top_examples(decomp2, 107, n_examples=16)


def print_top_examples_for_top_pairs(A, n_pairs: int, offset: int = 0, n_examples_per_pair: int = 8):
    m, n = A.shape
    _, inds = tf.math.top_k(tf.reshape(A, [-1]), k=n_pairs + offset)
    inds = inds[offset:]
    inds = tf.stack([inds // m, inds % m], axis=-1).numpy()
    for i, j in inds:
        print(f'Component {i}:')
        print_top_examples(decomp1, i, n_examples=n_examples_per_pair)
        print(f'Component {j}:')
        print_top_examples(decomp2, j, n_examples=n_examples_per_pair)
        print('')


print_top_examples_for_top_pairs(rga @ gb.T, n_pairs=32)

###############################################################################

A = np.corrcoef(decomp1.W.T, decomp2.W.T)

# NOTE: These decomps have the same shape, so I'm not sure if this will be
# correct if the decomps have a different number of components.
A = A[decomp1.W.shape[1]:, :decomp2.W.shape[1]]
A = A.T
# plt.imshow(A, cmap=rocket_cmap);plt.tight_layout();plt.show()
plt.imshow(np.abs(A), cmap=rocket_cmap);plt.tight_layout();plt.show()


# print_top_examples_for_top_pairs(A, n_pairs=32)
# print_top_examples_for_top_pairs(np.abs(A), n_pairs=32)
# print_top_examples_for_top_pairs(np.abs(A), n_pairs=2048)
# print_top_examples_for_top_pairs(np.abs(A), n_pairs=32, offset=4 * 1024)
print_top_examples_for_top_pairs(np.abs(A), n_pairs=32, offset=32)


def blah(i, j):
    print(f'Component {i}:')
    print_top_examples(decomp1, i, n_examples=8)
    print(f'Component {j}:')
    print_top_examples(decomp2, j, n_examples=8)
    #
    plt.plot(rsH1[i].sum(axis=-2))
    plt.plot(rsH2[j].sum(axis=-1))
    plt.show()


# # offset=512
# blah(222, 132)
# blah(110, 21)
# blah(42, 245)
# blah(36, 113)
# blah(141, 347)


# # offset=32
# blah(282, 23)
# blah(359, 11)
# blah(320, 83)
# blah(139, 284)
# blah(94, 118)
# blah(194, 249)
# blah(312, 26)


# blah(367, 128)
# blah(329, 304)
# blah(285, 358)
# blah(344, 4)


# plt.plot(A[:, 23]);plt.show()

comp_inds = np.argsort(-A[:, 23])
plt.plot(rsH2[23].sum(axis=-1))
# plt.plot(rsH1[comp_inds[0]].sum(axis=-2))
# plt.plot(rsH1[comp_inds[1]].sum(axis=-2))
# plt.plot(rsH1[comp_inds[2]].sum(axis=-2))
# plt.plot(rsH1[comp_inds[3]].sum(axis=-2))
# plt.plot(rsH1[comp_inds[:4]].sum(axis=-2).mean(axis=0))
plt.plot(rsH1[comp_inds[:8]].sum(axis=-2).max(axis=0))
plt.show()

# Maybe create some measure of overlap of sufficiently "active" outputs of layer1 comps
# and sufficiently "important" inputs of layer2 comps.

# blah(comp_inds[0], 23)
# blah(comp_inds[1], 23)
# blah(comp_inds[2], 23)
# blah(comp_inds[3], 23)
# blah(comp_inds[4], 23)
# blah(comp_inds[5], 23)
# blah(comp_inds[6], 23)


# 103, 207, 228, 48 (input units for comp2 23)
comp_inds0 = np.argsort(-rsH1[..., 103].sum(axis=1))
plt.plot(rsH1[..., 103].sum(axis=1));plt.show()
blah(comp_inds0[0], 23)
blah(comp_inds0[1], 23)  # This is maybe it for 103
blah(comp_inds0[2], 23)
blah(comp_inds0[3], 23)  # This is maybe it for 103
blah(comp_inds0[4], 23)  # Maybe? Selective for divisors of 5 with dividends ending with 0.
blah(comp_inds0[5], 23)
blah(comp_inds0[6], 23)  # This is maybe it for 103
blah(comp_inds0[7], 23)  # This is maybe it for 103
blah(comp_inds0[8], 23)
blah(comp_inds0[9], 23)

# print_top_examples(decomp1, comp_inds0[4], n_examples=256)

comp_inds0 = np.argsort(-rsH1[..., 207].sum(axis=1))
plt.plot(rsH1[..., 207].sum(axis=1));plt.show()
blah(comp_inds0[0], 23)
blah(comp_inds0[1], 23)
blah(comp_inds0[2], 23)  # Maybe a little bit?
blah(comp_inds0[3], 23)  # Maybe a little bit?
blah(comp_inds0[4], 23)
blah(comp_inds0[5], 23)
blah(comp_inds0[6], 23)
blah(comp_inds0[7], 23)
blah(comp_inds0[8], 23)  # Maybe a little bit? Decent amount of dividends endings with 0, divisors 3 and 6 though.
blah(comp_inds0[9], 23)


comp_inds0 = np.argsort(-rsH1[..., 228].sum(axis=1))
plt.plot(rsH1[..., 228].sum(axis=1));plt.show()
blah(comp_inds0[0], 23)
blah(comp_inds0[1], 23)
blah(comp_inds0[2], 23)
blah(comp_inds0[3], 23)
blah(comp_inds0[4], 23)
blah(comp_inds0[5], 23)
blah(comp_inds0[6], 23)
blah(comp_inds0[7], 23)
blah(comp_inds0[8], 23)
blah(comp_inds0[9], 23)


# A lot of divisors of 4 for this component.
comp_inds0 = np.argsort(-rsH1[..., 48].sum(axis=1))
plt.plot(rsH1[..., 48].sum(axis=1));plt.show()
blah(comp_inds0[0], 23)
blah(comp_inds0[1], 23)  # This is maybe it
blah(comp_inds0[2], 23)
blah(comp_inds0[3], 23)
blah(comp_inds0[4], 23)
blah(comp_inds0[5], 23)
blah(comp_inds0[6], 23)
blah(comp_inds0[7], 23)  # Maybe? Selective for divisors of 5 with dividends ending with 0.
blah(comp_inds0[8], 23)  # This is maybe it, very similar tuning.
blah(comp_inds0[9], 23)

