# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Script to perform mixed learning of federated and datacenter data."""

import asyncio
import collections
import functools
import os.path
import random
from typing import Any, Callable, Dict

from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
import tensorflow_federated as tff

from mixed_fl import mixing_process_lib
from mixed_fl import process_with_gradient_transfer_lib
from mixed_fl.experiments import datasets
from mixed_fl.experiments import tasks


SERVER_OPTIMIZER_FNS_MAP = {
    'sgd': tf.keras.optimizers.SGD,
    'adam': functools.partial(
        tf.keras.optimizers.Adam, beta_2=0.99, epsilon=0.001)
}


# OPTIMIZATION FLAGS.
flags.DEFINE_float(
    'client_learning_rate', 0.02, 'Learning rate to use during the client '
    'optimization. Note that in the case of mixing with parallel datacenter '
    'training, the datacenter learning rate is derived as an \'equivalent\' '
    'to this learning rate, adjusting for the difference in batch sizes.')
flags.DEFINE_float(
    'server_learning_rate', 1.0, 'Learning rate to use during the server '
    'optimization.')
flags.DEFINE_enum(
    'server_optimizer', 'sgd', list(SERVER_OPTIMIZER_FNS_MAP.keys()), 'Which '
    'optimizer to use when aggregating the client updates at the server.')
# MIXING FLAGS.
_MIXING_MODE = flags.DEFINE_enum(
    'mixing_mode', None, [
        'parallel_training', 'grad_transfer', 'two_way_grad_transfer',
        'example_transfer'
    ], 'Which variation of federated/datacenter mixed training to employ. In '
    'short: `parallel_training` involves training on datacenter data and '
    'then weighted averaging the results with the results from a round of '
    'FL, `grad_transfer` uses datacenter data to compute gradients that are '
    'then sent to clients to augment the local client gradients, '
    '`two_way_grad_transfer` is similar but sends augmenting gradients in '
    'both directions (client->datacenter as well as datacenter->client), '
    '`example_transfer` involves adding some datacenter data to the client '
    'caches while performing FL.')
flags.DEFINE_float(
    'mixing_datacenter_weight', 1.0, 'A scalar weight for how much the '
    'optimization should be influenced by the datacenter data. This should '
    'be chosen in concert with the `mixing_client_weight`, to properly '
    'weight the two terms relative to each other as desired. These weights '
    'are used in different ways, depending on --mixing_mode. In '
    '*_gradient_transfer, they are used to weight the local and augmenting '
    'gradients when summing them together. In parallel_training and '
    'two_way_gradient_transfer, they are used to relatively weight the '
    'centralized and federated model updates when merging at end of the '
    'round.')
flags.DEFINE_float(
    'mixing_client_weight', 1.0, 'A scalar weight for how much the '
    'optimization should be influenced by the federated data. This should be '
    'chosen in concert with the `mixing_datacenter_weight`, to properly '
    'weight the two terms relative to each other as desired. These weights '
    'are used in different ways, depending on --mixing_mode. In '
    '*_gradient_transfer, they are used to weight the local and augmenting '
    'gradients when summing them together. In parallel_training and '
    'two_way_gradient_transfer, they are used to relatively weight the '
    'centralized and federated model updates when merging at end of the '
    'round.')
flags.DEFINE_integer(
    'parallel_training_num_effective_clients_for_training', 1, 'If '
    '--mixing_mode=parallel_training, this parameter is used to set how much '
    'datacenter data is used in the datacenter training. This number is '
    'multiplied with the --parallel_training_effective_client_cache_size to '
    'get the total amount of training data used.')
flags.DEFINE_integer(
    'parallel_training_effective_client_cache_size', 0, 'If '
    '--mixing_mode=parallel_training, this parameter is used to set how much '
    'datacenter data is trained on in the centralized training process. This '
    'number is multiplied with the '
    '--parallel_training_num_effective_clients_for_training to get the total '
    'amount of training data used.')
flags.DEFINE_float(
    'parallel_training_central_optimizer_global_clipnorm', None, 'If '
    '--mixing_mode=parallel_training, the value to clip the l2 norm of the '
    'gradient at each step of central training. If `None`, no clipping is '
    'applied.')
