R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/nmf/nmf_math_bert_small_dev002.py

"""

import collections
import dataclasses
import os
from importlib import reload
import itertools
import time
from typing import Any, List, Sequence

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import torch
from torchnmf.nmf import NMF as TorchNMF
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em.evaluation import evaluation
from em.fishers import diagonal
from em.fishers import per_example
from em.fishers import sparse_diagonal
from em.util import hf_util
from em.util import vat_da_faak_vpn

from em.tools.nmf import nmf_common

from local_scripts.nmf import nmf_dev_bert as ndb
from local_scripts.soc import soc_dev_common as sdc

###############################################################################

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

# ndb.separate_tf_torch_gpus()

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/math_datasets_dev1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models1')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers0')

#######################################


TASK = 'math_dataset/original_true_false'
MODEL = "og_tf__bert_small__100k_steps"
FROM_PT = False
PER_EXAMPLES_FISHERS = "og_tf__bert_small__100k_steps.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

PRETRAINED_MODEL = "prajjwal1/bert-small"

#######################################

DECOMP_FILENAME = 'nmf_decomp.8k.4k.128.og_tf__bert_small__100k_steps.no_embeddings.sparse_dynamic_raw.32k.32k.h5'

N_EXAMPLES = 8 * 1024
START_FISHER_INDEX, END_FISHER_INDEX = (0, 4096)
N_NMF_COMPONENTS = 128

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)


print('Starting to load saved NMF decomposition.')
start = time.time()
decomp = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP_FILENAME))
print('Load saved NMF decomposition time: ', time.time() - start)


print('Starting to load saved per-example Fishers.')
start = time.time()
pe_fishers_data = per_example.PerExampleFlatFishers.load(
    os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS),
    n_examples=N_EXAMPLES,
    start_fisher_index=START_FISHER_INDEX,
    end_fisher_index=END_FISHER_INDEX,
)
print('Load saved per-example Fishers time: ', time.time() - start)

###############################################################################

W = decomp.W
H = decomp.H

examples = {'input_ids': pe_fishers_data.input_ids}
input_ids = pe_fishers_data.input_ids
labels = pe_fishers_data.labels
predicted_logits = pe_fishers_data.predicted_logits

ndb.print_top_examples(W, tokenizer, examples, labels, n_examples=8, component=0)

###############################################################################

'''

Some "simple" comp indicies: 2, 12, 14, 19, 23, 24, 25, 28, 29, 62


28: numbers ending in 5 do not divide odd numbers not ending in 5
62: even numbers do not divide odd numbers

28 and 62 both appear to have parameter representations concentrated in the last layer
'''

# Looks to be stuff related to numbers ending in 5 dividing numbers with odd, non-5
# last digit.
component = 28
n_examples = 8
_, inds = tf.math.top_k(W[:, component], k=n_examples)
inds = inds.numpy()
top_coeffs = W[inds]

plt.plot(top_coeffs.T)
plt.show()


def show_selectivity_of_top_parameters(H, component, n_params):
    _, inds = tf.math.top_k(H[component], k=n_params)
    inds = inds.numpy()
    top_params = H.T[inds].T
    # top_params /= top_params[component] + 1e-12
    # top_params = np.log(top_params)
    plt.imshow(top_params)
    plt.show()


show_selectivity_of_top_parameters(H, 28, 512)


# Something like merge model with initial/pretrained/zeros/negative of what it is checkpoint
# using full Fisher for fine-tuned checkpoint and the component(s) for the other checkpoint.