# factorization_debug001.py
R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i local_scripts/m_npeff/imagenet2/factorization_debug001.py

"""
import os
import numpy as np
from em.fishers import lrm_pefs
from em.tools.nmf import lrm_npeff

import matplotlib.pyplot as plt
import seaborn as sns

###############################################################################


NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/imagenet2_lrm_npeff/per_example_fishers/"
# NMF_NAME = "resnet50_imagenet.train.20000ex.65536.mpc3e-3.mnpeff.256comps.001.h5"
NMF_NAME = "TEST_LVRM_NPEFF_001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)


nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=False)

W = nmf.W

# if not np.isfinite(W).all():
#     raise ValueError

q = ~np.isfinite(W)
a, b = np.nonzero(q)

plt.imshow(W[:600].T);plt.show()

# plt.plot(W[1]);plt.show()

# new_W = W[:, :64]
# plt.imshow(np.sqrt(new_W.T));plt.show()
