# wrong_perturb001.py
R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/m_npeff/qqp/wrong_perturb001.py

"""

import dataclasses
from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.fishers import diagonal
from em.tools.nmf import lrm_npeff
from em.fishers import lrm_pefs
from em.util import flat_pack

from em.projects.m_npeff import qqp_context
from em.projects.pi import qqp_components_context as QCC

from em.util.color_util import cu


###############################################################################

NMF_DIR = "/playpen/users/m/project_data/qqp_lrm_npeff2/per_example_fishers"
NMF_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.heldout_from_train.50000ex.65536.wrongs_only.expansion_001.coeffs_fit_to_validation001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

###############################################################################

PEFS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/qqp_lrm_npeff2/per_example_fishers"
PEFS_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.validation.40430ex.65536.h5"
PEFS_PATH = os.path.join(PEFS_DIR, PEFS_NAME)

###############################################################################

MODEL_DIR = "/fruitbasket/users/m/project_data/extract_merge1/qqp_lrm_npeff2/models"
MODEL_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9"
MODEL = os.path.join(MODEL_DIR, MODEL_NAME)

###############################################################################

TOKENIZER = 'bert-base-uncased'
SPLIT = 'validation'

###############################################################################

og_model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=False)


print('Starting to read in decomposition.')
nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=True)
print('Decomposition read in.')
nmf.normalize_components_to_unit_norm()
print('Decomposition components normalized.')

logits = lrm_pefs.SparseLrmPefs.load_logits(PEFS_PATH)

ctx = qqp_context.QqpContext(
    split=SPLIT,
    tokenizer=AutoTokenizer.from_pretrained(TOKENIZER),
    nmf=nmf,
    load_examples=False,
    # load_examples=True,
)
print('QQP context made.')

###############################################################################

N_TOTAL_EXAMPLES = 8 * 1014
reload(qqp_context); ctx.__class__ = qqp_context.QqpContext
eval_ctx = ctx.create_eval_ctx_given_logits(logits)
print('Eval context made.')

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=False)

###############################################################################


def find_unnormalized_perturbation(component_index: int, max_sim: float = 1e9):
    # Assumes rows of G have unit norm.
    G = nmf.G
    #
    g_main = np.copy(G[component_index])
    #
    if max_sim > 0.0:
        for i in range(G.shape[0]):
            if i == component_index:
                continue
            if np.abs(G[component_index].dot(G[i])) > max_sim:
                continue
            g_main -= g_main.dot(G[i]) * G[i]
        #
    g = np.zeros([nmf.n_parameters], dtype=np.float32)
    g[nmf.new_to_old_col_indices] = g_main
    return g

###############################################################################


@dataclasses.dataclass
class PerturbationResults:
    top_results: QCC.QqpEvaluationResults
    total_results: QCC.QqpEvaluationResults
    
    def ratio(self):
        return self.top_results.kl() / self.total_results.kl()


def perturb_weights(offsets, multiplier):
    for ogv, v, offset in zip(og_model.trainable_variables, model.trainable_variables, offsets):
        v.assign(ogv + multiplier * offset)


def eval_ratio(component_index: int, n_top_examples: int = 128, n_total_examples=N_TOTAL_EXAMPLES):
    top_inds = np.argsort(-nmf.W[:, component_index])[:n_top_examples]
    top_results = eval_ctx.evaluate(model, top_inds)
    total_results = eval_ctx.evaluate(model, np.arange(n_total_examples))
    return PerturbationResults(top_results=top_results, total_results=total_results)


def run_and_eval(component_index: int, multiplier_mag=1.0, max_sim: float = 1e9):
    g = find_unnormalized_perturbation(component_index, max_sim)
    g /= np.sqrt(np.sum(g**2))
    #
    packer = flat_pack.FlatPacker([v.shape for v in model.trainable_variables])
    offsets = packer.decode_tf(tf.cast(g, tf.float32))
    #
    perturb_weights(offsets, multiplier_mag)
    r1 = eval_ratio(component_index)
    #
    perturb_weights(offsets, -multiplier_mag)
    r2 = eval_ratio(component_index)
    #
    return r1 if r1.ratio() > r2.ratio() else r2

###############################################################################


mp = 3e-1

COMP_INDS = [5, 18, 34, 44]
# for comp in COMP_INDS:
for comp in range(64):
    results = run_and_eval(comp, mp, max_sim=0.35)
    print(f'Component: {comp}')
    print(f'    ratio: {results.ratio()}')
    print(f'    top acc: {results.top_results.acc()}')
    print(f'    total acc: {results.total_results.acc()}')
    print('')


results = run_and_eval(comp, 0, max_sim=0.35)
print(f'Component: {comp}')
print(f'    ratio: {results.ratio()}')
print(f'    top acc: {results.top_results.acc()}')
print(f'    total acc: {results.total_results.acc()}')
print('')

"""
2023-04-30 11:25:43.294204: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 437935112 exceeds 10% of free system memory.
Component: 5
    ratio: 6.8585638999938965
    top acc: 0.3671875
    total acc: 0.8465236686390533

