R"""




cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i local_scripts/m_npeff/lrm_npeff_snli_perturbations001.py

"""

from importlib import reload
import os

import numpy as np
import tensorflow as tf

from em.fishers import diagonal
from em.tools.nmf import lrm_npeff
from em.projects.m_npeff import perturbation_finder


PerturbationFinder = perturbation_finder.PerturbationFinder


###############################################################################

FISHER_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/fishers/"
FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"
FISHER_PATH = os.path.join(FISHER_DIR, FISHER_NAME)

NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/"
NMF_NAME = "test_mnpeff_002.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

TOKENIZER = 'bert-base-uncased'

###############################################################################

fisher = diagonal.DiagonalFisher.load(FISHER_PATH)
flat_fisher = fisher.as_flat_fisher().numpy()

nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=True)
nmf.normalize_components_to_unit_norm()


def process_for_component(component_index: int, delta=0.01, min_fisher_value=1e-6):
    g = nmf.get_full_g(component_index)
    #
    # reload(perturbation_finder)
    pf = perturbation_finder.PerturbationFinder(f=flat_fisher, g=g)
    #
    z = pf.solve(delta=delta, min_fisher_value=min_fisher_value)
    #
    print(np.sum(z**2))
    print(pf.compute_constraint_value(z))
    print(pf.compute_objective_value(z)**2)


# process_for_component(6)
process_for_component(0)
