R"""




cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i local_scripts/m_npeff/lrm_npeff_snli_norm_ratios001.py

"""

from importlib import reload
import os

import h5py
import numpy as np

from em.tools.nmf import lrm_npeff
from em.util import hdf5_util

###############################################################################

PEF_FILEPATH = "/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/feather_berts_0.train.50000ex.65536.h5"

DIAG_PEFS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/"
DIAG_PEFS_NAME = "feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
DIAG_PEFS_PATH = os.path.join(DIAG_PEFS_DIR, DIAG_PEFS_NAME)

NMF_DIR = "/playpen/users/m/project_data/m_npeff1/per_example_fishers"
NMF_NAME = "test_mnpeff_002.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

###############################################################################


def read_pef_norms(pef_filepath: str):
    with h5py.File(pef_filepath, "r") as f:
        return hdf5_util.load_h5_ds(f['data/pef_frobenius_norms'])


def read_diag_pef_norms(pef_filepath: str):
    with h5py.File(pef_filepath, "r") as f:
        return hdf5_util.load_h5_ds(f['data/dense_fisher_norms'])


def compute_norm_ratio(nmf, pef_norms: np.ndarray, component_index: int, n_top_examples=128):
    top_inds = np.argsort(-nmf.W[:, component_index])[:n_top_examples]
    top_comp_mean = np.mean(pef_norms[top_inds])
    total_mean = np.mean(pef_norms)
    return top_comp_mean / total_mean


###############################################################################

nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=False)
pef_norms = read_pef_norms(PEF_FILEPATH)
diag_pef_norms = read_diag_pef_norms(DIAG_PEFS_PATH)

for i in range(256):
    nr = compute_norm_ratio(nmf, pef_norms, i, n_top_examples=128)
    print(nr)

print(5 * '\n')

for i in range(256):
    nr = compute_norm_ratio(nmf, diag_pef_norms, i, n_top_examples=128)
    print(nr)


###############################################################################

"""
x This
- Perturbation grid searches (of the last method).
- Perturbation methods latex.
- Fit coeffs for LRM-NPEFF.
- Touch ups to LRM_NPEFF factorization.
    - Back to int32 indices.
    - Logging.
    - Some sort of L1/L2 regularization on the Gs.
- Input feature salience via PEF factoring stuff.
"""