flags.DEFINE_integer(
    'grad_transfer_batch_size_multiplier', 1, 'Sets the size of the batch '
    'used to compute the augmenting gradients. A multiplicative factor on '
    'top of the batch size used in client training.')
flags.DEFINE_enum_class(
    'two_way_grad_transfer_gradients_option', None,
    mixing_process_lib.GradientsComputationOption,
    'Which option will be used when calculating (at the datacenter) the '
    'gradients that will be used for augmenting at the clients.')
flags.DEFINE_float(
    'two_way_grad_transfer_central_optimizer_global_clipnorm', None, 'If '
    '--mixing_mode=two_way_grad_transfer, value to clip the l2 norm of the '
    'gradient at each step of central training. If `None`, no clipping is '
    'applied.')
flags.DEFINE_integer(
    'example_transfer_num_examples', 1, 'The number of examples of '
    'datacenter data to add to a client cache, if '
    '--mixing_mode=example_transfer.')
flags.DEFINE_integer(
    'example_transfer_num_repetitions', 1, 'The number of times to repeat '
    'using each unique datacenter example during training on a client, if '
    '--mixing_mode=example_transfer. The total number of unique datacenter '
    'examples used per client will be `example_transfer_num_examples` '
    'divided by `example_transfer_num_repetitions`')
flags.DEFINE_integer(
    'num_epochs', 1, 'The number of epochs to repeat training over the '
    'client dataset.')
_CLIENT_BATCH_SIZE = flags.DEFINE_integer(
    'client_batch_size', None, 'The batch size to use in client training. '
    'Must be specified.')
flags.DEFINE_integer(
    'limit_num_batches', -1, 'If set to value greater than 0, only this many '
    'batches are processed during client training (or datacenter training). '
    'If less than or equal to zero, training will go through the entire '
    'cache.')
# DATASET FLAGS.
flags.DEFINE_enum(
    'dataset', None, ['emnist', 'celeba', 'ncp'], 'Which dataset to '
    'experiment on. Must be specified.')
_CLIENT_REST = flags.DEFINE_enum(
    'client_restriction', None, datasets.get_all_possible_splits(),
    'This flag indicates what restrictions to put on the federated train '
    'dataset. The restriction must match the dataset selected via --dataset '
    'flag; e.g., if --dataset=emnist, then this can be `only_digits`, but if '
    'instead `no_facial_hair` is specified, a ValueError is raised (as that '
    'restriction only applies to the `celeba` dataset). If `all`, no data is '
    'filtered.')
_DATACENTER_REST = flags.DEFINE_enum(
    'datacenter_restriction', None, datasets.get_all_possible_splits(),
    'This flag indicates what restrictions to put on the datacenter training '
    'dataset. The restriction must match the dataset selected via --dataset '
    'flag; e.g., if --dataset=emnist, then this can be `only_digits`, but if '
    'instead `no_facial_hair` is specified, a ValueError is raised (as that '
    'restriction only applies to the `celeba` dataset). If `all`, no data is '
    'filtered.')
_EVAL_REST = flags.DEFINE_list(
    'eval_restriction', None,
    'This flag indicates what restrictions to put on the federated eval '
    'dataset(s). The restrictions must match the dataset selected via '
    ' --dataset flag; e.g., if --dataset=emnist, then this can be '
    '`only_digits`, but if instead `no_facial_hair` is specified, a '
    'ValueError is raised (as that restriction only applies to the `celeba` '
    'dataset). If `all`, no data is filtered. Note this flag can take '
    'multiple arguments, for running multiple evaluations.')
# GENERAL SIMULATION FLAGS.
flags.DEFINE_string(
    'experiment_name', 'test',
    'Unique name for the experiment, suitable for use in filenames.')
flags.DEFINE_string('root_output_dir', '/tmp/mixing/',
                    'Base directory to write experimental output.')
flags.DEFINE_integer(
    'rounds_per_checkpoint', 10,
    'How often to write and save a model checkpoint. These checkpoints are '
    'most typically used in case of a simulation restart, e.g., due to Borg '
    'preemption. If 0, no checkpoints are saved.')
flags.DEFINE_integer(
    'rounds_per_metrics', 10,
    'How often to write and save metrics, to CSV and tensorboard. If 0, no '
    'metrics are saved.')
