R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i local_scripts/m_npeff/wrong_heuristics/wrong_perturb_test002.py

"""
import dataclasses
from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.fishers import diagonal
from em.tools.nmf import lrm_npeff
from em.util import flat_pack

from em.projects.m_npeff import snli_context
from em.projects.pi import qqp_components_context as QCC

from em.util.color_util import cu


###############################################################################

FISHER_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/fishers/"
FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"
FISHER_PATH = os.path.join(FISHER_DIR, FISHER_NAME)

NMF_DIR = "/playpen/users/m/project_data/m_npeff1/per_example_fishers"
NMF_NAME = "test_mnpeff_004.coeffs_fit001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

# Use this only to get the predictions and example token ids without having to
# evaluate the model.
PEFS_FOR_PREDICTIONS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/"
PEFS_FOR_PREDICTIONS_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"
PEFS_FOR_PREDICTIONS_PATH = os.path.join(PEFS_FOR_PREDICTIONS_DIR, PEFS_FOR_PREDICTIONS_NAME)

MODEL = "connectivity/feather_berts_0"
TOKENIZER = 'bert-base-uncased'

###############################################################################

og_model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)


# fisher = diagonal.DiagonalFisher.load(FISHER_PATH)
# flat_fisher = fisher.as_flat_fisher().numpy()
# flat_fisher /= np.sqrt(np.sum(flat_fisher**2))

print('Starting to read in decomposition.')
nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=True)
print('Decomposition read in.')
nmf.normalize_components_to_unit_norm()
print('Decomposition components normalized.')

ctx = snli_context.SnliContext(
    split='train_skip_50k',
    tokenizer=AutoTokenizer.from_pretrained(TOKENIZER),
    nmf=nmf,
    load_examples=False,
)
print('SNLI context made.')

###############################################################################

N_TOTAL_EXAMPLES = 8 * 1014
reload(snli_context); ctx.__class__ = snli_context.SnliContext
# eval_ctx = ctx.create_eval_ctx(og_model)
eval_ctx = ctx.create_eval_ctx_from_pefs_file(PEFS_FOR_PREDICTIONS_PATH)
eval_ctx.all_examples = (eval_ctx.all_examples[0], (eval_ctx.all_examples[1] + 1) % 3)
print('Eval context made.')

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)

###############################################################################


def find_perturbation(component_index: int, max_sim: float = 1e9):
    # Assumes rows of G have unit norm.
    G = nmf.G
    #
    g_main = np.copy(G[component_index])
    #
    if max_sim > 0.0:
        for i in range(G.shape[0]):
            if i == component_index:
                continue
            if np.abs(G[component_index].dot(G[i])) > max_sim:
                continue
            g_main -= g_main.dot(G[i]) * G[i]
        #
    g_main /= np.sqrt(np.sum(g_main**2))
    g = np.zeros([nmf.n_parameters], dtype=np.float32)
    g[nmf.new_to_old_col_indices] = g_main
    return g

###############################################################################


@dataclasses.dataclass
class PerturbationResults:
    top_results: QCC.QqpEvaluationResults
    total_results: QCC.QqpEvaluationResults
    
    def ratio(self):
        return self.top_results.kl() / self.total_results.kl()


def perturb_weights(offsets, multiplier):
    for ogv, v, offset in zip(og_model.trainable_variables, model.trainable_variables, offsets):
        v.assign(ogv + multiplier * offset)


def eval_ratio(component_index: int, n_top_examples: int = 128, n_total_examples=N_TOTAL_EXAMPLES):
    top_inds = np.argsort(-nmf.W[:, component_index])[:n_top_examples]
    top_results = eval_ctx.evaluate(model, top_inds)
    total_results = eval_ctx.evaluate(model, np.arange(n_total_examples))
    return PerturbationResults(top_results=top_results, total_results=total_results)


def find_sign(offset, component_index: int, multiplier_mag=1.0):
    packer = flat_pack.FlatPacker([v.shape for v in model.trainable_variables])
    offsets = packer.decode_tf(tf.cast(offset, tf.float32))
    #
    perturb_weights(offsets, multiplier_mag)
    r1 = eval_ratio(component_index)
    #
    perturb_weights(offsets, -multiplier_mag)
    r2 = eval_ratio(component_index)
    #
    return 1 if r1.ratio() > r2.ratio() else -1

###############################################################################


mp = 1e-1

COMP_INDS = [91, 182, 418, 473]

offsets = {
    component_index: find_perturbation(component_index, 0.35)
    for component_index in COMP_INDS
}

signs = {
    component_index: find_sign(offsets[component_index], component_index, mp)
    for component_index in COMP_INDS
}



packer = flat_pack.FlatPacker([v.shape for v in model.trainable_variables])

comps = [91, 182]
mp = 3e-1
# comps = [91, 182, 418]
offset = packer.decode_tf(tf.cast(sum([signs[i] * offsets[i] for i in comps]), tf.float32))
perturb_weights(offset, mp)
total_results = eval_ctx.evaluate(model, np.arange(N_TOTAL_EXAMPLES))
print(total_results.acc())


# for comp in COMP_INDS:
#     results = run_and_eval(comp, mp, max_sim=0.35)
#     print(f'Component: {comp}')
#     print(f'    ratio: {results.ratio()}')
#     print(f'    top acc: {results.top_results.acc()}')
#     print(f'    total acc: {results.total_results.acc()}')
#     print('')


"""

