R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=2 python -i local_scripts/m_npeff/wrong_heuristics/wrong_perturb_test003.py

"""
import dataclasses
from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.fishers import diagonal
from em.tools.nmf import lrm_npeff
from em.util import flat_pack

from em.projects.m_npeff import snli_context
from em.projects.pi import qqp_components_context as QCC

from em.util.color_util import cu


###############################################################################

FISHER_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/fishers/"
FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"
FISHER_PATH = os.path.join(FISHER_DIR, FISHER_NAME)

NMF_DIR = "/playpen/users/m/project_data/m_npeff1/per_example_fishers"
NMF_NAME = "test_mnpeff_002.expansion_005.coeffs_fit001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

# Use this only to get the predictions and example token ids without having to
# evaluate the model.
PEFS_FOR_PREDICTIONS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/"
PEFS_FOR_PREDICTIONS_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"
PEFS_FOR_PREDICTIONS_PATH = os.path.join(PEFS_FOR_PREDICTIONS_DIR, PEFS_FOR_PREDICTIONS_NAME)

MODEL = "connectivity/feather_berts_0"
TOKENIZER = 'bert-base-uncased'

###############################################################################

og_model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)


# fisher = diagonal.DiagonalFisher.load(FISHER_PATH)
# flat_fisher = fisher.as_flat_fisher().numpy()
# flat_fisher /= np.sqrt(np.sum(flat_fisher**2))

print('Starting to read in decomposition.')
nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=True)
print('Decomposition read in.')
nmf.normalize_components_to_unit_norm()
print('Decomposition components normalized.')

ctx = snli_context.SnliContext(
    split='train_skip_50k',
    tokenizer=AutoTokenizer.from_pretrained(TOKENIZER),
    nmf=nmf,
    load_examples=False,
)
print('SNLI context made.')

###############################################################################

N_TOTAL_EXAMPLES = 8 * 1014
reload(snli_context); ctx.__class__ = snli_context.SnliContext
# eval_ctx = ctx.create_eval_ctx(og_model)
eval_ctx = ctx.create_eval_ctx_from_pefs_file(PEFS_FOR_PREDICTIONS_PATH)
eval_ctx.all_examples = (eval_ctx.all_examples[0], (eval_ctx.all_examples[1] + 1) % 3)
print('Eval context made.')

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)

###############################################################################


def find_unnormalized_perturbation(component_index: int, max_sim: float = 1e9):
    # Assumes rows of G have unit norm.
    G = nmf.G
    #
    g_main = np.copy(G[component_index])
    #
    if max_sim > 0.0:
        for i in range(G.shape[0]):
            if i == component_index:
                continue
            if np.abs(G[component_index].dot(G[i])) > max_sim:
                continue
            g_main -= g_main.dot(G[i]) * G[i]
        #
    g = np.zeros([nmf.n_parameters], dtype=np.float32)
    g[nmf.new_to_old_col_indices] = g_main
    return g

###############################################################################


@dataclasses.dataclass
class PerturbationResults:
    top_results: QCC.QqpEvaluationResults
    total_results: QCC.QqpEvaluationResults
    
    def ratio(self):
        return self.top_results.kl() / self.total_results.kl()


def perturb_weights(offsets, multiplier):
    for ogv, v, offset in zip(og_model.trainable_variables, model.trainable_variables, offsets):
        v.assign(ogv + multiplier * offset)


def eval_ratio(component_index: int, n_top_examples: int = 128, n_total_examples=N_TOTAL_EXAMPLES):
    top_inds = np.argsort(-nmf.W[:, component_index])[:n_top_examples]
    top_results = eval_ctx.evaluate(model, top_inds)
    total_results = eval_ctx.evaluate(model, np.arange(n_total_examples))
    return PerturbationResults(top_results=top_results, total_results=total_results)


def run_and_eval(component_index: int, multiplier_mag=1.0, max_sim: float = 1e9):
    g = find_unnormalized_perturbation(component_index, max_sim)
    g /= np.sqrt(np.sum(g**2))
    #
    packer = flat_pack.FlatPacker([v.shape for v in model.trainable_variables])
    offsets = packer.decode_tf(tf.cast(g, tf.float32))
    #
    perturb_weights(offsets, multiplier_mag)
    r1 = eval_ratio(component_index)
    #
    perturb_weights(offsets, -multiplier_mag)
    r2 = eval_ratio(component_index)
    #
    return r1 if r1.ratio() > r2.ratio() else r2

###############################################################################


# mp = 2e-1
# mp = 0.0
# mp = 6e-1
# mp = 1e-1
mp = 3e-1

COMP_INDS = [13, 16, 18, 34, 36, 40, 210, 233, 315]
for comp in COMP_INDS:
    results = run_and_eval(comp, mp, max_sim=0.35)
    print(f'Component: {comp}')
    print(f'    ratio: {results.ratio()}')
    print(f'    top acc: {results.top_results.acc()}')
    print(f'    total acc: {results.total_results.acc()}')
    print('')


"""
baseline total acc: 0.7871055226824457

