R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=2 python -i local_scripts/m_npeff/snli2/wrong_perturb001.py

"""

import dataclasses
from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.fishers import diagonal
from em.tools.nmf import lrm_npeff
from em.fishers import lrm_pefs
from em.util import flat_pack
from em import datasets as em_datasets

from em.projects.m_npeff import snli_context
from em.projects.pi import qqp_components_context as QCC

from em.util.color_util import cu


###############################################################################

NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/snli2_lrm_npeff/per_example_fishers/"
NMF_NAME = "bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.50000ex.65536.512comps.expansion64comps.no_full_join.001.coeffs_fit001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

# Use this only to get the predictions and example token ids without having to
# evaluate the model.
PEFS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/snli2_lrm_npeff/per_example_fishers/"
PEFS_NAME = "bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.50000ex.65536.h5"
PEFS_PATH = os.path.join(PEFS_DIR, PEFS_NAME)

MODEL_DIR = "/fruitbasket/users/m/project_data/extract_merge1/snli2_lrm_npeff/models"
MODEL_NAME = "bert_base_snli_150k_holdout_4_epochs_01_epoch2"
MODEL = os.path.join(MODEL_DIR, MODEL_NAME)

###############################################################################

TOKENIZER = 'bert-base-uncased'
SPLIT = 'train[-100000:-50000]'

###############################################################################


###############################################################################

og_model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=False)


tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

print('Starting to load in nmf.')
nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=True)
nmf.normalize_components_to_unit_norm()

print('Starting to load in logits.')
logits = lrm_pefs.SparseLrmPefs.load_logits(PEFS_PATH)

print('Starting to create evaluation context.')
eval_ctx = QCC.EvaluationContext2.create_from_ds_and_logits(
    ds=em_datasets.load('snli/default', split=SPLIT, sequence_length=128, tokenizer=tokenizer),
    logits=logits,
)
print('Evaluation context made.')

###############################################################################

N_TOTAL_EXAMPLES = 8 * 1014

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=False)

###############################################################################


def find_unnormalized_perturbation(component_index: int, max_sim: float = 1e9):
    # Assumes rows of G have unit norm.
    G = nmf.G
    #
    g_main = np.copy(G[component_index])
    #
    if max_sim > 0.0:
        for i in range(G.shape[0]):
            if i == component_index:
                continue
            if np.abs(G[component_index].dot(G[i])) > max_sim:
                continue
            g_main -= g_main.dot(G[i]) * G[i]
        #
    g = np.zeros([nmf.n_parameters], dtype=np.float32)
    g[nmf.new_to_old_col_indices] = g_main
    return g

###############################################################################


@dataclasses.dataclass
class PerturbationResults:
    top_results: QCC.QqpEvaluationResults
    total_results: QCC.QqpEvaluationResults
    
    def ratio(self):
        return self.top_results.kl() / self.total_results.kl()


def perturb_weights(offsets, multiplier):
    for ogv, v, offset in zip(og_model.trainable_variables, model.trainable_variables, offsets):
        v.assign(ogv + multiplier * offset)


def eval_ratio(component_index: int, n_top_examples: int = 128, n_total_examples=N_TOTAL_EXAMPLES):
    top_inds = np.argsort(-nmf.W[:, component_index])[:n_top_examples]
    top_results = eval_ctx.evaluate(model, top_inds)
    total_results = eval_ctx.evaluate(model, np.arange(n_total_examples))
    return PerturbationResults(top_results=top_results, total_results=total_results)


def run_and_eval(component_index: int, multiplier_mag=1.0, max_sim: float = 1e9):
    g = find_unnormalized_perturbation(component_index, max_sim)
    g /= np.sqrt(np.sum(g**2))
    #
    packer = flat_pack.FlatPacker([v.shape for v in model.trainable_variables])
    offsets = packer.decode_tf(tf.cast(g, tf.float32))
    #
    perturb_weights(offsets, multiplier_mag)
    r1 = eval_ratio(component_index)
    #
    perturb_weights(offsets, -multiplier_mag)
    r2 = eval_ratio(component_index)
    #
    # return r1 if r1.ratio() > r2.ratio() else r2
    return r1 if r1.total_results.acc() > r2.total_results.acc() else r2


###############################################################################


def compute_baseline_top_acc(component_index: int, n_top_examples: int = 128):
    top_inds = np.argsort(-nmf.W[:, component_index])[:n_top_examples]
    labels = eval_ctx.all_examples[1][top_inds]
    preds = np.argmax(eval_ctx.og_logits, axis=-1)[top_inds]
    return np.mean((labels == preds).astype(np.float32))


###############################################################################


results = run_and_eval(0, 0, max_sim=0.35)
print('Baselines:')
print(f'    ratio: {results.ratio()}')
print(f'    top acc: {results.top_results.acc()}')
print(f'    total acc: {results.total_results.acc()}')
print('')



# mp = 3e-1
# mp = 8e-2
# mp = 1e-1
# mp = 6e-2
mp = 1.6e-1

# [0, 5, 17, 22,]
# COMP_INDS = [201, 337, 371, 569]
# COMP_INDS = range(64)
COMP_INDS = list(range(64)) + [201, 337, 371, 569]
for comp in COMP_INDS:
    results = run_and_eval(comp, mp, max_sim=0.35)
    print(f'Component: {comp}')
    print(f'    ratio: {results.ratio()}')
    print(f'    top acc: {results.top_results.acc()}')
    print(f'    total acc: {results.total_results.acc()}')
    print('')



for comp in list(range(64)) + [201, 337, 371, 569]:
    acc = compute_baseline_top_acc(comp)
    print(f'{comp}: {acc}')



"""


