R"""



cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/divis/divis_ds_nmf001.py

"""
import collections
from importlib import reload
import itertools
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import torch
from torchnmf.nmf import NMF as TorchNMF

from em.datasets import divisibility as divis_ds
from em.fishers import per_example
from em.models import divis_models
from em.models import transformer_model_vars as tmv
from em.tools.clustering import vat
from em.tools.nmf import parallel_sklearn_nmf as p_nmf
from em.tools.nmf import nmf_common
from em.util import flat_pack

from em.analysis import bert_nmf_analysis as bna
from em.analysis import bert_nmf_analysis2 as bna2

###############################################################################

# Keep tensorflow from allocating all GPU memory to allow torchnmf to
# use GPU.
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

###############################################################################

# EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/divis_mech1'
# MODELS_DIR = os.path.join(EXPS_DIR, 'models')
# FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
# PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

_DS_BUFFER_SIZE = 4 * 1024 * 1024

ds_config = divis_ds.DivisibilityDatasetConfig(
    min_divisor=2,
    max_divisor=9,
    min_dividend=100,
    max_dividend=999_999_999,
)
ds = divis_ds.create_ds(ds_config, buffer_size=_DS_BUFFER_SIZE)

###############################################################################

N_EXAMPLES = 32 * 1024

(x, y), = ds.batch(N_EXAMPLES).take(1)

z = tf.concat([x, y[:, None]], axis=-1)

z = tf.one_hot(z, depth=10, axis=-1)
z = tf.reshape(z, [z.shape[0], -1])
z = z.numpy()

###############################################################################

NMF_N_COMPONENTS = 256
MAX_ITER = 3000
TOL = 1e-8

torch_z = torch.from_numpy(z.T).cuda()

print('Starting NMF decomposition.')
start = time.time()
nmf_model = TorchNMF(torch_z.shape, rank=NMF_N_COMPONENTS).cuda()
nmf_model.fit(
    torch_z,
    verbose=True,
    max_iter=MAX_ITER,
    tol=TOL,
)
print('NMF time: ', time.time() - start)


W = nmf_model.W.detach().cpu().numpy()
H = nmf_model.H.detach().cpu().numpy().T

###############################################################################

divisors = ds_config.get_divisors_from_examples(x.numpy())
dividends = ds_config.get_dividends_from_examples(x.numpy())


def print_top_examples(component_index: int, n_examples: int):
    _, inds = tf.math.top_k(W[:, component_index], k=n_examples)
    for ind in inds:
        label = y[ind]
        #
        divisor = divisors[ind]
        dividend = dividends[ind]
        example = f'{divisor}|{dividend}'
        #
        print(f'{label}, {example}')


print_top_examples(0, n_examples=16)
