R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/pi/qnli_dev01.py

"""
from importlib import reload


import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets

##########################################################################

TOKENIZER = 'bert-base-uncased'

# MODEL = "textattack/bert-base-uncased-QNLI"
# DS = 'glue/qnli'

MODEL = "textattack/bert-base-uncased-QQP"
DS = 'glue/qqp'

# SEQUENCE_LENGTH = 128
SEQUENCE_LENGTH = 64
# BATCH_SIZE = 16
BATCH_SIZE = 32

N_EXAMPLES = 5_000

##########################################################################

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

train_ds = em_datasets.load(
    DS,
    split='train',
    sequence_length=SEQUENCE_LENGTH,
    tokenizer=tokenizer,
)
val_ds = em_datasets.load(
    DS,
    split='validation',
    sequence_length=SEQUENCE_LENGTH,
    tokenizer=tokenizer,
)

model.evaluate(train_ds.take(N_EXAMPLES).batch(BATCH_SIZE))
model.evaluate(val_ds.take(N_EXAMPLES).batch(BATCH_SIZE))

"""
QNLI:
313/313 [==============================] - 36s 93ms/step - loss: 0.0796 - sparse_categorical_accuracy: 0.9785
313/313 [==============================] - 29s 93ms/step - loss: 0.2557 - sparse_categorical_accuracy: 0.9148

QQP: [seqlen=128]
313/313 [==============================] - 34s 91ms/step - loss: 0.0807 - sparse_categorical_accuracy: 0.9698
313/313 [==============================] - 28s 88ms/step - loss: 0.2520 - sparse_categorical_accuracy: 0.9074

QQP: [seqlen=64]
157/157 [==============================] - 21s 99ms/step - loss: 0.0809 - sparse_categorical_accuracy: 0.9696
157/157 [==============================] - 16s 103ms/step - loss: 0.2527 - sparse_categorical_accuracy: 0.9072


QQP has a large test set.
- Let's use them later.
- For now, let's use first 20k examples of validation for Fisher stuff, then last 20k for evaluation.

"""