flags.DEFINE_integer(
    'rounds_per_eval', 10, 'How often to calculate evaluation metrics. If 0, '
    'no evaluation metrics are calculated.')
flags.DEFINE_boolean(
    'central_eval', True, 'Whether to perform the evaluations as centralized '
    'computations.')
flags.DEFINE_boolean(
    'federated_eval', True, 'Whether to perform the evaluations as federated '
    'computations.')
flags.DEFINE_integer(
    'total_rounds', 10, 'The total # of federated rounds to run for.')
flags.DEFINE_integer('clients_per_round', 10,
                     'The # of clients participating in a federated round.')

FLAGS = flags.FLAGS

DATACENTER_SHUFFLE_BUFFER = 700000


def validate_flag_settings():
  """Check that input flags have valid values."""
  flags.mark_flag_as_required('dataset')
  flags.mark_flag_as_required('client_batch_size')

  if _CLIENT_BATCH_SIZE.value is None:
    raise ValueError(
        'The --%s flag must be specified (cannot be left `None`).' %
        _CLIENT_BATCH_SIZE.name)

  if _CLIENT_REST.value is None:
    raise ValueError(
        'The --%s flag must be specified (cannot be left `None`).' %
        _CLIENT_REST.name)

  if _MIXING_MODE.value is not None:
    if _DATACENTER_REST.value is None:
      raise ValueError(
          'The --%s flag must be specified (cannot be left `None`) if '
          '--mixing_mode flag is specified (not `None`).' %
          _DATACENTER_REST.name)

    if _MIXING_MODE.value == 'two_way_grad_transfer':
      if FLAGS.two_way_grad_transfer_gradients_option is None:
        raise ValueError(
            'The --two_way_grad_transfer_gradients_option flag must be '
            'specified (cannot be left `None`) if '
            '--mixing_mode=two_way_grad_transfer.')


def setup_metrics_manager(
):
  """Set up directory/object to record metrics."""
  results_dir = os.path.join(FLAGS.root_output_dir, 'results',
                             FLAGS.experiment_name)
  tf.io.gfile.makedirs(results_dir)

  metrics_file_path = os.path.join(results_dir, 'experiment.metrics.csv')
  metrics_manager = tff.program.CSVFileReleaseManager(metrics_file_path)

  logging.info('Writing metrics csv to: %s', metrics_file_path)
  return metrics_manager


def setup_checkpoint_manager():
  """Set up directory and return object to save/restore checkpoints."""
  checkpoint_dir = os.path.join(FLAGS.root_output_dir, 'checkpoints',
                                FLAGS.experiment_name)
  return tff.program.FileProgramStateManager(checkpoint_dir)


def possibly_write_checkpoint(
    checkpoint_manager,
    state,
    round_num):
  """Save a checkpoint if input flag indicates to save for this round number."""
  loop = asyncio.get_event_loop()
  # This is the checkpoint that initializes `round_num + 1`, a saved state of
  # the model training as of the end of `round_num`. I.e., if something on Borg
  # crashed and we want to restart at the latest checkpoint, we load this up,
  # and it's used to resume at the start of `round_num + 1`.
  if (FLAGS.rounds_per_checkpoint > 0 and
      (round_num + 1) % FLAGS.rounds_per_checkpoint == 0):
    loop.run_until_complete(checkpoint_manager.save(state, round_num + 1))


def setup_tensorboard_manager():
  """Set up directory for tf.summaries, and a summary writer."""
  summary_logdir = os.path.join(FLAGS.root_output_dir, 'logdir',
                                FLAGS.experiment_name)
  return tff.program.TensorBoardReleaseManager(summary_logdir)


def possibly_write_metrics(
    metrics_manager,
    tensorboard_manager,
    metrics, round_num):
  """Write metrics to csv file, and write scalar metrics as tf.summary."""
  loop = asyncio.get_event_loop()
  if (FLAGS.rounds_per_metrics > 0 and
      (round_num % FLAGS.rounds_per_metrics == 0)):
    loop.run_until_complete(
        asyncio.gather(
            # Metrics for csv file.
            metrics_manager.release(metrics, round_num),
            # Metrics for tf.summary (tensorboard), write to tf.summary logdir.
            tensorboard_manager.release(metrics, round_num)))


