R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/transfer1/classifier_head_dev001.py

"""
from importlib import reload
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.tools.nmf import nmf_common
from em.util import flat_pack
from em.util import hf_util

###############################################################################

# PRETRAINED_MODEL = "roberta-base"
PRETRAINED_MODEL = "bert-base-uncased"
FROM_PT = True

# TARGET_TASK = 'glue/rte'
TARGET_TASK = 'glue/qnli'

SEQUENCE_LENGTH = 128
# BATCH_SIZE = 32
BATCH_SIZE = 64
# LR = 4e-5
# LR = 3e-4
LR = 1e-3
CLIPNORM = 0.1
# CLIPNORM = None

N_VAL_EXAMPLES = 4096

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

output_model = TFAutoModelForSequenceClassification.from_pretrained(
    PRETRAINED_MODEL,
    from_pt=FROM_PT,
    num_labels=em_datasets.n_classes_for_task(TARGET_TASK)
)

train_ds = em_datasets.load(
    TARGET_TASK,
    split='train',
    tokenizer=tokenizer,
    sequence_length=SEQUENCE_LENGTH,
)
train_ds = train_ds.repeat().shuffle(1000).batch(BATCH_SIZE)
val_ds = em_datasets.load(
    TARGET_TASK,
    split='train',
    tokenizer=tokenizer,
    sequence_length=SEQUENCE_LENGTH,
)
val_ds = val_ds.take(N_VAL_EXAMPLES).batch(BATCH_SIZE)

###########################################################

# Try training the classifier head first.
# output_model.roberta.trainable = False
output_model.bert.trainable = False

# output_model.compile(
#     optimizer=tf.keras.optimizers.Adam(LR, clipnorm=CLIPNORM),
#     metrics=tf.keras.metrics.SparseCategoricalAccuracy(),
#     loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
# )

# output_model.fit(
#     train_ds,
#     steps_per_epoch=512,
#     epochs=16,
#     validation_data=val_ds,
# )

for x, y in train_ds:
    arf = output_model.bert(x)
    # arf.pooler_output
    break

# Cache, then train?


# Cache embeddings, then train.
