R"""




cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i local_scripts/m_npeff/perturbations/perturbation_hyperparameters_search001.py

"""

from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.fishers import diagonal
from em.tools.nmf import lrm_npeff
from em.util import flat_pack

from em.projects.m_npeff import perturbation_finder
from em.projects.m_npeff import snli_context
from em.projects.pi import qqp_components_context as QCC


###############################################################################

FISHER_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/fishers/"
FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"
FISHER_PATH = os.path.join(FISHER_DIR, FISHER_NAME)

NMF_DIR = "/playpen/users/m/project_data/m_npeff1/per_example_fishers"
NMF_NAME = "test_mnpeff_002.coeffs_fit001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

# Use this only to get the predictions and example token ids without having to
# evaluate the model.
PEFS_FOR_PREDICTIONS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/"
PEFS_FOR_PREDICTIONS_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"
PEFS_FOR_PREDICTIONS_PATH = os.path.join(PEFS_FOR_PREDICTIONS_DIR, PEFS_FOR_PREDICTIONS_NAME)

MODEL = "connectivity/feather_berts_0"
TOKENIZER = 'bert-base-uncased'

###############################################################################

og_model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)

fisher = diagonal.DiagonalFisher.load(FISHER_PATH)
flat_fisher = fisher.as_flat_fisher().numpy()
flat_fisher /= np.sqrt(np.sum(flat_fisher**2))

nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=True)
nmf.normalize_components_to_unit_norm()

ctx = snli_context.SnliContext(
    split='train_skip_50k',
    tokenizer=AutoTokenizer.from_pretrained(TOKENIZER),
    nmf=nmf,
    load_examples=False,
)

###############################################################################

N_TOTAL_EXAMPLES = 1024
reload(snli_context); ctx.__class__ = snli_context.SnliContext
# eval_ctx = ctx.create_eval_ctx(og_model)
eval_ctx = ctx.create_eval_ctx_from_pefs_file(PEFS_FOR_PREDICTIONS_PATH)

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)


###############################################################################


def run(component_index: int, multiplier=1.0, lmbda=None, min_fisher_value=1e-9):
    g = nmf.get_full_g(component_index)
    #
    reload(perturbation_finder)
    pf = perturbation_finder.PerturbationFinder3(f=flat_fisher, g=g)
    #
    z = pf.solve(multiplier=multiplier, lmbda=lmbda, min_fisher_value=min_fisher_value)
    #
    packer = flat_pack.FlatPacker([v.shape for v in model.trainable_variables])
    offsets = packer.decode_tf(tf.cast(z, tf.float32))
    #
    # TODO: Both +/- offset would work here. Right now, just doing plus.
    for ogv, v, offset in zip(og_model.trainable_variables, model.trainable_variables, offsets):
        v.assign(ogv + offset)


def eval(component_index: int, n_top_examples: int = 128, n_total_examples=N_TOTAL_EXAMPLES):
    top_inds = np.argsort(-nmf.W[:, component_index])[:n_top_examples]
    top_results = eval_ctx.evaluate(model, top_inds)
    total_results = eval_ctx.evaluate(model, np.arange(n_total_examples))
    #
    # print('Top:', top_results.kl())
    # print('Total:', total_results.kl())
    # print('Ratio:', top_results.kl() / total_results.kl())
    #
    return top_results.kl() / total_results.kl()

###############################################################################


# component_indices = [0, 6]
component_indices = [
    159,  # 0.6703704596
    180,  # 11.68937302
    184,  # 0.08225885779
    191,  # 0.07570057362
    213,  # 1.329680085
]

lmbdas = [None, 0.01, 0.05, 0.1, 0.15, 0.25, 0.35, 0.5, 0.65, 0.75, 0.85, 0.9, 0.95, 0.99]
# lmbdas = [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
# lmbdas = [None, 0, 1e-6]
mp = 2e-1

# Generally looks like lmda=0 is best if the component ends up having a positive ratio.
# This is the same as simple offset in direction of g.
for i in component_indices:
    print(f'Component {i}')
    for lmbda in lmbdas:
        run(i, multiplier=mp, lmbda=lmbda)
        r1 = eval(i, n_top_examples=128)
        run(i, multiplier=-mp, lmbda=lmbda)
        r2 = eval(i, n_top_examples=128)
        r = max(r1, r2)
        print(f'    {mp}, {lmbda}: {r}')


# for i in range(32):
#     mp = 2e-1
#     # mp = 3e-1
#     lmbda = 0.25
#     #
#     print('Component', i)
#     run(i, multiplier=mp, lmbda=lmbda)
#     _ = eval(i, n_top_examples=128)
#     run(i, multiplier=-mp, lmbda=lmbda)
#     _ = eval(i, n_top_examples=128)
#     print('')

