R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 -i local_scripts/soc/nmf_cifar_dev001.py



CUDA_VISIBLE_DEVICES=0 python -i local_scripts/soc/nmf_cifar_dev001.py
CUDA_VISIBLE_DEVICES=2,3 python -i local_scripts/soc/nmf_cifar_dev001.py


"""
import collections
import dataclasses
import os
from importlib import reload
import itertools
import time
from typing import Any, List, Sequence

import h5py
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import NMF
import tensorflow as tf
import tensorflow_probability as tfp
import torch
from torchnmf.nmf import NMF as TorchNMF
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em.datasets import glue
from em.evaluation import evaluation
from em.fishers import diagonal
from em.fishers import per_example
from em.fishers import sparse_diagonal
from em.merging import merging
from em.models.generative import vae
# from em.models.generative import soc
from em.util import hf_util
from em.util import vat_da_faak_vpn

# from em.tools import bionmf_gpu as bionmf

from local_scripts.soc import soc_dev_common as sdc
from local_scripts.soc import soc_dev_cifar as cf


tfd = tfp.distributions

###############################################################################

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)


###############################################################################
###############################################################################

PE_BATCH_SIZE = 256
SOC_BATCH_SIZE = 256

#######################################

trained_model = cf.load_trained_model()
setattr(trained_model, 'num_labels', 10)

variables = trained_model.trainable_variables
# Remove the first and last layers to hopefully prevent us from
# getting components tuned too closely to low-level features and
# trivially to class identity.

###############################################################################


# n_components = 512
# component_size = cf.compute_component_size(trained_model)

# soc_ds = cf.create_soc_dataset(
#     trained_model,
#     soc_batch_size=SOC_BATCH_SIZE,
#     pe_batch_size=PE_BATCH_SIZE,
# )


# for _, x in soc_ds:
#     break


# sx = tf.sort(x, axis=-1, direction="DESCENDING")
# plt.plot(tf.math.log(tf.transpose(sx[:10])))
# plt.show()

###############################################################################

# nmf_model = NMF(n_components=32)
# W = nmf_model.fit_transform(x.numpy())

# plt.imshow(W)
# plt.show()

(train_images, train_labels), _ = tf.keras.datasets.cifar10.load_data()
train_images = train_images / 255.00

# nx = 512
nx = 2048
# nx = 2 * 2048
x, y = train_images[:nx], train_labels[:nx]
fisher = diagonal.compute_exact_fisher_for_batch(
    tf.cast(x, tf.float32),
    trained_model,
    variables,
    expectation_wrt_logits=True,
    per_example=True,
)
fisher = [
    tf.reshape(f, [nx, -1])
    for f in fisher
]
fisher = tf.concat(fisher, axis=-1)
fisher = tf.linalg.l2_normalize(fisher, axis=-1)

# # nmf_model = NMF(n_components=32, max_iter=2000)
# nmf_model = NMF(n_components=128, max_iter=750)
# W = nmf_model.fit_transform(fisher.numpy())


fisher = fisher.numpy()

torch_fisher = torch.from_numpy(fisher.T).cuda()

nmf_model = TorchNMF(torch_fisher.shape, rank=32).cuda()
nmf_model.fit(torch_fisher)

W = nmf_model.W.detach().cpu().numpy()

# max_fisher_values = tf.reduce_max(fisher, axis=0)
# threshold = 1e-2
# gpu_model = bionmf.NmfGpu()
# W, Ht = gpu_model.run_nmf(
#     fisher.numpy()[:, max_fisher_values.numpy() >= threshold],
#     # fisher.numpy(),
#     # n_components=128,
#     # n_components=32,
#     n_components=2,
#     # # TODO: See how this choice of GPU interacts TF's GPU memory hogging.
#     # gpu_device=3,
#     stop_threshold=1000,
#     max_iters=10_000,
# )

# raise ValueError


# plt.imshow(tf.transpose(W))
# plt.show()

_, inds = tf.math.top_k(tf.reshape(y, [-1]), k=nx)
sW = tf.gather(W, inds)

_, inds2 = tf.math.top_k(tf.reduce_mean(sW, axis=0), k=sW.shape[1])
ssW = tf.gather(sW, inds2, axis=-1)


# plt.imshow(tf.transpose(sW))
plt.imshow(tf.transpose(ssW))
plt.show()


_, inds3 = tf.math.top_k(tf.reduce_mean(W, axis=0), k=W.shape[1])
sW3 = tf.gather(W, inds3, axis=-1)


for c_ind in range(0, 8):
    # _, ex_inds = tf.math.top_k(sW3[:, c_ind], k=32)
    _, ex_inds = tf.math.top_k(W[:, c_ind], k=32)
    print(tf.gather(tf.reshape(y, [-1]), ex_inds).numpy())
    cf.imshow_cifar_multi(tf.gather(x, ex_inds), row_size=8)


# _, ex_inds = tf.math.top_k(W[:, 90], k=32)
# print(tf.gather(tf.reshape(y, [-1]), ex_inds).numpy())
# cf.imshow_cifar_multi(tf.gather(x, ex_inds), row_size=8)

R"""

for _, x in soc_ds:
    break

y = soc_model(x)

y[:10,:5]

tf.reduce_all(y[0] == y[1]).numpy()


x[:10,:5]

"""