R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i local_scripts/m_npeff/snli2/extension_debug_001.py

"""
import os
import numpy as np
from em.fishers import lrm_pefs
from em.tools.nmf import lrm_npeff

import matplotlib.pyplot as plt
import seaborn as sns

###############################################################################


NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/snli2_lrm_npeff/per_example_fishers/"
# NMF_NAME = "bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.512comps.coeffs_fit001.h5"
NMF_NAME = "bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.512comps.expansion64comps.001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)


nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=False)

W = nmf.W

# if not np.isfinite(W).all():
#     raise ValueError

q = ~np.isfinite(W)
a, b = np.nonzero(q)

# plt.imshow(np.sqrt(W.T));plt.show()

# plt.plot(W[1]);plt.show()

new_W = W[:, :64]
# plt.imshow(np.sqrt(new_W.T));plt.show()


###############################################################################

# PEFS_DIR = "/playpen/users/m/project_data/qqp_lrm_npeff2/per_example_fishers/"
# PEFS_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.heldout_from_train.50000ex.65536.wrongs_only.h5"
# PEFS_PATH = os.path.join(PEFS_DIR, PEFS_NAME)

# # PEFS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/qqp_lrm_npeff2/per_example_fishers/"
# # PEFS_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.DUMMY_TEST_DEBUG_001.100ex.65536.h5"
# # PEFS_PATH = os.path.join(PEFS_DIR, PEFS_NAME)

# pefs = lrm_pefs.SparseLrmPefs.load(PEFS_PATH)


# for i, offsets in enumerate(pefs.col_offsets):
#     for j in range(pefs.n_classes - 1):
#         q = pefs.row_indices[i, offsets[j]:offsets[j + 1]]
#         if len(set(q)) != offsets[j + 1] - offsets[j]:
#             raise ValueError

# if not np.isfinite(pefs.values).all():
#     raise ValueError
