R"""Script for general finetuning.

This should supercede the `finetune_glue.py` script.


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1

python3 ./scripts1/training/finetune.py  \
    --output_path=/tmp/finetune_glue_test \
    --model=roberta-base \
    --task=glue/rte \
    --batch_size=32 \
    --n_epochs=10
"""

import os

from absl import app
from absl import flags
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
from transformers import TFRobertaForSequenceClassification, TFBertForSequenceClassification

from em import datasets as em_datasets
from em.util import hf_util
from em.util import vat_da_faak_vpn

FLAGS = flags.FLAGS


flags.DEFINE_string("output_path", None, "Path to folder that will be created containing HF information.")

flags.DEFINE_string("model", None, "")
flags.DEFINE_bool("from_pt", True, "")

# These will only be used when --model is not set. In that case, we will
# be training a BERT-style model from scratch.
flags.DEFINE_integer("hidden_size", None, "")
flags.DEFINE_integer("num_hidden_layers", None, "")

flags.DEFINE_string("tokenizer", None, "Defaults to `model` if not set.")

flags.DEFINE_string("task", None, "")

flags.DEFINE_string("split", None, "Uses a default if not set.")
flags.DEFINE_string("val_split", None, "Uses a default if not set.")
flags.DEFINE_integer("batch_size", 32, "")
flags.DEFINE_float("learning_rate", 1e-5, "")
flags.DEFINE_integer("sequence_length", 128, "")

flags.DEFINE_integer("n_val_examples", 2048, "")

flags.DEFINE_float("clipnorm", None, "")

flags.DEFINE_integer("force_n_labels", None, "")

# Exactly one of these must be set.
flags.DEFINE_integer("n_epochs", None, "")
flags.DEFINE_integer("n_steps", None, "")


def load_datasets(tokenizer):
    task = FLAGS.task

    train_ds = em_datasets.load(
        task,
        split=FLAGS.split or 'train',
        tokenizer=tokenizer,
        sequence_length=FLAGS.sequence_length,
    )
    train_ds = train_ds.repeat().shuffle(1000).batch(FLAGS.batch_size)

    val_split = FLAGS.val_split or em_datasets.de_facto_validation_split(task)
    if val_split is not None:
        val_ds = em_datasets.load(
            task,
            split=val_split,
            tokenizer=tokenizer,
            sequence_length=FLAGS.sequence_length,
        )
        val_ds = val_ds.take(FLAGS.n_val_examples).batch(FLAGS.batch_size)
    else:
        val_ds = None
    return train_ds, val_ds


def load_model():
    model_str = FLAGS.model and os.path.expanduser(FLAGS.model)
    tokenizer_str = FLAGS.tokenizer and os.path.expanduser(FLAGS.tokenizer)
    num_labels = FLAGS.force_n_labels or em_datasets.n_classes_for_task(FLAGS.task)

    if FLAGS.tokenizer is None:
        assert model_str is not None
        tokenizer = AutoTokenizer.from_pretrained(model_str)
    else:
        tokenizer = AutoTokenizer.from_pretrained(os.path.expanduser(FLAGS.tokenizer))

    if FLAGS.model is None:
        assert tokenizer_str is not None
        assert FLAGS.hidden_size is not None
        assert FLAGS.num_hidden_layers is not None

        # roberta_config = hf_util.make_roberta_config(
        #     tokenizer=tokenizer,
        #     hidden_size=FLAGS.hidden_size,
        #     num_hidden_layers=FLAGS.num_hidden_layers,
        #     max_position_embeddings=FLAGS.sequence_length,
        # )
        # roberta_config.num_labels = num_labels
        # model = TFRobertaForSequenceClassification(roberta_config)
        bert_config = hf_util.make_bert_config(
            tokenizer=tokenizer,
            hidden_size=FLAGS.hidden_size,
            num_hidden_layers=FLAGS.num_hidden_layers,
            max_position_embeddings=FLAGS.sequence_length,
        )
        bert_config.num_labels = num_labels
        model = TFBertForSequenceClassification(bert_config)

    else:
        model = TFAutoModelForSequenceClassification.from_pretrained(
            model_str, from_pt=FLAGS.from_pt, num_labels=num_labels
        )

    model.compile(
        optimizer=tf.keras.optimizers.Adam(FLAGS.learning_rate, clipnorm=FLAGS.clipnorm),
        # TODO: Support metrics other than accuracy based on the task.
        metrics=tf.keras.metrics.SparseCategoricalAccuracy(),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    )
    return model, tokenizer


def get_train_params():
    assert (FLAGS.n_epochs is None) != (FLAGS.n_steps is None)
    if FLAGS.n_epochs is not None:
        n_examples_per_epoch = em_datasets.examples_per_epoch(FLAGS.task)
        if n_examples_per_epoch is None:
            raise ValueError(f'Training by number of epochs is not supported for the task {FLAGS.task}')
        n_total_examples = FLAGS.n_epochs * n_examples_per_epoch
        n_steps = n_total_examples // FLAGS.batch_size
    elif FLAGS.n_steps is not None:
        n_steps = FLAGS.n_steps
    else:
        raise ValueError
    # TODO: Let this be set, only affects how often keras evaluates during training.
    epochs = 10
    steps_per_epoch = n_steps // epochs
    return epochs, steps_per_epoch


def main(_):
    model, tokenizer = load_model()
    train_ds, val_ds = load_datasets(tokenizer)
    epochs, steps_per_epoch = get_train_params()
    output_path = os.path.expanduser(FLAGS.output_path)
    model.fit(
        train_ds,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        validation_data=val_ds,
        callbacks=[hf_util.HfSaverCallback(output_path)]
    )


if __name__ == "__main__":
    app.run(main)