def get_process(mixing_task_definition, client_data,
                datacenter_train_dataset_fn):
  """Get the appropriate Process to use."""
  client_optimizer_fn = lambda: tf.keras.optimizers.SGD(
      FLAGS.client_learning_rate)
  server_optimizer_fn = lambda: SERVER_OPTIMIZER_FNS_MAP[
      FLAGS.server_optimizer](FLAGS.server_learning_rate)

  client_dataset_processing_fn = (
      mixing_task_definition.get_client_dataset_processing_fn(
          FLAGS.limit_num_batches, FLAGS.num_epochs))

  iterative_process_fn = tff.learning.build_federated_averaging_process
  iterative_process_fn = functools.partial(
      iterative_process_fn,
      model_update_aggregation_factory=tff.learning.robust_aggregator())

  augmenting_iterative_process_fn = process_with_gradient_transfer_lib.build_federated_averaging_process_with_gradient_transfer
  augmenting_iterative_process_fn = functools.partial(
      augmenting_iterative_process_fn,
      model_update_aggregation_factory=tff.learning.robust_aggregator())

  if _MIXING_MODE.value == 'parallel_training':
    datacenter_learning_rate = (
        FLAGS.client_learning_rate * FLAGS.server_learning_rate)
    datacenter_optimizer_fn = lambda: tf.keras.optimizers.SGD(
        datacenter_learning_rate,
        global_clipnorm=FLAGS.
        parallel_training_central_optimizer_global_clipnorm)
    datacenter_loss_fn = mixing_task_definition.get_loss()

    datacenter_batch_size = (
        FLAGS.parallel_training_num_effective_clients_for_training *
        mixing_task_definition.client_batch_size)

    def datacenter_metrics_fn():
      return mixing_task_definition.get_metrics()

    return mixing_process_lib.build_mixing_process_with_parallel_training(
        keras_model_fn=mixing_task_definition.get_keras_model,
        tff_model_fn=mixing_task_definition.get_tff_model,
        client_data=client_data,
        datacenter_dataset_fn=datacenter_train_dataset_fn,
        num_effective_clients_for_training=(
            FLAGS.parallel_training_num_effective_clients_for_training),
        num_examples_per_effective_client=(
            FLAGS.parallel_training_effective_client_cache_size *
            FLAGS.num_epochs),
        datacenter_loss_fn=datacenter_loss_fn,
        client_optimizer_fn=client_optimizer_fn,
        datacenter_optimizer_fn=datacenter_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        datacenter_metrics_fn=datacenter_metrics_fn,
        client_dataset_processing_fn=client_dataset_processing_fn,
        datacenter_dataset_processing_fn=(
            mixing_task_definition.get_datacenter_dataset_processing_fn(
                datacenter_batch_size=datacenter_batch_size,
                limit_num_batches=FLAGS.limit_num_batches)),
        datacenter_shuffle_buffer=DATACENTER_SHUFFLE_BUFFER,
        datacenter_weight=FLAGS.mixing_datacenter_weight,
        client_weight=FLAGS.mixing_client_weight,
        iterative_process_fn=iterative_process_fn)

  elif _MIXING_MODE.value == 'grad_transfer':
    datacenter_loss_fn = mixing_task_definition.get_loss()

    datacenter_batch_size = (FLAGS.grad_transfer_batch_size_multiplier *
                             mixing_task_definition.client_batch_size)

    return mixing_process_lib.build_mixing_process_with_gradient_transfer(
        keras_model_fn=mixing_task_definition.get_keras_model,
        tff_model_fn=mixing_task_definition.get_tff_model,
        client_data=client_data,
        datacenter_dataset_fn=datacenter_train_dataset_fn,
        datacenter_batch_size=datacenter_batch_size,
        datacenter_loss_fn=datacenter_loss_fn,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_dataset_processing_fn=client_dataset_processing_fn,
        datacenter_dataset_processing_fn=(
            mixing_task_definition.get_datacenter_dataset_processing_fn(
                datacenter_batch_size=datacenter_batch_size,
                limit_num_batches=FLAGS.limit_num_batches)),
        datacenter_shuffle_buffer=DATACENTER_SHUFFLE_BUFFER,
        datacenter_gradient_weight=FLAGS.mixing_datacenter_weight,
        client_gradient_weight=FLAGS.mixing_client_weight,
        augmenting_iterative_process_fn=augmenting_iterative_process_fn)

  elif _MIXING_MODE.value == 'two_way_grad_transfer':
    datacenter_learning_rate = (
        FLAGS.client_learning_rate * FLAGS.server_learning_rate)
    datacenter_optimizer_fn = lambda: tf.keras.optimizers.SGD(
        datacenter_learning_rate,
        global_clipnorm=FLAGS.
        two_way_grad_transfer_central_optimizer_global_clipnorm)
    datacenter_loss_fn = mixing_task_definition.get_loss()

    datacenter_batch_size = (
        FLAGS.parallel_training_num_effective_clients_for_training *
        mixing_task_definition.client_batch_size)

    def datacenter_metrics_fn():
      return mixing_task_definition.get_metrics()

    if isinstance(FLAGS.two_way_grad_transfer_gradients_option, str):
      two_way_grad_transfer_gradients_option = (
          mixing_process_lib.GradientsComputationOption[
              FLAGS.two_way_grad_transfer_gradients_option])
    else:
      two_way_grad_transfer_gradients_option = (
          FLAGS.two_way_grad_transfer_gradients_option)

    return mixing_process_lib.build_mixing_process_with_two_way_gradient_transfer(
        keras_model_fn=mixing_task_definition.get_keras_model,
        tff_model_fn=mixing_task_definition.get_tff_model,
        client_data=client_data,
        datacenter_dataset_fn=datacenter_train_dataset_fn,
        num_effective_clients_for_training=(
            FLAGS.parallel_training_num_effective_clients_for_training),
        num_examples_per_effective_client=(
            FLAGS.parallel_training_effective_client_cache_size *
            FLAGS.num_epochs),
        datacenter_loss_fn=datacenter_loss_fn,
        client_optimizer_fn=client_optimizer_fn,
        datacenter_optimizer_fn=datacenter_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        datacenter_metrics_fn=datacenter_metrics_fn,
        client_dataset_processing_fn=client_dataset_processing_fn,
        datacenter_dataset_processing_fn=(
            mixing_task_definition.get_datacenter_dataset_processing_fn(
                datacenter_batch_size=datacenter_batch_size,
                limit_num_batches=FLAGS.limit_num_batches)),
        datacenter_shuffle_buffer=DATACENTER_SHUFFLE_BUFFER,
        datacenter_weight=FLAGS.mixing_datacenter_weight,
        client_weight=FLAGS.mixing_client_weight,
        augmenting_iterative_process_fn=augmenting_iterative_process_fn,
        augmenting_datacenter_gradients_option=(
            two_way_grad_transfer_gradients_option))

  elif _MIXING_MODE.value == 'example_transfer':
    return mixing_process_lib.build_mixing_process_with_example_transfer(
        tff_model_fn=mixing_task_definition.get_tff_model,
        client_data=client_data,
        datacenter_dataset_fn=datacenter_train_dataset_fn,
        num_examples_to_augment=FLAGS.example_transfer_num_examples,
        num_repetitions=FLAGS.example_transfer_num_repetitions,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_dataset_processing_fn=client_dataset_processing_fn,
        datacenter_shuffle_buffer=DATACENTER_SHUFFLE_BUFFER,
        iterative_process_fn=iterative_process_fn)

  else:
    iterative_process = iterative_process_fn(
        mixing_task_definition.get_tff_model,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn)
    # Compose the iterative process with the data processing step; this will
    # take place at the clients.
    @tff.tf_computation(tf.string)
    def process_client_data_fn(client_id):
      return client_dataset_processing_fn(
          client_data.dataset_computation(client_id))
    return (
        tff.simulation.compose_dataset_computation_with_iterative_process(
            process_client_data_fn, iterative_process))


