R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=2 python -i local_scripts/m_npeff/gd_methods/gd_baseline001.py

"""

import dataclasses
from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.datasets import glue
from em.fishers import diagonal
from em.tools.nmf import lrm_npeff
from em.util import flat_pack

from em.projects.m_npeff import snli_context
from em.projects.pi import qqp_components_context as QCC

from em.util.color_util import cu

###############################################################################

FISHER_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/fishers/"
FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"
FISHER_PATH = os.path.join(FISHER_DIR, FISHER_NAME)

NMF_DIR = "/playpen/users/m/project_data/m_npeff1/per_example_fishers"
NMF_NAME = "test_mnpeff_002.expansion_005.coeffs_fit001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

MODEL = "connectivity/feather_berts_0"
TOKENIZER = 'bert-base-uncased'

###############################################################################

LR = 1e-5
SEQUENCE_LENGTH = 128
BATCH_SIZE = 32

N_VAL_EXAMPLES = 2048

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)
model.compile(
    optimizer=tf.keras.optimizers.Adam(LR),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

train_ds = em_datasets.load('snli/default', split='train', sequence_length=SEQUENCE_LENGTH, tokenizer=tokenizer)
train_ds = glue.fix_text_attack_mnli_labeling(train_ds)
train_ds = train_ds.repeat().shuffle(1000).batch(BATCH_SIZE)

val_ds = em_datasets.load('snli/default', split='validation', sequence_length=SEQUENCE_LENGTH, tokenizer=tokenizer)
val_ds = glue.fix_text_attack_mnli_labeling(val_ds)
val_ds = val_ds.take(N_VAL_EXAMPLES).cache().batch(BATCH_SIZE)

###############################################################################

_, baseline_acc = model.evaluate(val_ds)
print(cu.hlg(f'Baseline Val Acc: {baseline_acc}'))

###############################################################################

model.fit(train_ds, steps_per_epoch=64, epochs=10, validation_data=val_ds)

###############################################################################

"""
Baseline Val Acc: 0.8095703125

accuracy: 0.7915 - val_loss: 0.4561 - val_accuracy: 0.8311
accuracy: 0.8203 - val_loss: 0.4458 - val_accuracy: 0.8418
accuracy: 0.8218 - val_loss: 0.4386 - val_accuracy: 0.8428
accuracy: 0.8232 - val_loss: 0.4276 - val_accuracy: 0.8472
accuracy: 0.8203 - val_loss: 0.4280 - val_accuracy: 0.8472
accuracy: 0.8311 - val_loss: 0.4175 - val_accuracy: 0.8525
accuracy: 0.8086 - val_loss: 0.4170 - val_accuracy: 0.8481
accuracy: 0.8369 - val_loss: 0.4131 - val_accuracy: 0.8569
accuracy: 0.8281 - val_loss: 0.4201 - val_accuracy: 0.8491
accuracy: 0.8213 - val_loss: 0.4024 - val_accuracy: 0.8564

"""