R"""Finetunes a model on GLUE.

cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1

python3 ./scripts1/training/finetune_glue.py  \
    --output_path=/tmp/finetune_glue_test \
    --model=roberta-base \
    --task=rte \
    --batch_size=32 \
    --n_epochs=10

"""
import os


from absl import app
from absl import flags
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em.datasets import glue
from em.util import vat_da_faak_vpn

FLAGS = flags.FLAGS


flags.DEFINE_string("output_path", None, "Path to folder that will be created containing HF information.")

flags.DEFINE_string("model", None, "")
flags.DEFINE_string("task", None, "")

flags.DEFINE_bool("from_pt", True, "")

flags.DEFINE_string("split", "train", "")
flags.DEFINE_string("val_split", "validation", "")
flags.DEFINE_integer("batch_size", 32, "")
flags.DEFINE_float("learning_rate", 1e-5, "")
flags.DEFINE_integer("sequence_length", 128, "")

flags.DEFINE_integer("n_val_examples", 2048, "")


# Exactly one of these must be set.
flags.DEFINE_integer("n_epochs", None, "")
flags.DEFINE_integer("n_steps", None, "")


def load_datasets(tokenizer):
    train_ds = glue.load_glue_dataset(
        task=FLAGS.task,
        split=FLAGS.split,
        tokenizer=tokenizer,
        max_length=FLAGS.sequence_length,
    )
    val_ds = glue.load_glue_dataset(
        task=FLAGS.task,
        split=FLAGS.val_split,
        tokenizer=tokenizer,
        max_length=FLAGS.sequence_length,
    )
    train_ds = train_ds.repeat().shuffle(1000).batch(FLAGS.batch_size)
    val_ds = val_ds.take(FLAGS.n_val_examples).batch(FLAGS.batch_size)
    return train_ds, val_ds


def load_model():
    model_str = os.path.expanduser(FLAGS.model)
    model = TFAutoModelForSequenceClassification.from_pretrained(
        model_str, from_pt=FLAGS.from_pt, num_labels=glue.get_n_classes(FLAGS.task)
    )
    tokenizer = AutoTokenizer.from_pretrained(model_str)

    model.compile(
        optimizer=tf.keras.optimizers.Adam(FLAGS.learning_rate),
        # TODO: Support metrics other than accuracy based on the task.
        metrics=tf.keras.metrics.SparseCategoricalAccuracy(),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    )
    return model, tokenizer


def get_train_params():
    assert (FLAGS.n_epochs is None) != (FLAGS.n_steps is None)
    if FLAGS.n_epochs is not None:
        n_examples_per_epoch = glue.NUM_GLUE_TRAIN_EXAMPLES[FLAGS.task]
        n_total_examples = FLAGS.n_epochs * n_examples_per_epoch
        n_steps = n_total_examples // FLAGS.batch_size
    elif FLAGS.n_steps is not None:
        n_steps = FLAGS.n_steps
    else:
        raise ValueError
    # TODO: Let this be set, only affects how often keras evaluates during training.
    epochs = 10
    steps_per_epoch = n_steps // epochs
    return epochs, steps_per_epoch


def main(_):
    raise Exception("This script is deprecated. Use `finetune.py` instead with 'glue/' prefix on the task.")

    model, tokenizer = load_model()
    train_ds, val_ds = load_datasets(tokenizer)
    epochs, steps_per_epoch = get_train_params()
    model.fit(
        train_ds,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        validation_data=val_ds,
    )
    # save_directory
    output_path = os.path.expanduser(FLAGS.output_path)
    model.save_pretrained(output_path)


if __name__ == "__main__":
    app.run(main)