################################################################


Components: 81, 182
    mp = 1e-1
    total acc: 0.7984467455621301

Components: 81, 182
    mp = 2e-1
    total acc: 0.8009122287968442

Components: 81, 182
    mp = 3e-1
    total acc: 0.7978303747534516


Components: 81, 182, 418
    mp = 1e-1
    total acc: 0.7995562130177515

Components: 81, 182, 418
    mp = 2e-1
    total acc: 0.7979536489151874

Components: 81, 182, 418
    mp = 5e-2
    total acc: 0.796844181459566

Components: 81, 182, 418
    mp = 8e-2
    total acc: 0.7990631163708086


################################################################
mp = 2e-1:

Component: 91
    ratio: 102.59149169921875
    top acc: 0.6796875
    total acc: 0.7938856015779092

Component: 182
    ratio: 99.8275375366211
    top acc: 0.8203125
    total acc: 0.7936390532544378

Component: 418
    ratio: 66.10822296142578
    top acc: 0.6953125
    total acc: 0.7917899408284024

Component: 473
    ratio: 34.2840461730957
    top acc: 0.671875
    total acc: 0.7896942800788954

################################################################
mp = 0.0 (baselines):

Component: 91
    top acc: 0.09375
    total acc: 0.7871055226824457

Component: 182
    top acc: 0.03125
    total acc: 0.7871055226824457

Component: 418
    top acc: 0.28125
    total acc: 0.7871055226824457

Component: 473
    top acc: 0.6796875
    total acc: 0.7871055226824457

################################################################
mp = 6e-1

Component: 91
    ratio: 20.061758041381836
    top acc: 0.625
    total acc: 0.7842702169625246

Component: 182
    ratio: 29.810922622680664
    top acc: 0.8203125
    total acc: 0.7855029585798816

Component: 418
    ratio: 20.66106605529785
    top acc: 0.6484375
    total acc: 0.7778599605522682

Component: 473
    ratio: 7.063019275665283
    top acc: 0.296875
    total acc: 0.6942800788954635

################################################################
mp = 1e-1

Component: 91
    ratio: 182.10888671875
    top acc: 0.5703125
    total acc: 0.7922830374753451

Component: 182
    ratio: 124.07206726074219
    top acc: 0.65625
    total acc: 0.7935157790927022

Component: 418
    ratio: 85.91929626464844
    top acc: 0.5859375
    total acc: 0.7919132149901381

Component: 473
    ratio: 16.77489471435547
    top acc: 0.6875
    total acc: 0.7899408284023669


"""