R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i local_scripts/m_npeff/gd_methods/gd_attempt_thing001.py

"""

import dataclasses
from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.datasets import glue
from em.fishers import diagonal
from em.tools.nmf import lrm_npeff
from em.util import flat_pack

from em.projects.m_npeff import snli_context
from em.projects.pi import qqp_components_context as QCC

from em.perturbations.gradient_descent import subspace_regularized1
from em.util.color_util import cu

###############################################################################

FISHER_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/fishers/"
FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"
FISHER_PATH = os.path.join(FISHER_DIR, FISHER_NAME)

NMF_DIR = "/playpen/users/m/project_data/m_npeff1/per_example_fishers"
NMF_NAME = "test_mnpeff_002.expansion_005.coeffs_fit001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

MODEL = "connectivity/feather_berts_0"
TOKENIZER = 'bert-base-uncased'

###############################################################################

LR = 1e-5
SEQUENCE_LENGTH = 128
BATCH_SIZE = 32

N_VAL_EXAMPLES = 2048

###############################################################################
MAX_SIM = 0.35
# COMP_INDEX = 34
# COMP_INDEX = 40
# COMP_INDEX = 13
# COMP_INDEX = 210
COMP_INDEX = 233
###############################################################################

nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=True)
nmf.normalize_components_to_unit_norm()

component_g = subspace_regularized1.find_unnormalized_reduced_perturbation(nmf.G, COMP_INDEX, MAX_SIM)

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)

model.compile(
    optimizer=tf.keras.optimizers.Adam(LR),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

train_ds = em_datasets.load('snli/default', split='train', sequence_length=SEQUENCE_LENGTH, tokenizer=tokenizer)
train_ds = glue.fix_text_attack_mnli_labeling(train_ds)
train_ds = train_ds.repeat().shuffle(1000).batch(BATCH_SIZE)

val_ds = em_datasets.load('snli/default', split='validation', sequence_length=SEQUENCE_LENGTH, tokenizer=tokenizer)
val_ds = glue.fix_text_attack_mnli_labeling(val_ds)
val_ds = val_ds.take(N_VAL_EXAMPLES).cache().batch(BATCH_SIZE)

###############################################################################

###############################################################################

model.fit(train_ds, steps_per_epoch=64, epochs=5, validation_data=val_ds)

###############################################################################

_, baseline_acc = model.evaluate(val_ds)
print(cu.hlg(f'Baseline Val Acc: {baseline_acc}'))

reload(subspace_regularized1)
model2 = subspace_regularized1.SingleComponentRegularized(
    model=model,
    component_g=component_g,
    new_to_old_col_indices=nmf.new_to_old_col_indices,
    # lmbda_ss=1e-2,
    # lmbda_ss=1e0,
    lmbda_ss=1e-1,
)

mp = 8e-2
model2.perturb_weights_by_component(mp)
_, perturbed_acc = model.evaluate(val_ds)
print(cu.hlg(f'Perturbed Val Acc: {perturbed_acc}'))
model2.perturb_weights_by_component(-mp)

"""
Baseline Val Acc: 0.84375

comp 34, -1e-3
Perturbed Val Acc: 0.84423828125

comp 40, 1e-1
Perturbed Val Acc: 0.8466796875

comp 13, 3e-2
Perturbed Val Acc: 0.84619140625

comp 210, -1e-1
Perturbed Val Acc: 0.84619140625

comp 233, 7e-2
Perturbed Val Acc: 0.84814453125

"""