def do_centralized_eval_metrics(model_weights, eval_fn):
  eval_metrics = eval_fn(model_weights)
  return eval_metrics


def do_federated_eval_metrics(model_weights, federated_eval_data,
                              eval_computation):
  """Calculate metrics against the federated evaluation dataset."""
  # Randomly select clients and make training datasets with their data.
  sample_eval_clients = random.sample(federated_eval_data.client_ids,
                                      FLAGS.clients_per_round)
  eval_metrics = eval_computation(model_weights, sample_eval_clients)
  return eval_metrics


def do_federated_round(state, federated_train_data, process):
  """Do one round of federated training."""
  # Randomly select clients.
  sample_train_clients = random.sample(federated_train_data.client_ids,
                                       FLAGS.clients_per_round)
  # One round of training with the selected clients.
  state, federated_train_metrics = process.next(state, sample_train_clients)
  return state, federated_train_metrics


def do_federated_rounds(
    init_round_num, num_rounds, state, federated_train_data, process,
    central_eval_fns, federated_eval_data_and_computations, metrics_manager,
    tensorboard_manager, checkpoint_manager):
  """Do multiple rounds of federated training."""

  for round_num in range(init_round_num, init_round_num + num_rounds):
    print('Federated round number %d' % round_num)
    logging.info('Federated round number %d', round_num)
    state, metrics = do_federated_round(state, federated_train_data, process)

    if (FLAGS.rounds_per_eval > 0 and (round_num % FLAGS.rounds_per_eval == 0)):

      if FLAGS.central_eval:
        for restriction, eval_fn in central_eval_fns:
          logging.info(
              'Starting central eval for restriction = %s', restriction)
          metrics['eval_cent_' + restriction] = do_centralized_eval_metrics(
              state.model, eval_fn)
          logging.info(
              'Finished central eval for restriction = %s', restriction)

      if FLAGS.federated_eval:
        for restriction, data, computation in federated_eval_data_and_computations:
          logging.info(
              'Starting federated eval for restriction = %s', restriction)
          metrics['eval_' + restriction] = do_federated_eval_metrics(
              state.model, data, computation)
          logging.info(
              'Finished federated eval for restriction = %s', restriction)

    possibly_write_metrics(metrics_manager, tensorboard_manager, metrics,
                           round_num)
    possibly_write_checkpoint(checkpoint_manager, state, round_num)

  return state, round_num + 1


