import argparse
import logging
import sys
import time

import tensorflow as tf
from datasets import load_dataset
from packaging.version import parse

from transformers import AutoTokenizer, TFAutoModelForSequenceClassification


try:
    import tf_keras as keras
except (ModuleNotFoundError, ImportError):
    import keras

    if parse(keras.__version__).major > 2:
        raise ValueError(
            "Your currently installed version of Keras is Keras 3, but this is not yet supported in "
            "Transformers. Please install the backwards-compatible tf-keras package with "
            "`pip install tf-keras`."
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Hyperparameters sent by the client are passed as command-line arguments to the script.
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--per_device_train_batch_size", type=int, default=16)
    parser.add_argument("--per_device_eval_batch_size", type=int, default=8)
    parser.add_argument("--model_name_or_path", type=str)
    parser.add_argument("--learning_rate", type=str, default=5e-5)
    parser.add_argument("--do_train", type=bool, default=True)
    parser.add_argument("--do_eval", type=bool, default=True)
    parser.add_argument("--output_dir", type=str)

    args, _ = parser.parse_known_args()

    # overwrite batch size until we have tf_glue.py
    args.per_device_train_batch_size = 16
    args.per_device_eval_batch_size = 16

    # Set up logging
    logger = logging.getLogger(__name__)

    logging.basicConfig(
        level=logging.getLevelName("INFO"),
        handlers=[logging.StreamHandler(sys.stdout)],
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )

    # Load model and tokenizer
    model = TFAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

    # Load dataset
    train_dataset, test_dataset = load_dataset("imdb", split=["train", "test"])
    train_dataset = train_dataset.shuffle().select(range(5000))  # smaller the size for train dataset to 5k
    test_dataset = test_dataset.shuffle().select(range(500))  # smaller the size for test dataset to 500

    # Preprocess train dataset
    train_dataset = train_dataset.map(
        lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True
    )
    train_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"])

    train_features = {
        x: train_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length])
        for x in ["input_ids", "attention_mask"]
    }
    tf_train_dataset = tf.data.Dataset.from_tensor_slices((train_features, train_dataset["label"])).batch(
        args.per_device_train_batch_size
    )

    # Preprocess test dataset
    test_dataset = test_dataset.map(
        lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True
    )
    test_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"])

    test_features = {
        x: test_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length])
        for x in ["input_ids", "attention_mask"]
    }
    tf_test_dataset = tf.data.Dataset.from_tensor_slices((test_features, test_dataset["label"])).batch(
        args.per_device_eval_batch_size
    )

    # fine optimizer and loss
    optimizer = keras.optimizers.Adam(learning_rate=args.learning_rate)
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    metrics = [keras.metrics.SparseCategoricalAccuracy()]
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    start_train_time = time.time()
    train_results = model.fit(tf_train_dataset, epochs=args.epochs, batch_size=args.per_device_train_batch_size)
    end_train_time = time.time() - start_train_time

    logger.info("*** Train ***")
    logger.info(f"train_runtime = {end_train_time}")
    for key, value in train_results.history.items():
        logger.info(f"  {key} = {value}")
