# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Script to centrally train EMNIST classifier defined in classifier_model.py.

If wishing to centrally train from scratch your own EMNIST classifier, you can
use this script.

To run, first build the binary:
  bazel build /path/to/this/script:centrally_train_classifier_model
Then execute it:
  ./path/to/binary/centrally_train_classifier_model --epochs=3
"""

from absl import app
from absl import flags
import tensorflow as tf

from mixed_fl.experiments.emnist import classifier_model
from mixed_fl.experiments.emnist import data_utils
from mixed_fl.experiments.emnist import datasets

flags.DEFINE_string(
    'path_to_save_model_checkpoint', '/tmp/emnist_classifier_checkpoint.hd5',
    'The path/name at which to save a trained checkpoint of the '
    'EMNIST classifier model, after --epochs of training.')
flags.DEFINE_integer('epochs', 1, 'The # of epochs to train for.')
flags.DEFINE_enum(
    'dataset_restriction', 'only_digits_and_lowercase',
    datasets.get_possible_dataset_splits(),
    'What restrictions to put on the training dataset. If `only_digits`, then '
    'only digits are in the dataset. If `only_digits_and_uppercase`, then only '
    'digits and uppercase letters are in the dataset, etc. If `all`, no data '
    'is filtered.')

FLAGS = flags.FLAGS

BATCH_SIZE = 32


def _load_train_dataset(split):
  """Provides training data for specified 'split'."""
  return datasets.get_datacenter_train(split)


def _load_test_dataset(split):
  """Provides test data for specified 'split'."""
  return datasets.get_datacenter_eval(split)


def _load_and_preprocess_datasets():
  """Load raw EMNIST data and preprocess images and labels."""
  # Raw image datasets.
  train_dataset = _load_train_dataset(FLAGS.dataset_restriction)
  test_dataset_split = _load_test_dataset(FLAGS.dataset_restriction)
  test_dataset_all = _load_test_dataset('all')

  # Preprocessed image datasets.
  preprocessed_train_dataset = data_utils.preprocess_img_dataset(
      train_dataset, batch_size=BATCH_SIZE, shuffle=True)
  preprocessed_test_dataset_split = data_utils.preprocess_img_dataset(
      test_dataset_split, batch_size=BATCH_SIZE, shuffle=False)
  preprocessed_test_dataset_all = data_utils.preprocess_img_dataset(
      test_dataset_all, batch_size=BATCH_SIZE, shuffle=False)

  return (preprocessed_train_dataset, preprocessed_test_dataset_split,
          preprocessed_test_dataset_all)


def _train_and_evaluate(preprocessed_train_dataset,
                        preprocessed_test_dataset_split,
                        preprocessed_test_dataset_all,
                        epochs):
  """Train (and evaluate) the classifier model defined in this module."""

  # Model (note: outputs a logit, not probability).
  logits_model = classifier_model.get_emnist_classifier_model(num_classes=36)
  # Loss and Optimizer.
  loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  optimizer = tf.keras.optimizers.Adam()

  logits_model.compile(
      optimizer=optimizer,
      loss=loss,
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')])

  # Training.
  logits_model.fit(preprocessed_train_dataset, epochs=epochs)

  # Evaluation.
  eval_loss_split, eval_accuracy_split = logits_model.evaluate(
      preprocessed_test_dataset_split)
  print('Evaluation loss and accuracy (on same dataset split as used in '
        'training) are...')
  print(eval_loss_split)
  print(eval_accuracy_split)

  eval_loss_all, eval_accuracy_all = logits_model.evaluate(
      preprocessed_test_dataset_all)
  print('Evaluation loss and accuracy (on all data in the overall dataset) '
        'are...')
  print(eval_loss_all)
  print(eval_accuracy_all)

  return logits_model


def _save(model, path_to_save_model_checkpoint):
  model.save_weights(path_to_save_model_checkpoint, save_format='h5')


def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Datasets.
  (preprocessed_train_dataset, preprocessed_test_dataset_split,
   preprocessed_test_dataset_all) = _load_and_preprocess_datasets()

  # Train and Evaluate Model.
  model = _train_and_evaluate(
      preprocessed_train_dataset,
      preprocessed_test_dataset_split,
      preprocessed_test_dataset_all,
      epochs=FLAGS.epochs)

  # Save Model.
  _save(model, FLAGS.path_to_save_model_checkpoint)


if __name__ == '__main__':
  app.run(main)
