R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i local_scripts/m_npeff/qqp/extension_debug_001.py

"""
import os
import numpy as np
from em.fishers import lrm_pefs
from em.tools.nmf import lrm_npeff

###############################################################################

# NMF_DIR = "/playpen/users/m/project_data/qqp_lrm_npeff2/per_example_fishers/"
# NMF_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.heldout_from_train.50000ex.65536.wrongs_only.coeffs_fit001.h5"
# # NMF_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.heldout_from_train.mnpeff.256comps.001.h5"
# NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

# NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/m_npeff_ims1/per_example_fishers/"
# NMF_NAME = "resnet50_imagenet.train.tkc5.mnpeff.512comps.001.fit_to_validation.001.h5"
# NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)


NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/qqp_lrm_npeff2/per_example_fishers/"
NMF_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.heldout_from_train.50000ex.65536.wrongs_only.coeffs_fit001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)


nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=False)

W = nmf.W

# if not np.isfinite(W).all():
#     raise ValueError

q = ~np.isfinite(W)
a, b = np.nonzero(q)

###############################################################################

# PEFS_DIR = "/playpen/users/m/project_data/qqp_lrm_npeff2/per_example_fishers/"
# PEFS_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.heldout_from_train.50000ex.65536.wrongs_only.h5"
# PEFS_PATH = os.path.join(PEFS_DIR, PEFS_NAME)

# # PEFS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/qqp_lrm_npeff2/per_example_fishers/"
# # PEFS_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.DUMMY_TEST_DEBUG_001.100ex.65536.h5"
# # PEFS_PATH = os.path.join(PEFS_DIR, PEFS_NAME)

# pefs = lrm_pefs.SparseLrmPefs.load(PEFS_PATH)


# for i, offsets in enumerate(pefs.col_offsets):
#     for j in range(pefs.n_classes - 1):
#         q = pefs.row_indices[i, offsets[j]:offsets[j + 1]]
#         if len(set(q)) != offsets[j + 1] - offsets[j]:
#             raise ValueError

# if not np.isfinite(pefs.values).all():
#     raise ValueError
