# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Script to centrally train CelebA classifier defined in classifier_model.py.

If wishing to centrally train from scratch your own CelebA attribute classifier,
you can use this script.

To run, first build the binary:
  bazel build /path/to/this/script:centrally_train_classifier_model
Then execute it:
  ./path/to/binary/centrally_train_classifier_model --epochs=3
"""

from absl import app
from absl import flags
import tensorflow as tf

from mixed_fl.experiments.celeba import classifier_model
from mixed_fl.experiments.celeba import data_utils
from mixed_fl.experiments.celeba import datasets

flags.DEFINE_string(
    'path_to_save_model_checkpoint', '/tmp/celeba_classifier_checkpoint.hd5',
    'The path/name at which to save a trained checkpoint of the '
    'CelebA classifier model, after --epochs of training.')
flags.DEFINE_integer('epochs', 1, 'The # of epochs to train for.')
flags.DEFINE_enum(
    'dataset_restriction', 'all',
    datasets.get_possible_dataset_splits(),
    'What restrictions to put on the training dataset. E.g., if '
    '`no_facial_hair`, then any CelebA example is removed where the attributes '
    'indicate presence of mustache, goatee, or beard. If `all`, no data is '
    'filtered.')

FLAGS = flags.FLAGS

BATCH_SIZE = 32
STEPS_PER_EPOCH = 4000
LABEL_ATTRIBUTE = 'smiling'


def _load_train_dataset(split):
  """Provides training data for specified 'split'."""
  return datasets.get_datacenter_train(split)


def _load_test_dataset(split):
  """Provides test data for specified 'split'."""
  return datasets.get_datacenter_eval(split)


def _load_and_preprocess_datasets():
  """Load raw CelebA data and preprocess images and labels."""
  # Raw image datasets.
  train_dataset = _load_train_dataset(FLAGS.dataset_restriction)
  test_dataset_split = _load_test_dataset(FLAGS.dataset_restriction)
  test_dataset_all = _load_test_dataset('all')

  # Preprocessed image datasets.
  preprocessed_train_dataset = data_utils.preprocess_img_dataset(
      train_dataset, label_attribute=LABEL_ATTRIBUTE, batch_size=BATCH_SIZE,
      shuffle=True)
  preprocessed_test_dataset_split = data_utils.preprocess_img_dataset(
      test_dataset_split, label_attribute=LABEL_ATTRIBUTE,
      batch_size=BATCH_SIZE, shuffle=False)
  preprocessed_test_dataset_all = data_utils.preprocess_img_dataset(
      test_dataset_all, label_attribute=LABEL_ATTRIBUTE, batch_size=BATCH_SIZE,
      shuffle=False)

  return (preprocessed_train_dataset, preprocessed_test_dataset_split,
          preprocessed_test_dataset_all)


class _NumExamplesCounter(tf.keras.metrics.Sum):
  """A `tf.keras.metrics.Metric` that counts the number of examples seen."""

  def __init__(self, name='num_examples', dtype=tf.int64):
    super().__init__(name, dtype)

  def update_state(self, y_true, y_pred, sample_weight=None):
    del y_true
    if isinstance(y_pred, list):
      y_pred = y_pred[0]
    return super().update_state(tf.shape(y_pred)[0], sample_weight)


def _train_and_evaluate(preprocessed_train_dataset,
                        preprocessed_test_dataset_split,
                        preprocessed_test_dataset_all,
                        epochs):
  """Train (and evaluate) the classifier model defined in this module."""

  # Model (note: outputs a logit, not probability).
  logits_model = (
      classifier_model.get_celeba_attribute_binary_classifier_model())
  # For training purposes, we need to make use of this model, which outputs the
  # probabilities directly (i.e., post sigmoid). The outputs of this model will
  # be what's used to calculate loss during the course of training/evaluation.
  inputs = tf.keras.Input(shape=(84, 84, 3))  # Returns a placeholder tensor
  probs_model = tf.keras.Model(
      inputs=inputs, outputs=tf.math.sigmoid(logits_model(inputs)))

  # Loss and Optimizer.
  loss = tf.keras.losses.BinaryCrossentropy()
  optimizer = tf.keras.optimizers.Adam()

  probs_model.compile(
      optimizer=optimizer,
      loss=loss,
      metrics=[tf.keras.metrics.BinaryAccuracy(name='accuracy'),
               tf.keras.metrics.AUC(name='auc'),
               tf.keras.metrics.FalsePositives(name='fp'),
               tf.keras.metrics.FalseNegatives(name='fn'),
               _NumExamplesCounter(),])

  # Training.
  print('Training Keras model...')
  history = probs_model.fit(preprocessed_train_dataset, epochs=epochs)
  print('... done.')

  print('Train metrics...')
  print(history.history)

  # Evaluation.
  eval_loss_split, eval_accuracy_split, eval_auc_split, eval_fp_split, eval_fn_split, eval_num_examples = probs_model.evaluate(
      preprocessed_test_dataset_split)
  print('Evaluation loss and accuracy (on same dataset split as used in '
        'training) are...')
  print(eval_loss_split)
  print(eval_accuracy_split)
  print(eval_auc_split)
  print(eval_fp_split)
  print(eval_fn_split)
  print(eval_num_examples)

  eval_loss_all, eval_accuracy_all, eval_auc_all, eval_fp_all, eval_fn_all, eval_num_examples = probs_model.evaluate(
      preprocessed_test_dataset_all)
  print('Evaluation loss and accuracy (on all data in the overall dataset) '
        'are...')
  print(eval_loss_all)
  print(eval_accuracy_all)
  print(eval_auc_all)
  print(eval_fp_all)
  print(eval_fn_all)
  print(eval_num_examples)

  return logits_model


def _save(model, path_to_save_model_checkpoint):
  model.save_weights(path_to_save_model_checkpoint, save_format='h5')


def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Datasets.
  (preprocessed_train_dataset, preprocessed_test_dataset_split,
   preprocessed_test_dataset_all) = _load_and_preprocess_datasets()

  # Train and Evaluate Model.
  model = _train_and_evaluate(
      preprocessed_train_dataset,
      preprocessed_test_dataset_split,
      preprocessed_test_dataset_all,
      epochs=FLAGS.epochs)

  # Save Model.
  _save(model, FLAGS.path_to_save_model_checkpoint)


if __name__ == '__main__':
  app.run(main)