Baselines:
    total acc: 0.8788214990138067


mp = 8e-2:

Component: 0
    ratio: 11.419376373291016
    top acc: 0.5234375
    total acc: 0.8687130177514792

Component: 1
    ratio: 21.468156814575195
    top acc: 0.609375
    total acc: 0.872904339250493

Component: 2
    ratio: 22.423688888549805
    top acc: 0.7890625
    total acc: 0.8789447731755424

Component: 3
    ratio: 12.929420471191406
    top acc: 0.890625
    total acc: 0.8794378698224852

Component: 4
    ratio: 33.6479377746582
    top acc: 0.4609375
    total acc: 0.8703155818540433

Component: 5
    ratio: 18.097814559936523
    top acc: 0.6875
    total acc: 0.8785749506903353

Component: 6
    ratio: 17.255090713500977
    top acc: 0.765625
    total acc: 0.8793145956607495

Component: 7
    ratio: 18.654935836791992
    top acc: 0.640625
    total acc: 0.8741370808678501

Component: 8
    ratio: 23.25096321105957
    top acc: 0.7578125
    total acc: 0.8795611439842209

Component: 9
    ratio: 25.170886993408203
    top acc: 0.75
    total acc: 0.8775887573964497

Component: 10
    ratio: 24.89461326599121
    top acc: 0.6328125
    total acc: 0.8727810650887574

Component: 11
    ratio: 25.145090103149414
    top acc: 0.8125
    total acc: 0.8780818540433925

Component: 12
    ratio: 20.405540466308594
    top acc: 0.6875
    total acc: 0.8767258382642998

Component: 13
    ratio: 14.08040714263916
    top acc: 0.78125
    total acc: 0.8767258382642998

Component: 14
    ratio: 17.534772872924805
    top acc: 0.734375
    total acc: 0.8807938856015779

Component: 15
    ratio: 41.205474853515625
    top acc: 0.640625
    total acc: 0.8731508875739645

Component: 16
    ratio: 22.876924514770508
    top acc: 0.7421875
    total acc: 0.8763560157790927

Component: 17
    ratio: 39.38975524902344
    top acc: 0.6484375
    total acc: 0.8754930966469427

Component: 18
    ratio: 22.563034057617188
    top acc: 0.6640625
    total acc: 0.876232741617357

Component: 19
    ratio: 16.196802139282227
    top acc: 0.625
    total acc: 0.8710552268244576

Component: 20
    ratio: 19.56968116760254
    top acc: 0.671875
    total acc: 0.8738905325443787

Component: 21
    ratio: 36.035587310791016
    top acc: 0.6640625
    total acc: 0.8724112426035503

Component: 22
    ratio: 36.4364013671875
    top acc: 0.640625
    total acc: 0.8754930966469427

Component: 23
    ratio: 21.135618209838867
    top acc: 0.7734375
    total acc: 0.8785749506903353

Component: 24
    ratio: 23.417495727539062
    top acc: 0.7421875
    total acc: 0.8782051282051282

Component: 25
    ratio: 16.9957218170166
    top acc: 0.78125
    total acc: 0.8795611439842209

Component: 26
    ratio: 28.11031723022461
    top acc: 0.7890625
    total acc: 0.875

Component: 27
    ratio: 75.3795166015625
    top acc: 0.546875
    total acc: 0.8780818540433925

Component: 28
    ratio: 20.728736877441406
    top acc: 0.75
    total acc: 0.875

Component: 29
    ratio: 18.2907772064209
    top acc: 0.7421875
    total acc: 0.8772189349112426