2023-04-30 11:26:24.033175: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 437935112 exceeds 10% of free system memory.
Component: 18
    ratio: 7.807177543640137
    top acc: 0.328125
    total acc: 0.8657544378698225

2023-04-30 11:27:06.613205: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 437935112 exceeds 10% of free system memory.
Component: 34
    ratio: 23.40241241455078
    top acc: 0.59375
    total acc: 0.8984220907297831

2023-04-30 11:27:49.270168: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 437935112 exceeds 10% of free system memory.
Component: 44
    ratio: 9.904426574707031
    top acc: 0.53125
    total acc: 0.8852317554240631

>>> 
>>> results = run_and_eval(comp, 0, max_sim=0.35)
2023-04-30 11:28:55.440169: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 437935112 exceeds 10% of free system memory.
>>> print(f'Component: {comp}')
Component: 44
>>> print(f'    ratio: {results.ratio()}')
    ratio: 3.987076997756958
>>> print(f'    top acc: {results.top_results.acc()}')
    top acc: 0.640625
>>> print(f'    total acc: {results.total_results.acc()}')
    total acc: 0.9000246548323472
>>> print('')


Component: 0
    ratio: 9.856927871704102
    top acc: 0.5703125
    total acc: 0.8831360946745562

Component: 1
    ratio: 2.7159457206726074
    top acc: 0.984375
    total acc: 0.879930966469428

Component: 2
    ratio: 30.466766357421875
    top acc: 0.796875
    total acc: 0.8981755424063116

Component: 3
    ratio: 4.927387714385986
    top acc: 0.578125
    total acc: 0.844551282051282

Component: 4
    ratio: 6.092838764190674
    top acc: 0.28125
    total acc: 0.8500986193293886

Component: 5
    ratio: 6.8585638999938965
    top acc: 0.3671875
    total acc: 0.8465236686390533

Component: 6
    ratio: 8.305885314941406
    top acc: 0.4921875
    total acc: 0.877095660749507

Component: 7
    ratio: 1.7026410102844238
    top acc: 1.0
    total acc: 0.8982988165680473

Component: 8
    ratio: 4.481677055358887
    top acc: 0.390625
    total acc: 0.8267998027613412

Component: 9
    ratio: 10.14175033569336
    top acc: 0.3671875
    total acc: 0.8671104536489151

Component: 10
    ratio: 15.046124458312988
    top acc: 0.6640625
    total acc: 0.8899161735700197

Component: 11
    ratio: 6.196431636810303
    top acc: 0.421875
    total acc: 0.8553994082840237

Component: 12
    ratio: 8.097901344299316
    top acc: 0.3671875
    total acc: 0.8563856015779092

Component: 13
    ratio: 5.725308418273926
    top acc: 0.2265625
    total acc: 0.8240877712031558

Component: 14
    ratio: 6.17897891998291
    top acc: 0.390625
    total acc: 0.8595907297830375

Component: 15
    ratio: 5.607953071594238
    top acc: 0.9375
    total acc: 0.8958333333333334

Component: 16
    ratio: 26.723108291625977
    top acc: 0.8359375
    total acc: 0.8975591715976331

Component: 17
    ratio: 25.12314796447754
    top acc: 0.2265625
    total acc: 0.8843688362919132

Component: 18
    ratio: 7.807177543640137
    top acc: 0.328125
    total acc: 0.8657544378698225

Component: 19
    ratio: 3.3026788234710693
    top acc: 0.140625
    total acc: 0.6586538461538461

Component: 20
    ratio: 6.022406101226807
    top acc: 0.8671875
    total acc: 0.8960798816568047

Component: 21
    ratio: 11.894643783569336
    top acc: 0.1875
    total acc: 0.8615631163708086

Component: 22
    ratio: 8.934998512268066
    top acc: 0.59375
    total acc: 0.8789447731755424

Component: 23
    ratio: 3.938901901245117
    top acc: 0.9296875
    total acc: 0.895956607495069

Component: 24
    ratio: 1.8111484050750732
    top acc: 0.2109375
    total acc: 0.6367110453648915