def get_mixing_task_definition():
  """Get definition of mixing task (model, loss, metrics, data processing)."""
  # Data structure containing all the building blocks (model, loss, metrics,
  # data processing, etc.) for a specific mixed FL task.
  return tasks.get_mixing_task_definition(
      task=FLAGS.dataset, client_batch_size=FLAGS.client_batch_size)


def get_federated_train_data():
  """Provides the federated training data."""
  return datasets.get_federated_train_data(FLAGS.dataset, _CLIENT_REST.value)


def get_federated_eval_data(split):
  """Provides the federated evaluation data."""
  return datasets.get_federated_eval_data(FLAGS.dataset, split)


def get_datacenter_train_dataset_fn():
  """Provides a callable returning the datacenter training data."""
  return datasets.get_datacenter_train_dataset_fn(FLAGS.dataset,
                                                  _DATACENTER_REST.value)


def get_datacenter_eval_dataset_fn(split):
  """Provides a callable returning the datacenter evaluation data."""
  return datasets.get_datacenter_eval_dataset_fn(FLAGS.dataset, split)


def mixed_training():
  """Runs the TFF simulation of mixed training, round-by-round."""
  loop = asyncio.get_event_loop()

  # Utilities that save metrics and checkpoints, and allow for resumption from
  # latest checkpoint, e.g. in the event that a job is preempted on Borg.
  metrics_manager = setup_metrics_manager()
  checkpoint_manager = setup_checkpoint_manager()
  tensorboard_manager = setup_tensorboard_manager()

  # Data structure containing all the building blocks (model, loss, metrics,
  # data processing, etc.) for a specific mixed FL task.
  mixing_task_definition = get_mixing_task_definition()
  # Load the federated training dataset.
  federated_train_data = get_federated_train_data()
  # Get a fn that can load the 'datacenter' training dataset. OK to be None if
  # no mixing is taking place.
  datacenter_train_dataset_fn = (None if _MIXING_MODE.value is None else
                                 get_datacenter_train_dataset_fn())
  # Create an iterative process or mixing process, and initialize.
  process = get_process(mixing_task_definition, federated_train_data,
                        datacenter_train_dataset_fn)
  init_state = process.initialize()
  # Check for latest checkpoint, in case of restart e.g. due to Borg preemption.
  state, round_num = loop.run_until_complete(
      checkpoint_manager.load_latest(init_state))
  if state is None:
    state = init_state
    round_num = 0
    loop.run_until_complete(checkpoint_manager.save(state, round_num))

  # Setup centralized and federated evaluation tools, which will be used for
  # evaluation of the model that's been trained with both federated training
  # data and datacenter training data.
  central_eval_fns = []
  federated_eval_data_and_computations = []

  client_dataset_processing_fn_for_eval = (
      mixing_task_definition.get_client_dataset_processing_fn())
  for restriction in _EVAL_REST.value:
    if FLAGS.central_eval:
      # Create a centralized evaluation computation (to be run after federated
      # training rounds).
      eval_fn = get_centralized_eval_fn(
          model_fn=mixing_task_definition.get_tff_model,
          dataset_fn=get_datacenter_eval_dataset_fn(restriction),
          mixing_task_definition=mixing_task_definition)
      central_eval_fns.append((restriction, eval_fn))

    if FLAGS.federated_eval:
      data = get_federated_eval_data(restriction)
      # Create federated evaluation computations (to be run after federated
      # training rounds).
      computation = tff.learning.build_federated_evaluation(
          mixing_task_definition.get_tff_model)
      # Compose the eval computation with the data processing step; this will
      # take place at the clients.
      @tff.tf_computation(tf.string)
      def process_client_data_fn_for_eval(client_id):
        return client_dataset_processing_fn_for_eval(
            data.dataset_computation(client_id))
      computation = (
          tff.simulation.compose_dataset_computation_with_computation(
              process_client_data_fn_for_eval, computation))
      federated_eval_data_and_computations.append(
          (restriction, data, computation))

  # The federated training.
  print('Federated training rounds...')
  logging.info(
      'Federated training rounds...')
  do_federated_rounds(round_num, FLAGS.total_rounds - round_num, state,
                      federated_train_data, process, central_eval_fns,
                      federated_eval_data_and_computations, metrics_manager,
                      tensorboard_manager, checkpoint_manager)