Component: 30
    ratio: 15.992047309875488
    top acc: 0.7734375
    total acc: 0.8790680473372781

Component: 31
    ratio: 13.520147323608398
    top acc: 0.671875
    total acc: 0.8741370808678501

Component: 32
    ratio: 14.714926719665527
    top acc: 0.6875
    total acc: 0.8769723865877712

Component: 33
    ratio: 18.643800735473633
    top acc: 0.78125
    total acc: 0.8789447731755424

Component: 34
    ratio: 14.572858810424805
    top acc: 0.7421875
    total acc: 0.8767258382642998

Component: 35
    ratio: 58.043418884277344
    top acc: 0.765625
    total acc: 0.8791913214990138

Component: 36
    ratio: 20.380661010742188
    top acc: 0.8125
    total acc: 0.878698224852071

Component: 37
    ratio: 21.40690040588379
    top acc: 0.734375
    total acc: 0.877465483234714

Component: 38
    ratio: 24.4859561920166
    top acc: 0.796875
    total acc: 0.8796844181459567

Component: 39
    ratio: 25.722169876098633
    top acc: 0.6640625
    total acc: 0.870069033530572

Component: 40
    ratio: 11.49941635131836
    top acc: 0.8046875
    total acc: 0.8764792899408284

Component: 41
    ratio: 23.663578033447266
    top acc: 0.7265625
    total acc: 0.876232741617357

Component: 42
    ratio: 21.774019241333008
    top acc: 0.7578125
    total acc: 0.8790680473372781

Component: 43
    ratio: 16.990158081054688
    top acc: 0.796875
    total acc: 0.8757396449704142

Component: 44
    ratio: 24.946969985961914
    top acc: 0.671875
    total acc: 0.8733974358974359

Component: 45
    ratio: 26.058961868286133
    top acc: 0.640625
    total acc: 0.8766025641025641

Component: 46
    ratio: 18.04328727722168
    top acc: 0.671875
    total acc: 0.8793145956607495

Component: 47
    ratio: 24.398820877075195
    top acc: 0.765625
    total acc: 0.877095660749507

Component: 48
    ratio: 66.66757202148438
    top acc: 0.6015625
    total acc: 0.872904339250493

Component: 49
    ratio: 25.275903701782227
    top acc: 0.6171875
    total acc: 0.8784516765285996

Component: 50
    ratio: 17.10837745666504
    top acc: 0.6953125
    total acc: 0.8782051282051282

Component: 51
    ratio: 24.69113540649414
    top acc: 0.8125
    total acc: 0.8793145956607495

Component: 52
    ratio: 21.648345947265625
    top acc: 0.7109375
    total acc: 0.8752465483234714

Component: 53
    ratio: 14.260683059692383
    top acc: 0.5859375
    total acc: 0.8726577909270217

Component: 54
    ratio: 20.6578311920166
    top acc: 0.7734375
    total acc: 0.8795611439842209

Component: 55
    ratio: 17.812458038330078
    top acc: 0.859375
    total acc: 0.8727810650887574

Component: 56
    ratio: 17.93783950805664
    top acc: 0.671875
    total acc: 0.878698224852071

Component: 57
    ratio: 40.44706344604492
    top acc: 0.8046875
    total acc: 0.8791913214990138

Component: 58
    ratio: 25.95654296875
    top acc: 0.734375
    total acc: 0.8759861932938856

Component: 59
    ratio: 27.25227928161621
    top acc: 0.7578125
    total acc: 0.878698224852071

Component: 60
    ratio: 24.734777450561523
    top acc: 0.7265625
    total acc: 0.8780818540433925

Component: 61
    ratio: 37.22211837768555
    top acc: 0.671875
    total acc: 0.8784516765285996

Component: 62
    ratio: 19.126684188842773
    top acc: 0.734375
    total acc: 0.8777120315581854

Component: 63
    ratio: 17.26833724975586
    top acc: 0.828125
    total acc: 0.8742603550295858

Component: 201
    ratio: 17.649917602539062
    top acc: 0.578125
    total acc: 0.8736439842209073

Component: 337
    ratio: 18.68383026123047
    top acc: 0.6171875
    total acc: 0.8687130177514792

Component: 371
    ratio: 25.818363189697266
    top acc: 0.46875
    total acc: 0.8721646942800789

Component: 569
    ratio: 22.380155563354492
    top acc: 0.6796875
    total acc: 0.8698224852071006



mp = 3e-1:



Component: 0
    ratio: 3.2634592056274414
    top acc: 0.2109375
    total acc: 0.7410009861932939

