R"""Exploration of NMF hyperparameters on MNIST/FASHION MNIST.


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i em/projects/nmf_explore/mnist_hyperparams.py

"""
import time

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import torch
from torchnmf.nmf import NMF as TorchNMF


###############################################################################

# Keep tensorflow from allocating all GPU memory to allow torchnmf to
# use GPU.
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

###############################################################################
###############################################################################

IMG_SHAPE = (28, 28)
FLAT_SIZE = 28 * 28

##########################################################################


def load_mnist_train():
    (X, Y), _ = tf.keras.datasets.mnist.load_data()
    X = (X.astype("float32") / 255).reshape([-1, FLAT_SIZE])
    return X, Y


def load_fashion_mnist_train():
    (X, Y), _ = tf.keras.datasets.fashion_mnist.load_data()
    X = (X.astype("float32") / 255).reshape([-1, FLAT_SIZE])
    return X, Y

##########################################################################


def plot_images(images, n_rows: int, n_cols: int, *, show=True):
    fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, squeeze=False)
    #
    for ax_ in axs:
        for ax in ax_:
            # These suck for plotting data as images.
            ax.spines["top"].set_visible(False)  
            ax.spines["right"].set_visible(False)  
            ax.spines["bottom"].set_visible(False)  
            ax.spines["left"].set_visible(False)
            ax.axis('off')
    #
    for i in range(n_rows):
        for j in range(n_cols):
            k = i * n_cols + j
            if k >= images.shape[0]:
                break
            axs[i][j].imshow(images[k], cmap='gray', interpolation=None)
    #
    plt.tight_layout()
    if show:
        plt.show()


###############################################################################
###############################################################################

# N_COMPONENTS = 16
# N_COMPONENTS = 32
N_COMPONENTS = 64
# N_COMPONENTS = 256

# MAX_ITER = 3_000
MAX_ITER = 10_000
# TOL = 1e-6
TOL = 1e-8

# ALPHA = 0.0
ALPHA = 1e-1
# ALPHA = 5e-1
# ALPHA = 1e-2

# BETA = 1.0
BETA = 2.0

# L1_RATIO = 0.0
# L1_RATIO = 0.9
L1_RATIO = 1.0

##########################################################################

# X, Y = load_mnist_train()
X, Y = load_fashion_mnist_train()

torch_X = torch.from_numpy(X.T).cuda()


print('Starting NMF decomposition.')
start = time.time()
nmf_model = TorchNMF(torch_X.shape, rank=N_COMPONENTS).cuda()
nmf_model.fit(
    torch_X,
    verbose=True,
    max_iter=MAX_ITER,
    tol=TOL,
    alpha=ALPHA,
    beta=BETA,
    l1_ratio=L1_RATIO,
)
# def sparse_fit(self,
#                    V,
#                    beta=2,
#                    max_iter=200,
#                    verbose=False,
#                    sW=None,
#                    sH=None,
#                    )
print('NMF time: ', time.time() - start)

W = nmf_model.W.detach().cpu().numpy()

H = nmf_model.H.detach().cpu().numpy().T
H = H.reshape([-1, *IMG_SHAPE])

# plot_images(H, n_rows=4, n_cols=4)
# plot_images(H, n_rows=4, n_cols=8)
plot_images(H, n_rows=8, n_cols=8)
# plot_images(H, n_rows=N_COMPONENTS // 32, n_cols=32)


"""
NMF Variants to Try:
    - Note that some will be to allievate computational issues
      while others will be for the output.


General:
    - Number of examples.


PE-Fishers:
    - Grouping of parameter subsets.
    - The k in top-k.
    - Top-k in parameter subset (or subsets of a subset) vs top-k of all parameters.
    - Different values of k for different examples (e.g., parameters contains top k% of Fisher mass)


Reduction:
    - Ignoring parameters with non-zero values for less than k examples (i.e. --reduce_threshold)


Computation (general):
    - Magnitude of each example's Fisher (e.g., scaling of each parameter's Fisher)
    - Tolerance and max iterations.
    - The fit vs sparse_fit methods.
    - The value of beta.


"""