# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Script to centrally train next char prediction model in prediction_model.py.

If wishing to centrally train from scratch your own next char predictor, you
can use this script.

To run, first build the binary:
  bazel build /path/to/this/script:centrally_train_prediction_model
Then execute it:
  ./path/to/binary/centrally_train_prediction_model --epochs=3
"""

from absl import app
from absl import flags
import tensorflow as tf

from mixed_fl.experiments.next_char_prediction import data_utils
from mixed_fl.experiments.next_char_prediction import datasets
from mixed_fl.experiments.next_char_prediction import prediction_model

Model = tf.keras.Model

flags.DEFINE_string(
    'path_to_save_model_checkpoint', '/tmp/next_char_prediction_checkpoint.hd5',
    'The path/name at which to save a trained checkpoint of the '
    'next char prediction model, after --epochs of training.')
flags.DEFINE_integer('epochs', 1, 'The # of epochs to train for.')
flags.DEFINE_enum(
    'dataset_restriction', 'all',
    datasets.get_possible_dataset_splits(),
    'What restrictions to put on the training dataset.')

FLAGS = flags.FLAGS

BATCH_SIZE = 64


def _load_train_dataset(split):
  """Provides training data for specified 'split'."""
  return datasets.get_datacenter_train(split)


def _load_test_dataset(split):
  """Provides test data for specified 'split'."""
  return datasets.get_datacenter_eval(split)


def _load_and_preprocess_datasets():
  """Load raw EMNIST data and preprocess images and labels."""
  # Raw text datasets.
  train_dataset = _load_train_dataset(FLAGS.dataset_restriction)
  test_dataset_all = _load_test_dataset('all')
  test_dataset_so = _load_test_dataset('stackoverflow')
  test_dataset_wiki = _load_test_dataset('wikipedia')

  # Preprocessed text datasets.
  preprocessed_train_dataset = data_utils.preprocess_text_dataset(
      train_dataset.take(200000), batch_size=BATCH_SIZE, shuffle=True)
  preprocessed_test_dataset_all = data_utils.preprocess_text_dataset(
      test_dataset_all.take(50000), batch_size=BATCH_SIZE, shuffle=True)
  preprocessed_test_dataset_so = data_utils.preprocess_text_dataset(
      test_dataset_so.take(50000), batch_size=BATCH_SIZE, shuffle=True)
  preprocessed_test_dataset_wiki = data_utils.preprocess_text_dataset(
      test_dataset_wiki.take(50000), batch_size=BATCH_SIZE, shuffle=True)

  return (preprocessed_train_dataset, preprocessed_test_dataset_all,
          preprocessed_test_dataset_so, preprocessed_test_dataset_wiki)


class _NumExamplesCounter(tf.keras.metrics.Sum):
  """A `tf.keras.metrics.Metric` that counts the number of examples seen."""

  def __init__(self, name='num_examples', dtype=tf.int64):
    super().__init__(name, dtype)

  def update_state(self, y_true, y_pred, sample_weight=None):
    del y_true
    if isinstance(y_pred, list):
      y_pred = y_pred[0]
    return super().update_state(tf.shape(y_pred)[0], sample_weight)


def _train_and_evaluate(train_dataset,
                        test_dataset_all,
                        test_dataset_so,
                        test_dataset_wiki,
                        epochs):
  """Train (and evaluate) the prediction model defined in this module."""

  # Model.
  logits_model = prediction_model.get_next_char_prediction_model()
  # Loss and Optimizer.
  loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  optimizer = tf.keras.optimizers.Adam()

  print('Compiling Keras model...')
  logits_model.compile(
      optimizer=optimizer,
      loss=loss,
      metrics=[prediction_model.FlattenedCategoricalAccuracy(),
               _NumExamplesCounter(),])
  print('... done.')

  # Training.
  print('Training Keras model...')
  history = logits_model.fit(
      train_dataset, epochs=epochs, verbose=2, steps_per_epoch=4000)
  print('... done.')

  print('Train metrics...')
  print(history.history)

  # Evaluation.
  print('Evaluating Keras model (on all)...')
  eval_loss_all, eval_accuracy_all, eval_num_examples_all = logits_model.evaluate(
      test_dataset_all)
  print('... done.')
  print('Evaluation loss and accuracy (on all data in the overall dataset) '
        'are...')
  print(eval_loss_all)
  print(eval_accuracy_all)
  print(eval_num_examples_all)

  print('Evaluating Keras model (on stackoverflow)...')
  eval_loss_so, eval_accuracy_so, eval_num_examples_so = logits_model.evaluate(
      test_dataset_so)
  print('... done.')
  print('Evaluation loss and accuracy (on stackoverflow data) are...')
  print(eval_loss_so)
  print(eval_accuracy_so)
  print(eval_num_examples_so)

  print('Evaluating Keras model (on wikipedia)...')
  eval_loss_wiki, eval_accuracy_wiki, eval_num_examples_wiki = logits_model.evaluate(
      test_dataset_wiki)
  print('... done.')
  print('Evaluation loss and accuracy (on wikipedia data) are...')
  print(eval_loss_wiki)
  print(eval_accuracy_wiki)
  print(eval_num_examples_wiki)

  return logits_model


def _save(model, path_to_save_model_checkpoint):
  model.save_weights(path_to_save_model_checkpoint, save_format='h5')


def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Datasets.
  print('Loading datasets...')
  (preprocessed_train_dataset,
   preprocessed_test_dataset_all,
   preprocessed_test_dataset_so,
   preprocessed_test_dataset_wiki) = _load_and_preprocess_datasets()
  print('... done.')

  # Train and Evaluate Model.
  model = _train_and_evaluate(
      preprocessed_train_dataset, preprocessed_test_dataset_all,
      preprocessed_test_dataset_so, preprocessed_test_dataset_wiki,
      epochs=FLAGS.epochs)

  # Save Model.
  _save(model, FLAGS.path_to_save_model_checkpoint)


if __name__ == '__main__':
  app.run(main)