Component: 25
    ratio: 5.519041538238525
    top acc: 0.96875
    total acc: 0.8948471400394478

Component: 26
    ratio: 12.799822807312012
    top acc: 0.4140625
    total acc: 0.872904339250493

Component: 27
    ratio: 14.145071029663086
    top acc: 0.5078125
    total acc: 0.8867110453648915

Component: 28
    ratio: 4.056584358215332
    top acc: 0.09375
    total acc: 0.7493836291913215

Component: 29
    ratio: 9.235061645507812
    top acc: 0.921875
    total acc: 0.8949704142011834

Component: 30
    ratio: 14.710883140563965
    top acc: 0.5546875
    total acc: 0.8944773175542406

Component: 31
    ratio: 3.534416913986206
    top acc: 0.0703125
    total acc: 0.7107988165680473

Component: 32
    ratio: 0.751900851726532
    top acc: 0.984375
    total acc: 0.8936143984220908

Component: 33
    ratio: 12.172329902648926
    top acc: 0.203125
    total acc: 0.8710552268244576

Component: 34
    ratio: 23.40241241455078
    top acc: 0.59375
    total acc: 0.8984220907297831

Component: 35
    ratio: 5.392189025878906
    top acc: 0.4296875
    total acc: 0.8502218934911243

Component: 36
    ratio: 6.423355579376221
    top acc: 0.1796875
    total acc: 0.8237179487179487

Component: 37
    ratio: 6.934769153594971
    top acc: 0.5390625
    total acc: 0.8827662721893491

Component: 38
    ratio: 7.094363689422607
    top acc: 0.8125
    total acc: 0.8990384615384616

Component: 39
    ratio: 11.014043807983398
    top acc: 0.484375
    total acc: 0.8827662721893491

Component: 40
    ratio: 4.019016265869141
    top acc: 0.0546875
    total acc: 0.7369329388560157

Component: 41
    ratio: 17.845386505126953
    top acc: 0.453125
    total acc: 0.8910256410256411

Component: 42
    ratio: 5.2566142082214355
    top acc: 0.9453125
    total acc: 0.8948471400394478

Component: 43
    ratio: 4.054774284362793
    top acc: 1.0
    total acc: 0.8872041420118343

Component: 44
    ratio: 9.904426574707031
    top acc: 0.53125
    total acc: 0.8852317554240631

Component: 45
    ratio: 14.50318431854248
    top acc: 0.2421875
    total acc: 0.8745069033530573

Component: 46
    ratio: 4.603786945343018
    top acc: 0.0859375
    total acc: 0.7622041420118343

Component: 47
    ratio: 0.7701278924942017
    top acc: 0.984375
    total acc: 0.8925049309664694

Component: 48
    ratio: 1.0656296014785767
    top acc: 0.9140625
    total acc: 0.8999013806706114

Component: 49
    ratio: 5.732612133026123
    top acc: 0.546875
    total acc: 0.8708086785009862

Component: 50
    ratio: 5.972711563110352
    top acc: 0.6875
    total acc: 0.8847386587771203

Component: 51
    ratio: 19.757625579833984
    top acc: 0.46875
    total acc: 0.8872041420118343

Component: 52
    ratio: 3.608445644378662
    top acc: 0.0703125
    total acc: 0.7299063116370809

Component: 53
    ratio: 8.64585018157959
    top acc: 0.4296875
    total acc: 0.8775887573964497

Component: 54
    ratio: 10.77694034576416
    top acc: 0.6640625
    total acc: 0.8931213017751479

Component: 55
    ratio: 4.215232849121094
    top acc: 0.1484375
    total acc: 0.7437130177514792

Component: 56
    ratio: 12.4645357131958
    top acc: 0.7890625
    total acc: 0.8981755424063116

Component: 57
    ratio: 12.896743774414062
    top acc: 0.515625
    total acc: 0.8884368836291914

Component: 58
    ratio: 8.852889060974121
    top acc: 0.3203125
    total acc: 0.871301775147929

Component: 59
    ratio: 16.52035903930664
    top acc: 0.2578125
    total acc: 0.8790680473372781

Component: 60
    ratio: 13.159024238586426
    top acc: 0.4296875
    total acc: 0.8858481262327417

Component: 61
    ratio: 5.165931224822998
    top acc: 0.8515625
    total acc: 0.8943540433925049

Component: 62
    ratio: 6.486883163452148
    top acc: 0.2890625
    total acc: 0.8397435897435898

Component: 63
    ratio: 0.7975021004676819
    top acc: 0.9375
    total acc: 0.8927514792899408

>>> 

"""