Component: 1
    ratio: 8.019353866577148
    top acc: 0.34375
    total acc: 0.826060157790927

Component: 2
    ratio: 20.215221405029297
    top acc: 0.5234375
    total acc: 0.8579881656804734

Component: 3
    ratio: 23.834362030029297
    top acc: 0.71875
    total acc: 0.8741370808678501

Component: 4
    ratio: 13.930872917175293
    top acc: 0.28125
    total acc: 0.8436883629191322

Component: 5
    ratio: 11.782620429992676
    top acc: 0.4375
    total acc: 0.8584812623274162

Component: 6
    ratio: 13.853377342224121
    top acc: 0.5625
    total acc: 0.8624260355029586

Component: 7
    ratio: 9.541629791259766
    top acc: 0.46875
    total acc: 0.8444280078895463

Component: 8
    ratio: 17.9312686920166
    top acc: 0.5546875
    total acc: 0.8641518737672583

Component: 9
    ratio: 12.153315544128418
    top acc: 0.5
    total acc: 0.8588510848126233

Component: 10
    ratio: 11.932476043701172
    top acc: 0.5078125
    total acc: 0.8561390532544378

Component: 11
    ratio: 20.855587005615234
    top acc: 0.6484375
    total acc: 0.8667406311637081

Component: 12
    ratio: 7.325865745544434
    top acc: 0.5703125
    total acc: 0.8498520710059172

Component: 13
    ratio: 8.041543960571289
    top acc: 0.359375
    total acc: 0.8329635108481263

Component: 14
    ratio: 11.716567039489746
    top acc: 0.578125
    total acc: 0.8581114398422091

Component: 15
    ratio: 15.258824348449707
    top acc: 0.359375
    total acc: 0.8397435897435898

Component: 16
    ratio: 18.31674575805664
    top acc: 0.4921875
    total acc: 0.8623027613412229

Component: 17
    ratio: 14.221606254577637
    top acc: 0.5078125
    total acc: 0.8651380670611439

Component: 18
    ratio: 13.400150299072266
    top acc: 0.53125
    total acc: 0.8500986193293886

Component: 19
    ratio: 7.8321051597595215
    top acc: 0.3203125
    total acc: 0.8227317554240631

Component: 20
    ratio: 11.183816909790039
    top acc: 0.5859375
    total acc: 0.8616863905325444

Component: 21
    ratio: 11.564948081970215
    top acc: 0.609375
    total acc: 0.8563856015779092

Component: 22
    ratio: 14.047127723693848
    top acc: 0.3984375
    total acc: 0.8524408284023669

Component: 23
    ratio: 13.584965705871582
    top acc: 0.46875
    total acc: 0.8526873767258383

Component: 24
    ratio: 13.2267427444458
    top acc: 0.484375
    total acc: 0.8553994082840237

Component: 25
    ratio: 15.080923080444336
    top acc: 0.5859375
    total acc: 0.86267258382643

Component: 26
    ratio: 13.634769439697266
    top acc: 0.46875
    total acc: 0.849112426035503

Component: 27
    ratio: 18.416284561157227
    top acc: 0.4296875
    total acc: 0.8605769230769231

Component: 28
    ratio: 10.815831184387207
    top acc: 0.546875
    total acc: 0.8486193293885601

Component: 29
    ratio: 10.91618537902832
    top acc: 0.515625
    total acc: 0.8521942800788954

Component: 30
    ratio: 14.462736129760742
    top acc: 0.5625
    total acc: 0.8621794871794872

Component: 31
    ratio: 7.466119766235352
    top acc: 0.5234375
    total acc: 0.852810650887574

Component: 32
    ratio: 12.842744827270508
    top acc: 0.421875
    total acc: 0.8584812623274162

Component: 33
    ratio: 10.704041481018066
    top acc: 0.6015625
    total acc: 0.8510848126232742

Component: 34
    ratio: 7.806983947753906
    top acc: 0.3515625
    total acc: 0.8399901380670611

Component: 35
    ratio: 21.823965072631836
    top acc: 0.5625
    total acc: 0.8708086785009862

Component: 36
    ratio: 19.9368896484375
    top acc: 0.5625
    total acc: 0.8646449704142012

Component: 37
    ratio: 8.41324234008789
    top acc: 0.5234375
    total acc: 0.8515779092702169

Component: 38
    ratio: 21.862409591674805
    top acc: 0.59375
    total acc: 0.8685897435897436

Component: 39
    ratio: 11.340542793273926
    top acc: 0.265625
    total acc: 0.8296351084812623

Component: 40
    ratio: 7.51206636428833
    top acc: 0.5234375
    total acc: 0.8452909270216963

