R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 -i local_scripts/ead/ead_dev001.py
CUDA_VISIBLE_DEVICES=0 python -i local_scripts/ead/ead_dev001.py

"""
import itertools

import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em.datasets.antiderivative import antiderivative_ds
from em.util import vat_da_faak_vpn

###################################################################################

# DS_FILE_PATH = "~/Desktop/projects_data/extract_merge1/antiderivative/datasets/expressions001_ead.3M.00.5s.csv"
DS_FILE_PATH = "~/Desktop/projects_data/extract_merge1/antiderivative/datasets/expressions001_ead.3M.01.5s.csv"

###################################################################################

PRETRAINED_MODEL = 'bert-base-uncased'
SEQUENCE_LENGTH = 128
BATCH_SIZE = 8
EVAL_BATCH_SIZE = 32

LR = 3e-5
CLIPNORM = 0.1

###################################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

ds = antiderivative_ds.load_raw_from_file(DS_FILE_PATH, skip_unlabeled=True)
stats = antiderivative_ds.get_stats(ds)
print(stats)
print(stats.total_count)

ds = antiderivative_ds.convert_dataset_to_features(ds, tokenizer, sequence_length=SEQUENCE_LENGTH)

ds = ds.repeat().shuffle(1000).batch(BATCH_SIZE)

model = TFAutoModelForSequenceClassification.from_pretrained(PRETRAINED_MODEL, from_pt=True)

model.compile(
    optimizer=tf.keras.optimizers.Adam(LR, clipnorm=CLIPNORM),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)

model.fit(ds, steps_per_epoch=64, epochs=4)

###################################################################################
# ds = antiderivative_ds.load_raw_from_file(DS_FILE_PATH, skip_unlabeled=False)

# for x in ds:
#     break
# stats = antiderivative_ds.get_stats(ds)
# print(stats)

# Skip long examples (probably by char, mostly to eliminate examples with huge numbers.)
# Remember to filter out the unlabelled once I'm done looking at stats.


# for x, y in itertools.islice(ds.as_numpy_iterator(), 64):
#     x = tf.compat.as_str(x)
#     print(tokenizer.decode(tokenizer.encode(x)))


# expr.is_algebraic_expr(x)
# expr.is_constant(x)
# expr.is_polynomial(x)
# expr.is_rational_function(x)