################################################################
mp = 2e-1:

Component: 13
    ratio: 45.647857666015625
    top acc: 0.7734375
    total acc: 0.7936390532544378

Component: 16
    ratio: 67.78596496582031
    top acc: 0.8828125
    total acc: 0.7919132149901381

Component: 18
    ratio: 18.042158126831055
    top acc: 0.75
    total acc: 0.7909270216962525

Component: 34
    ratio: 5.065206050872803
    top acc: 0.5078125
    total acc: 0.7432199211045365

Component: 36
    ratio: 38.61678695678711
    top acc: 0.59375
    total acc: 0.7910502958579881

Component: 40
    ratio: 17.60877799987793
    top acc: 0.625
    total acc: 0.7941321499013807

Component: 210
    ratio: 60.44377899169922
    top acc: 0.734375
    total acc: 0.7917899408284024

Component: 233
    ratio: 75.42974853515625
    top acc: 0.828125
    total acc: 0.7904339250493096

Component: 315
    ratio: 51.33586502075195
    top acc: 0.6953125
    total acc: 0.7908037475345168


################################################################
mp = 1e-1:

Component: 13
    ratio: 87.96298217773438
    top acc: 0.671875
    total acc: 0.7909270216962525

Component: 16
    ratio: 103.35576629638672
    top acc: 0.8125
    total acc: 0.792776134122288

Component: 18
    ratio: 21.091678619384766
    top acc: 0.75
    total acc: 0.7873520710059172

Component: 34
    ratio: 1.4242366552352905
    top acc: 0.640625
    total acc: 0.7859960552268245

Component: 36
    ratio: 110.92489624023438
    top acc: 0.6640625
    total acc: 0.7889546351084813

Component: 40
    ratio: 9.553682327270508
    top acc: 0.4296875
    total acc: 0.7924063116370809

Component: 210
    ratio: 106.35462951660156
    top acc: 0.6796875
    total acc: 0.7909270216962525

Component: 233
    ratio: 104.81703186035156
    top acc: 0.7109375
    total acc: 0.7904339250493096

Component: 315
    ratio: 70.19462585449219
    top acc: 0.5703125
    total acc: 0.789447731755424


################################################################
mp = 3e-1:

Component: 13
    ratio: 25.205078125
    top acc: 0.78125
    total acc: 0.7959812623274162

Component: 16
    ratio: 42.955631256103516
    top acc: 0.890625
    total acc: 0.7928994082840237

Component: 18
    ratio: 12.648873329162598
    top acc: 0.6640625
    total acc: 0.7847633136094675

Component: 34
    ratio: 4.717387676239014
    top acc: 0.3203125
    total acc: 0.6411489151873767

Component: 36
    ratio: 20.563859939575195
    top acc: 0.59375
    total acc: 0.7885848126232742

Component: 40
    ratio: 10.17428207397461
    top acc: 0.671875
    total acc: 0.7842702169625246

Component: 210
    ratio: 37.13119125366211
    top acc: 0.71875
    total acc: 0.7936390532544378

Component: 233
    ratio: 51.07826614379883
    top acc: 0.8203125
    total acc: 0.7931459566074951

Component: 315
    ratio: 37.035728454589844
    top acc: 0.7109375
    total acc: 0.7919132149901381


"""