def get_centralized_eval_fn(
    model_fn,
    dataset_fn,
    mixing_task_definition):
  """Gets fn that calculates eval metrics for model weights."""
  shuffle_buffer = 100000
  eval_batch_size = 1000

  dataset_iterator = iter(dataset_fn().shuffle(
      buffer_size=shuffle_buffer,
      reshuffle_each_iteration=True).repeat().window(size=eval_batch_size))

  dataset_processing_fn = (
      mixing_task_definition.get_datacenter_dataset_processing_fn(
          datacenter_batch_size=eval_batch_size))

  # @tf.function
  def _centralized_eval_fn(model, dataset, sum_then_finalize):
    """Returns outputs after evaluting `incoming_model_weights` on `dataset`."""
    def reduce_fn(state, batch):
      num_examples = state
      model_output = model.forward_pass(batch, training=False)
      if model_output.num_examples is None:
        # Compute shape from the size of the predictions if model didn't use the
        # batch size.
        return num_examples + tf.shape(
            model_output.predictions, out_type=tf.int64)[0]
      else:
        return num_examples + tf.cast(model_output.num_examples, tf.int64)

    initial_state = tf.zeros([], dtype=tf.int64)
    num_examples = dataset.reduce(initial_state, reduce_fn)

    finalized_metrics = sum_then_finalize(
        [model.report_local_unfinalized_metrics()])

    return collections.OrderedDict(
        eval=finalized_metrics, num_examples=num_examples)

  def centralized_eval_fn(incoming_model_weights):
    model = model_fn()
    model_weights = tff.learning.ModelWeights.from_model(model)
    tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                          incoming_model_weights)

    sum_then_finalize = tff.learning.metrics.sum_then_finalize(
        model.metric_finalizers(),
        tff.framework.type_from_tensors(
            model.report_local_unfinalized_metrics()))

    processed_dataset = dataset_processing_fn(
        tf.data.Dataset.zip(next(dataset_iterator)))
    return _centralized_eval_fn(model, processed_dataset, sum_then_finalize)

  return centralized_eval_fn


def main(argv):
  del argv
  validate_flag_settings()
  mixed_training()


if __name__ == '__main__':
  app.run(main)