Component: 41
    ratio: 13.415058135986328
    top acc: 0.7109375
    total acc: 0.8647682445759369

Component: 42
    ratio: 13.128835678100586
    top acc: 0.4921875
    total acc: 0.8588510848126233

Component: 43
    ratio: 10.084392547607422
    top acc: 0.46875
    total acc: 0.8482495069033531

Component: 44
    ratio: 12.435408592224121
    top acc: 0.3984375
    total acc: 0.8418392504930966

Component: 45
    ratio: 9.830524444580078
    top acc: 0.2421875
    total acc: 0.8351824457593688

Component: 46
    ratio: 11.948302268981934
    top acc: 0.3984375
    total acc: 0.8494822485207101

Component: 47
    ratio: 12.035072326660156
    top acc: 0.6171875
    total acc: 0.8668639053254438

Component: 48
    ratio: 15.00649642944336
    top acc: 0.125
    total acc: 0.8242110453648915

Component: 49
    ratio: 13.962278366088867
    top acc: 0.453125
    total acc: 0.8571252465483234

Component: 50
    ratio: 8.091094017028809
    top acc: 0.71875
    total acc: 0.863905325443787

Component: 51
    ratio: 17.552921295166016
    top acc: 0.515625
    total acc: 0.8671104536489151

Component: 52
    ratio: 11.470945358276367
    top acc: 0.6484375
    total acc: 0.8599605522682445

Component: 53
    ratio: 7.6435956954956055
    top acc: 0.390625
    total acc: 0.8385108481262328

Component: 54
    ratio: 12.715217590332031
    top acc: 0.5625
    total acc: 0.8630424063116371

Component: 55
    ratio: 9.43368148803711
    top acc: 0.234375
    total acc: 0.8022682445759369

Component: 56
    ratio: 10.024519920349121
    top acc: 0.5234375
    total acc: 0.8584812623274162

Component: 57
    ratio: 21.281476974487305
    top acc: 0.421875
    total acc: 0.8563856015779092

Component: 58
    ratio: 11.099900245666504
    top acc: 0.5625
    total acc: 0.860207100591716

Component: 59
    ratio: 17.22976303100586
    top acc: 0.4296875
    total acc: 0.8583579881656804

Component: 60
    ratio: 17.185077667236328
    top acc: 0.46875
    total acc: 0.8597140039447732

Component: 61
    ratio: 18.8862247467041
    top acc: 0.453125
    total acc: 0.8631656804733728

Component: 62
    ratio: 9.493195533752441
    top acc: 0.5546875
    total acc: 0.854043392504931

Component: 63
    ratio: 9.577665328979492
    top acc: 0.3828125
    total acc: 0.8271696252465484

Component: 201
    ratio: 3.6429529190063477
    top acc: 0.34375
    total acc: 0.7554240631163708

Component: 337
    ratio: 3.746276378631592
    top acc: 0.265625
    total acc: 0.734344181459566

Component: 371
    ratio: 3.3350846767425537
    top acc: 0.609375
    total acc: 0.8344428007889546

Component: 569
    ratio: 7.134460926055908
    top acc: 0.46875
    total acc: 0.8169378698224852



Baseline top accs:
0: 0.6796875
1: 0.7421875
2: 0.8046875
3: 0.8984375
4: 0.6796875
5: 0.71875
6: 0.7578125
7: 0.6640625
8: 0.7890625
9: 0.734375
10: 0.75
11: 0.8125
12: 0.65625
13: 0.78125
14: 0.78125
15: 0.765625
16: 0.765625
17: 0.65625
18: 0.7421875
19: 0.7265625
20: 0.7421875
21: 0.7890625
22: 0.6875
23: 0.8046875
24: 0.7265625
25: 0.734375
26: 0.828125
27: 0.65625
28: 0.796875
29: 0.734375
30: 0.7109375
31: 0.75
32: 0.703125
33: 0.7734375
34: 0.7734375
35: 0.7734375
36: 0.8125
37: 0.8125
38: 0.8125
39: 0.7578125
40: 0.78125
41: 0.7890625
42: 0.7421875
43: 0.8203125
44: 0.7578125
45: 0.6953125
46: 0.71875
47: 0.765625
48: 0.7578125
49: 0.703125
50: 0.7421875
51: 0.859375
52: 0.78125
53: 0.7265625
54: 0.8046875
55: 0.9140625
56: 0.7421875
57: 0.8203125
58: 0.7421875
59: 0.8125
60: 0.6953125
61: 0.7109375
62: 0.71875
63: 0.828125
201: 0.65625
337: 0.75
371: 0.7109375
569: 0.734375

"""