R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=2 python -i local_scripts/m_npeff/gd_methods/gd_attempt001.py

"""

import dataclasses
from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.datasets import glue
from em.fishers import diagonal
from em.tools.nmf import lrm_npeff
from em.util import flat_pack

from em.projects.m_npeff import snli_context
from em.projects.pi import qqp_components_context as QCC

from em.perturbations.gradient_descent import subspace_regularized1
from em.util.color_util import cu

###############################################################################

FISHER_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/fishers/"
FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"
FISHER_PATH = os.path.join(FISHER_DIR, FISHER_NAME)

NMF_DIR = "/playpen/users/m/project_data/m_npeff1/per_example_fishers"
NMF_NAME = "test_mnpeff_002.expansion_005.coeffs_fit001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

MODEL = "connectivity/feather_berts_0"
TOKENIZER = 'bert-base-uncased'

###############################################################################

LR = 1e-5
SEQUENCE_LENGTH = 128
BATCH_SIZE = 32

N_VAL_EXAMPLES = 2048

###############################################################################
MAX_SIM = 0.35
COMP_INDEX = 34
###############################################################################

nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=True)
nmf.normalize_components_to_unit_norm()

component_g = subspace_regularized1.find_unnormalized_reduced_perturbation(nmf.G, COMP_INDEX, MAX_SIM)

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

og_model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)

model = subspace_regularized1.SingleComponentRegularized(
    model=og_model,
    component_g=component_g,
    new_to_old_col_indices=nmf.new_to_old_col_indices,
    # lmbda_ss=1e-2,
    # lmbda_ss=1e0,
    lmbda_ss=1e-1,
)

model.compile(
    optimizer=tf.keras.optimizers.Adam(LR),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

train_ds = em_datasets.load('snli/default', split='train', sequence_length=SEQUENCE_LENGTH, tokenizer=tokenizer)
train_ds = glue.fix_text_attack_mnli_labeling(train_ds)
train_ds = train_ds.repeat().shuffle(1000).batch(BATCH_SIZE)

val_ds = em_datasets.load('snli/default', split='validation', sequence_length=SEQUENCE_LENGTH, tokenizer=tokenizer)
val_ds = glue.fix_text_attack_mnli_labeling(val_ds)
val_ds = val_ds.take(N_VAL_EXAMPLES).cache().batch(BATCH_SIZE)

###############################################################################

_, baseline_acc = model.evaluate(val_ds)
print(cu.hlg(f'Baseline Val Acc: {baseline_acc}'))

###############################################################################

model.fit(train_ds, steps_per_epoch=64, epochs=10, validation_data=val_ds)

###############################################################################

"""
Baseline Val Acc: 0.8095703125


#####################################################################
lmbda_ss = 1e-2

accuracy: 0.7988 - val_loss: 0.4774 - val_accuracy: 0.8311
accuracy: 0.8110 - val_loss: 0.4468 - val_accuracy: 0.8398
accuracy: 0.8203 - val_loss: 0.4490 - val_accuracy: 0.8457
accuracy: 0.8267 - val_loss: 0.4405 - val_accuracy: 0.8472
accuracy: 0.8145 - val_loss: 0.4342 - val_accuracy: 0.8472
accuracy: 0.8193 - val_loss: 0.4187 - val_accuracy: 0.8545
accuracy: 0.8125 - val_loss: 0.4160 - val_accuracy: 0.8535
accuracy: 0.8306 - val_loss: 0.4194 - val_accuracy: 0.8491
accuracy: 0.8364 - val_loss: 0.4189 - val_accuracy: 0.8491
accuracy: 0.8174 - val_loss: 0.4061 - val_accuracy: 0.8564

#####################################################################
lmbda_ss = 1e0

accuracy: 0.7803 - val_loss: 0.4784 - val_accuracy: 0.8325
accuracy: 0.8051 - val_loss: 0.4797 - val_accuracy: 0.8296
accuracy: 0.8160 - val_loss: 0.4842 - val_accuracy: 0.8286
accuracy: 0.8233 - val_loss: 0.4791 - val_accuracy: 0.8350
accuracy: 0.7993 - val_loss: 0.4797 - val_accuracy: 0.8320
accuracy: 0.7863 - val_loss: 0.4749 - val_accuracy: 0.8320
accuracy: 0.8078 - val_loss: 0.4710 - val_accuracy: 0.8315
accuracy: 0.7996 - val_loss: 0.4663 - val_accuracy: 0.8398
accuracy: 0.8089 - val_loss: 0.4662 - val_accuracy: 0.8291
accuracy: 0.8078 - val_loss: 0.4653 - val_accuracy: 0.8306


#####################################################################
lmbda_ss = 1e-1
accuracy: 0.7971 - val_loss: 0.4766 - val_accuracy: 0.8325
accuracy: 0.7990 - val_loss: 0.4539 - val_accuracy: 0.8359
accuracy: 0.8249 - val_loss: 0.4507 - val_accuracy: 0.8418
accuracy: 0.8094 - val_loss: 0.4452 - val_accuracy: 0.8467
accuracy: 0.8251 - val_loss: 0.4521 - val_accuracy: 0.8472
accuracy: 0.8149 - val_loss: 0.4402 - val_accuracy: 0.8447
accuracy: 0.8041 - val_loss: 0.4331 - val_accuracy: 0.8408
accuracy: 0.8192 - val_loss: 0.4350 - val_accuracy: 0.8472
accuracy: 0.8158 - val_loss: 0.4327 - val_accuracy: 0.8408
accuracy: 0.8187 - val_loss: 0.4277 - val_accuracy: 0.8418

"""