# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provides processes to mix datacenter training w/ federated training."""

import abc
import collections
import dataclasses
import enum

from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

import tensorflow as tf
import tensorflow_federated as tff
import typing_extensions

from mixed_fl import process_with_gradient_transfer_lib

IDENTITY_PROCESSING_FN = lambda x: x
TFF_LEARN_FED_AVG_FN = tff.learning.build_federated_averaging_process
GRAD_TRANSFER_FED_AVG_FN = process_with_gradient_transfer_lib.build_federated_averaging_process_with_gradient_transfer
DEFAULT_DATACENTER_SHUFFLE_BUFFER = 10000
DEFAULT_NUM_EXAMPLES_FROM_PREDICTIONS_FN = lambda x: tf.shape(x)[0]

# Type aliases.
_KerasModelConstructor = Callable[[], tf.keras.Model]
_TffModelConstructor = Callable[[], tff.learning.Model]
_MetricsConstructor = Callable[[], List[tf.keras.metrics.Metric]]
_DatasetConstructor = Callable[[], tf.data.Dataset]
_DatasetProcessingFn = Callable[[tf.data.Dataset], tf.data.Dataset]
_LossFn = Callable[[tf.Tensor, tf.Tensor], tf.Tensor]
_MetricsDict = Dict[str, Union[tf.Tensor, Dict[str, tf.Tensor]]]


class _RoundNumOptimizerConstructor(typing_extensions.Protocol):

  def __call__(
      self, round_num):
    pass


_NoArgOptimizerConstructor = Callable[[], tf.keras.optimizers.Optimizer]
_OptimizerConstructor = Union[_NoArgOptimizerConstructor,
                              _RoundNumOptimizerConstructor]


class _BasicIterativeProcessFn(typing_extensions.Protocol):

  def __call__(
      self, model_fn,
      client_optimizer_fn,
      server_optimizer_fn,
  ):
    pass


class _RoundNumIterativeProcessFn(typing_extensions.Protocol):

  def __call__(
      self, model_fn,
      client_optimizer_fn,
      server_optimizer_fn,
  ):
    pass


_IterativeProcessFn = Union[_BasicIterativeProcessFn,
                            _RoundNumIterativeProcessFn]


class MixingProcess(metaclass=abc.ABCMeta):
  """Establishes interface for what a mixing process should entail in TFF.

  It is purposely meant to closely resemble a tff.templates.IterativeProcess,
  and implementations of this interface probably wrap one (or more)
  tff.templates.IterativeProcess instances.

  The contract for any implementing classes is:
  - a no-arg `initialize` method which returns a
    tff.learning.framework.ServerState (or similar) and performs an setup.
  - a `next` method which takes the previous ServerState (or similar), takes a
    list of client ids for the clients participating in the given round, and
    returns the updated ServerState (or similar) as well as a dictionary of
    metrics.

  All particulars of individual mixing strategies are assumed to be implemented
  within the confines of the `next` method. The objective is to confine all
  behavior of sampling from datasets, training, updating variables, etc., so
  that a user doesn't need to maintain any state or track any variables (other
  than ServerState or similar).
  """

  @abc.abstractmethod
  def initialize(self):
    raise NotImplementedError

  @abc.abstractmethod
  def next(self, prior_state,
           federated_client_ids):
    raise NotImplementedError


_NumExamplesFromPredictionFn = Callable[[Any], tf.Tensor]
_GradientCalculationOutputs = Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor,
                                    tf.Tensor]


@tf.function
def _centralized_gradients_with_dicts(
    model,
    batched_processed_examples,
    loss_fn,
    metrics = None,
    get_num_examples_from_predictions_fn
     = DEFAULT_NUM_EXAMPLES_FROM_PREDICTIONS_FN
):
  """Calculate gradients, given a model, batched datacenter examples, and loss."""
  inputs = batched_processed_examples['x']
  labels = batched_processed_examples['y']
  return _centralized_gradients_with_tuples(
      model, (inputs, labels), loss_fn, metrics,
      get_num_examples_from_predictions_fn)


@tf.function
def _centralized_gradients_with_tuples(
    model,
    batched_processed_examples,
    loss_fn,  # _LossFn,
    metrics = None,
    get_num_examples_from_predictions_fn
     = DEFAULT_NUM_EXAMPLES_FROM_PREDICTIONS_FN
):
  """Calculate gradients, given a model, batched datacenter examples, and loss."""
  model_vars = tff.learning.ModelWeights.from_model(model)
  with tf.GradientTape(persistent=True) as tape:
    inputs, labels = batched_processed_examples
    predictions = model(inputs, training=True)
    if metrics:
      for metric in metrics:
        metric.update_state(labels, predictions)

    # Using a reshape here, to make the downstream code agnostic to whether the
    # loss_fn had a tf.keras.losses.Reduction of SUM_OVER_BATCH_SIZE or NONE (
    # see https://www.tensorflow.org/api_docs/python/tf/keras/losses/Reduction
    # for more different Keras loss reductions).
    possibly_per_example_loss = tf.reshape(loss_fn(labels, predictions), [-1])

    # If reduction was NONE, this will take each per-example loss, and calculate
    # individual gradients for each example's loss.
    def fn(loss):
      return tape.gradient(loss, model_vars.trainable)

    possibly_per_example_gradients = tf.vectorized_map(
        fn=fn, elems=possibly_per_example_loss)

  # Calculate the means of the gradient terms, over the batch.
  average_gradients = [
      tf.math.reduce_mean(g, axis=0) for g in possibly_per_example_gradients
  ]
  for idx, grad in enumerate(average_gradients):
    if isinstance(grad, tf.IndexedSlices):
      average_gradients[idx] = tf.convert_to_tensor(grad)

  # Calculate the variance of the gradient terms, over the batch.
  variances = [
      tf.reshape(tf.math.reduce_variance(g, axis=0), [-1])
      for g in possibly_per_example_gradients
  ]
  variance = tf.math.reduce_sum(tf.concat(variances, axis=0))

  num_examples_used = tf.cast(
      get_num_examples_from_predictions_fn(predictions), tf.float32)
  average_loss = tf.math.reduce_mean(possibly_per_example_loss, axis=0)
  return average_gradients, variance, num_examples_used, average_loss


_ReduceState = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, List[tf.Tensor]]
_ReduceFnForCentralizedTraining = Callable[[_ReduceState, Any], _ReduceState]
_GradientsAndVarianceFn = Callable[[tf.keras.Model, Any], Tuple[List[tf.Tensor],
                                                                tf.Tensor]]


def _get_reduce_fn_for_dicts(
    model,
    model_vars,
    optimizer,
    loss_fn,
    metrics,
    augmenting_gradients = None,
    augmenting_gradient_weight = 0.0,
    local_gradient_weight = 1.0,
    get_num_examples_from_predictions_fn
     = DEFAULT_NUM_EXAMPLES_FROM_PREDICTIONS_FN
):
  """Provides reduce fn for centralized training, where examples are dicts."""
  reduce_fn_for_tuples, _ = _get_reduce_fn_for_tuples(
      model, model_vars, optimizer, loss_fn, metrics, augmenting_gradients,
      augmenting_gradient_weight, local_gradient_weight,
      get_num_examples_from_predictions_fn)

  @tf.function
  def _reduce_fn_for_dicts(previous_state, batch):
    """Train `tff.learning.Model` on batch (in dict form)."""
    inputs = batch['x']
    labels = batch['y']
    return reduce_fn_for_tuples(previous_state, (inputs, labels))

  def _weighted_gradients_and_variance_fn_for_dicts(model, batch):
    gradients, variance, _, _ = _centralized_gradients_with_dicts(
        model, batch, loss_fn)
    weighted_gradients = tf.nest.map_structure(
        lambda grad: local_gradient_weight * grad, gradients)
    weighted_variance = local_gradient_weight**2 * variance
    return weighted_gradients, weighted_variance

  return _reduce_fn_for_dicts, _weighted_gradients_and_variance_fn_for_dicts


def _get_reduce_fn_for_tuples(
    model,
    model_vars,
    optimizer,
    loss_fn,
    metrics,
    augmenting_gradients = None,
    augmenting_gradient_weight = 0.0,
    local_gradient_weight = 1.0,
    get_num_examples_from_predictions_fn
     = DEFAULT_NUM_EXAMPLES_FROM_PREDICTIONS_FN
):
  """Provides reduce fn for centralized training, where examples are tuples."""
  if augmenting_gradients is None:
    augmenting_gradients = tf.nest.map_structure(tf.zeros_like,
                                                 model_vars.trainable)
  else:
    tf.nest.assert_same_structure(model_vars.trainable, augmenting_gradients)

  grad_wt = local_gradient_weight
  aug_wt = augmenting_gradient_weight

  @tf.function
  def _reduce_fn_for_tuples(previous_state, batch):
    """Train `tff.learning.Model` on batch (in tuple form)."""
    num_examples_sum, loss_sum, num_batches_sum, central_gradients_sum = previous_state

    central_gradients, _, num_new_examples, loss = (
        _centralized_gradients_with_tuples(
            model, batch, loss_fn, metrics,
            get_num_examples_from_predictions_fn))

    # Add in the augmenting gradients to the centralized gradients calculated by
    # minimizing the training loss, before applying via the optimizer.
    gradients = tf.nest.map_structure(
        lambda grad, aug: grad_wt * grad + aug_wt * aug, central_gradients,
        augmenting_gradients)
    # Note: this may create variables inside `optimizer`, for example if this is
    # the first usage of Adam or momentum optmizers.
    optimizer.apply_gradients(zip(gradients, model_vars.trainable))

    num_examples_sum = num_examples_sum + num_new_examples
    loss_sum = loss_sum + loss
    num_batches_sum = num_batches_sum + tf.constant(1.0, tf.float32)

    # This assumes that `central_gradients` is a mean of per-example gradients
    # across the batch, as opposed to a sum of per-example gradients. I.e., that
    # the reduction of the loss was defined as a mean (as opposed to a sum).
    central_gradients_sum = tf.nest.map_structure(
        lambda grad_sum, grad: grad_sum + num_new_examples * grad,
        central_gradients_sum, tuple(central_gradients))

    return num_examples_sum, loss_sum, num_batches_sum, central_gradients_sum

  def _weighted_gradients_and_variance_fn_for_tuples(model, batch):
    gradients, variance, _, _ = _centralized_gradients_with_tuples(
        model, batch, loss_fn)
    weighted_gradients = tf.nest.map_structure(
        lambda grad: local_gradient_weight * grad, gradients)
    weighted_variance = local_gradient_weight**2 * variance
    return weighted_gradients, weighted_variance

  return _reduce_fn_for_tuples, _weighted_gradients_and_variance_fn_for_tuples


@tf.function
def _centralized_training_with_dicts(
    model, processed_dataset,
    optimizer, loss_fn,
    metrics,
    get_num_examples_from_predictions_fn
):
  """Train a model on batches of datacenter data, given optimizer and loss."""
  model_vars = tff.learning.ModelWeights.from_model(model)
  reduce_fn, gradients_and_variance_fn = _get_reduce_fn_for_dicts(
      model,
      model_vars,
      optimizer,
      loss_fn,
      metrics,
      get_num_examples_from_predictions_fn=get_num_examples_from_predictions_fn)
  datacenter_model_weights, datacenter_train_metrics, _ = _centralized_training(
      model, processed_dataset, metrics, reduce_fn, gradients_and_variance_fn)
  return datacenter_model_weights, datacenter_train_metrics


@tf.function
def _centralized_training_with_tuples(
    model, processed_dataset,
    optimizer, loss_fn,
    metrics,
    get_num_examples_from_predictions_fn
):
  """Train a model on batches of datacenter data, given optimizer and loss."""
  model_vars = tff.learning.ModelWeights.from_model(model)
  reduce_fn, gradients_and_variance_fn = _get_reduce_fn_for_tuples(
      model,
      model_vars,
      optimizer,
      loss_fn,
      metrics,
      get_num_examples_from_predictions_fn=get_num_examples_from_predictions_fn)
  datacenter_model_weights, datacenter_train_metrics, _ = _centralized_training(
      model, processed_dataset, metrics, reduce_fn, gradients_and_variance_fn)
  return datacenter_model_weights, datacenter_train_metrics


@tf.function
def _centralized_training_with_augmenting_client_gradients_with_dicts(
    model, processed_dataset,
    optimizer, loss_fn,
    metrics,
    augmenting_client_gradients,
    augmenting_client_weight, datacenter_weight
):
  """Train on datacenter data (as dicts), w/ augmenting grads for client info."""
  model_vars = tff.learning.ModelWeights.from_model(model)
  reduce_fn, gradients_and_variance_fn = _get_reduce_fn_for_dicts(
      model, model_vars, optimizer, loss_fn, metrics,
      augmenting_client_gradients, augmenting_client_weight, datacenter_weight)
  return _centralized_training(model, processed_dataset, metrics, reduce_fn,
                               gradients_and_variance_fn)


@tf.function
def _centralized_training_with_augmenting_client_gradients_with_tuples(
    model, processed_dataset,
    optimizer, loss_fn,
    metrics,
    augmenting_client_gradients,
    augmenting_client_weight, datacenter_weight
):
  """Train on datacenter data (as tuples), w/ augmenting grads for client info."""
  model_vars = tff.learning.ModelWeights.from_model(model)
  reduce_fn, gradients_and_variance_fn = _get_reduce_fn_for_tuples(
      model, model_vars, optimizer, loss_fn, metrics,
      augmenting_client_gradients, augmenting_client_weight, datacenter_weight)
  return _centralized_training(model, processed_dataset, metrics, reduce_fn,
                               gradients_and_variance_fn)


@tf.function
def _centralized_training(
    model, processed_dataset,
    metrics,
    reduce_fn,
    gradients_and_variance_fn
):
  """Train on datacenter data, w/ augmenting grads for adding client info."""
  model_vars = tff.learning.ModelWeights.from_model(model)
  initial_trainable_model_weights = tf.nest.map_structure(
      tf.identity, model_vars.trainable)

  # Make copy of first batch in the centralized dataset, use to compute
  # gradients on the same batch of training examples both before and after the
  # training/reduce (to calculate some metrics about changes in centralized
  # gradients).
  first_batch = processed_dataset.take(1).get_single_element()

  (weighted_centralized_gradients_at_start,
   weighted_centralized_variance_at_start) = gradients_and_variance_fn(
       model, first_batch)

  zero_gradients = tuple(
      tf.nest.map_structure(tf.zeros_like, model_vars.trainable))
  initial_state = (tf.zeros(shape=[], dtype=tf.float32),
                   tf.zeros(shape=(), dtype=tf.float32),
                   tf.zeros(shape=(), dtype=tf.float32), zero_gradients)
  num_examples_sum, loss_sum, num_batches_sum, central_gradients_sum = (
      processed_dataset.reduce(
          initial_state=initial_state, reduce_func=reduce_fn))

  # If `num_examples` is zero, then the average gradients should just be zeroes.
  # Otherwise, divide the sum of the gradients by `num_examples` to get the
  # average gradients.
  average_central_gradients = tf.cond(
      tf.math.equal(num_examples_sum, 0.0),
      lambda: zero_gradients,
      lambda: tf.nest.map_structure(
          lambda g: (1.0 / num_examples_sum) * g, central_gradients_sum))
  average_central_gradients = list(average_central_gradients)

  weights_delta = tf.nest.map_structure(tf.subtract, model_vars.trainable,
                                        initial_trainable_model_weights)

  (weighted_centralized_gradients_at_end,
   weighted_centralized_variance_at_end) = gradients_and_variance_fn(
       model, first_batch)

  weighted_gradients_difference = tf.nest.map_structure(
      tf.subtract, weighted_centralized_gradients_at_end,
      weighted_centralized_gradients_at_start)
  lipschitz_smoothness = tf.linalg.global_norm(
      tf.nest.flatten(weighted_gradients_difference)) / tf.linalg.global_norm(
          tf.nest.flatten(weights_delta))

  train_dict = collections.OrderedDict(
      avg_loss=loss_sum / num_batches_sum,
      num_batches=num_batches_sum,
      num_examples=num_examples_sum,
      lipschitz_smoothness=lipschitz_smoothness,
      start_centralized_gradients_sample_variance=weighted_centralized_variance_at_start,
      end_centralized_gradients_sample_variance=weighted_centralized_variance_at_end,
  )
  for metric in metrics:
    train_dict[metric.name] = metric.result()
  metrics = collections.OrderedDict(train=train_dict)

  return model_vars, metrics, average_central_gradients


def _apply_deltas(
    *,
    optimizer,
    model_variables,
    mean_model_deltas,
):
  """Applies `mean_model_deltas` to `model_variables` using `optimizer`."""
  tf.nest.assert_same_structure(mean_model_deltas, model_variables.trainable)
  grads_and_vars = tf.nest.map_structure(lambda x, v: (-1.0 * x, v),
                                         mean_model_deltas,
                                         model_variables.trainable)
  # Note: this may create variables inside `optimizer`, for example if this is
  # the first usage of Adam or momentum optmizers.
  optimizer.apply_gradients(grads_and_vars)


def _eagerly_create_optimizer_variables(
    *, model_variables,
    optimizer):
  """Forces eager construction of the optimizer variables.

  This code is needed both in `server_init` and `server_update` (to introduce
  variables so we can read their initial values for the initial state).

  Args:
    model_variables: A `tff.learning.ModelWeights` structure of `tf.Variables`.
    optimizer: A `tf.keras.optimizers.Optimizer`.
  """
  deltas_tensor_spec = tf.nest.map_structure(
      lambda v: tf.TensorSpec.from_tensor(v.read_value()),
      model_variables.trainable)
  # Trace the function, which forces eager variable creation.
  tf.function(_apply_deltas).get_concrete_function(
      optimizer=optimizer,
      model_variables=model_variables,
      mean_model_deltas=deltas_tensor_spec)


@tf.function
def _merge_federated_and_datacenter_outputs(
    model, prior_model_weights,
    federated_model_weights,
    federated_weighting,
    datacenter_model_weights,
    datacenter_weighting, optimizer
):
  """Merge results from FL training and from datacenter training.

  Takes intermediate results from FL training and from datacenter training,
  merges them by weighted averaging, and then uses the merged 'gradient' and do
  a server optimizer update.

  Args:
    model: A tf.keras.Model whose constituent tf.Variables will be used.
    prior_model_weights: A tff.learning.ModelWeights reflecting the model
      parameters before the current round's client and datacenter training.
    federated_model_weights: A tff.learning.ModelWeights reflecting the model
      parameters as updated via client training.
    federated_weighting: The relative weighting of the federated training update
      in the overall update.
    datacenter_model_weights: A tff.learning.ModelWeights reflecting the model
      parameters as updated via datacenter training.
    datacenter_weighting: The relative weighting of the datacenter training
      update in the overall update.
    optimizer: The tf.keras.optimizers.Optimizer to use for updating the model's
      trainable variables, based on the 'gradient' determined by merging the
      client (FL) update and the datacenter update.

  Returns:
    A tuple of:
      - a tff.learning.ModelWeights with the updated model parameters, post
        optimizer update.
      - a flat list of the weight deltas from the overall update.
      - a flat list of the weight deltas from the client updates.
      - a flat list of the weight deltas from the datacenter update.

  """
  prior_model_weights.assign_weights_to(model)
  model_vars = tff.learning.ModelWeights.from_model(model)

  with tf.init_scope():
    # We must force variable creation for momentum and adaptive optimizers.
    _eagerly_create_optimizer_variables(
        model_variables=model_vars, optimizer=optimizer)

  client_deltas = tf.nest.map_structure(
      lambda prev_var, next_var: prev_var - next_var,
      federated_model_weights.trainable, prior_model_weights.trainable)

  datacenter_deltas = tf.nest.map_structure(
      lambda prev_var, next_var: prev_var - next_var,
      datacenter_model_weights.trainable, prior_model_weights.trainable)

  sum_model_deltas = tf.nest.map_structure(
      lambda cd, dd: federated_weighting * cd + datacenter_weighting * dd,
      client_deltas, datacenter_deltas)
  inv_total_weight = tf.math.divide_no_nan(
      1.0, federated_weighting + datacenter_weighting)
  mean_model_deltas = tf.nest.map_structure(lambda d: inv_total_weight * d,
                                            sum_model_deltas)

  _apply_deltas(
      optimizer=optimizer,
      model_variables=model_vars,
      mean_model_deltas=mean_model_deltas)

  return model_vars, mean_model_deltas, client_deltas, datacenter_deltas


@dataclasses.dataclass
class _DatacenterTraining():
  dataset_iter: Iterator[tf.data.Dataset]
  dataset_processing_fn: _DatasetProcessingFn
  optimizer: tf.keras.optimizers.Optimizer
  loss_fn: _LossFn
  metrics_fn: _MetricsConstructor


class _ParallelTrainingMixingProcess(MixingProcess):
  """A MixingProcess where datacenter training takes place in parallel to FL."""

  def __init__(
      self,
      keras_model_fn,
      federated_iterative_process,
      datacenter_training,
      merging_optimizer,
      get_num_examples_from_predictions_fn,
      datacenter_weight = 1.0,
      client_weight = 1.0,
  ):
    """Constructor for MixingProcess that will blend datacenter training w/ FL.

    When the `next` method of this class is called, a standard FL round of
    FedAvg will take place, a series of steps of gradient descent on a number of
    examples from a datacenter training dataset will take place in-parallel to
    the FL round, and then the resulting model updates from FL and from
    datacenter training will be merged into a new global model.

    Args:
      keras_model_fn: A callable returning an instance of a tf.keras.Model, to
        be used in the centralized training. It should match the (TFF) model
        that has been wrapped into the `federated_iterative_process`.
      federated_iterative_process: A tff.templates.IterativeProcess that has
        already been composed with data processing, so that its `next` method
        expects a list of client ids. This handles the FL training side of the
        in-parallel FL/centralized training streams.
      datacenter_training: An dataclass containing a number of objects to be
        used during centralized ('datacenter') training. These objects are the
        centralized training equivalents of information already packaged up in
        the `federated_iterative_process` for the FL training, e.g., objects for
        the data, data processing, optimizer and loss, metrics, etc.
      merging_optimizer: The optimizer to be used for updating the global model,
        based on the 'gradient' determined by merging the client (FL) update and
        the datacenter update.
      get_num_examples_from_predictions_fn: Callable to calculate the number of
        examples in a batch of predictions made by the model returned by
        keras_model_fn.
      datacenter_weight: This should be chosen in concert with the
        `client_weight`, to properly relatively weight the centralized and
        federated model updates when merging at end of the round.
      client_weight: This should be chosen in concert with the
        `datacenter_weight`, to properly relatively weight the centralized and
        federated model updates when merging at end of the round.
    """
    self._keras_model_fn = keras_model_fn
    self._federated_iterative_process = federated_iterative_process
    self._datacenter_training = datacenter_training
    self._merging_optimizer = merging_optimizer
    self._get_num_examples_from_predictions_fn = get_num_examples_from_predictions_fn
    self._datacenter_weight = datacenter_weight
    self._client_weight = client_weight

  def initialize(self):
    initial_state = self._federated_iterative_process.initialize()
    return initial_state

  def next(self, prior_state,
           federated_client_ids):

    federated_state, federated_train_metrics = (
        self._federated_iterative_process.next(prior_state,
                                               federated_client_ids))

    prior_model = self._keras_model_fn()
    prior_state.model.assign_weights_to(prior_model)

    batched_processed_dataset = self._datacenter_training.dataset_processing_fn(
        tf.data.Dataset.zip(next(self._datacenter_training.dataset_iter)))

    if isinstance(batched_processed_dataset.element_spec,
                  collections.abc.Sequence):
      datacenter_model_weights, datacenter_train_metrics = _centralized_training_with_tuples(
          prior_model, batched_processed_dataset,
          self._datacenter_training.optimizer,
          self._datacenter_training.loss_fn,
          self._datacenter_training.metrics_fn(),
          self._get_num_examples_from_predictions_fn)
    elif isinstance(batched_processed_dataset.element_spec,
                    collections.abc.Mapping):
      datacenter_model_weights, datacenter_train_metrics = _centralized_training_with_dicts(
          prior_model, batched_processed_dataset,
          self._datacenter_training.optimizer,
          self._datacenter_training.loss_fn,
          self._datacenter_training.metrics_fn(),
          self._get_num_examples_from_predictions_fn)
    else:
      raise ValueError('The processed examples are in unexpected format of '
                       f'{type(batched_processed_dataset.element_spec)}. '
                       '(Expected either `tuple` or `dict` type.)')

    model_params, overall_deltas, client_deltas, datacenter_deltas = (
        _merge_federated_and_datacenter_outputs(
            self._keras_model_fn(), prior_state.model, federated_state.model,
            self._client_weight, datacenter_model_weights,
            self._datacenter_weight, self._merging_optimizer))
    state = tff.learning.state_with_new_model_weights(
        federated_state, [x.numpy() for x in model_params.trainable],
        [x.numpy() for x in model_params.non_trainable])

    federated_train_metrics['datacenter_train'] = datacenter_train_metrics[
        'train']
    federated_train_metrics['train'][
        'federated_num_examples'] = federated_train_metrics['train'][
            'num_examples']
    federated_train_metrics['train'][
        'datacenter_num_examples'] = datacenter_train_metrics['train'][
            'num_examples']
    federated_train_metrics['train']['num_examples'] += (
        datacenter_train_metrics['train']['num_examples'])
    federated_train_metrics['train'][
        'federated_delta_glob_norm'] = tf.linalg.global_norm(
            tf.nest.flatten(client_deltas))
    federated_train_metrics['train'][
        'datacenter_delta_glob_norm'] = tf.linalg.global_norm(
            tf.nest.flatten(datacenter_deltas))
    federated_train_metrics['train']['delta_glob_norm'] = tf.linalg.global_norm(
        tf.nest.flatten(overall_deltas))
    federated_train_metrics['train']['federated_weight'] = self._client_weight
    federated_train_metrics['train'][
        'datacenter_weight'] = self._datacenter_weight

    return state, federated_train_metrics


def build_mixing_process_with_parallel_training(
    *,
    keras_model_fn,
    tff_model_fn,
    client_data,
    datacenter_dataset_fn,
    num_effective_clients_for_training,
    num_examples_per_effective_client,
    datacenter_loss_fn,
    client_optimizer_fn,
    datacenter_optimizer_fn,
    server_optimizer_fn,
    datacenter_metrics_fn = lambda: [],
    client_dataset_processing_fn = IDENTITY_PROCESSING_FN,
    datacenter_dataset_processing_fn
     = IDENTITY_PROCESSING_FN,
    datacenter_shuffle_buffer = DEFAULT_DATACENTER_SHUFFLE_BUFFER,
    datacenter_weight = 1.0,
    client_weight = 1.0,
    iterative_process_fn = TFF_LEARN_FED_AVG_FN,
    get_num_examples_from_predictions_fn
     = DEFAULT_NUM_EXAMPLES_FROM_PREDICTIONS_FN,
):
  """Factory method for getting a _ParallelTrainingMixingProcess.

  This is a method of federated/datacenter data mixing where, during each round,
  in parallel to the federated training and aggregating taking place with the
  clients, there is an optimization taking place on datacenter data. The client
  FL and the datacenter learning both begin the round with the same model
  weights, and at the conclusion of the round the client updates are combined
  with the datacenter update to do a joint server update, resulting in a single
  merged global model.  The client updates and datacenter update are weighted
  (in the server update) by the respective number of examples trained on in
  each.

  Args:
    keras_model_fn: A callable returning an instance of a tf.keras.Model. It
      should match the model wrapped inside `tff_model_fn`.
    tff_model_fn: A callable returning an instance of a tff.learning.Model.
    client_data: A federated dataset for the clients (e.g., edge devices).
    datacenter_dataset_fn: A callable returning the tf.data.Dataset for the
      datacenter data.
    num_effective_clients_for_training: The number of 'effective' clients whose
      data the datacenter training will consume. The total data used in training
      will be `num_effective_clients_for_training *
      num_examples_per_effective_client`.
    num_examples_per_effective_client: A number of examples that should be set
      in rough equivalence to the expected number of examples in the federated
      client datasets.
    datacenter_loss_fn: A callable returning the tf.keras.losses.Loss to be used
      when optimizing on the datacenter data. Note that the analogous loss
      function for the client training are specified via the tff.learning.Model
      return by the `tff_model_fn` argument.
    client_optimizer_fn: A callable returning the optimizer to be used for
      calculating client updates.
    datacenter_optimizer_fn: A callable returning the optimizer to be used for
      calculating datacenter updates.
    server_optimizer_fn: A callable returning the optimizer to be used for
      calculating server updates.
    datacenter_metrics_fn: A callable returning list of tf.keras.metric.Metrics,
      to be used as metrics during datacentering training. Note that the
      analogous metrics for the client training are specified via the
      tff.learning.Model returned by the `tff_model_fn` argument.
    client_dataset_processing_fn: A callable for converting the raw (e.g.,
      serialized) unprocessed dataset into a batched, shuffled, processed
      dataset of examples in the format expected by the model during client
      training.
    datacenter_dataset_processing_fn: A callable for converting the raw (e.g.,
      serialized) unprocessed dataset into a batched, shuffled, processed
      dataset of examples in the format expected by the model during datacenter
      training.
    datacenter_shuffle_buffer: Size of buffer to use when shuffling the
      datacenter dataset.
    datacenter_weight: This should be chosen in concert with the
      `client_weight`, to properly relatively weight the centralized and
      federated model updates when merging at end of the round.
    client_weight: This should be chosen in concert with the
      `datacenter_weight`, to properly relatively weight the centralized and
      federated model updates when merging at end of the round.
    iterative_process_fn: A callable that returns an instance of
      tff.templates.IterativeProcess, to be used for performing the federated
      training. It should take 3 callables as arguments: a callable for getting
        the tff.learning.Model, a callable for getting the client optimizer, and
        a callable for getting the server optimizer.
    get_num_examples_from_predictions_fn: A callable that takes as input the
      predictions of the model returned by keras_model_fn on a given batch. The
      callable must return the number of examples in the batch. Used to
      correctly calculate the number of examples being trained on in the
      datacenter. By default, we use the first dimension of the predictions
      as the number of examples, and we expect this to satisfy most use cases.

  Returns:
    An instance of _ParallelTrainingMixingProcess.
  """

  @tff.tf_computation(tf.string)
  def process_client_data_fn(client_id):
    return client_dataset_processing_fn(
        client_data.dataset_computation(client_id))

  federated_iterative_process = iterative_process_fn(
      model_fn=tff_model_fn,
      client_optimizer_fn=client_optimizer_fn,
      server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
  # Compose the iterative process with the data processing step; this will take
  # place at the Borg workers (i.e. the clients), resulting in significant
  # speedup.
  federated_iterative_process = (
      tff.simulation.compose_dataset_computation_with_iterative_process(
          process_client_data_fn, federated_iterative_process))

  datacenter_dataset_iterator = iter(datacenter_dataset_fn().shuffle(
      datacenter_shuffle_buffer, reshuffle_each_iteration=True).repeat().window(
          num_effective_clients_for_training *
          num_examples_per_effective_client))

  datacenter_training = _DatacenterTraining(
      dataset_iter=datacenter_dataset_iterator,
      dataset_processing_fn=datacenter_dataset_processing_fn,
      optimizer=datacenter_optimizer_fn(),
      loss_fn=datacenter_loss_fn,
      metrics_fn=datacenter_metrics_fn)
  return _ParallelTrainingMixingProcess(
      keras_model_fn=keras_model_fn,
      federated_iterative_process=federated_iterative_process,
      datacenter_training=datacenter_training,
      merging_optimizer=server_optimizer_fn(),
      get_num_examples_from_predictions_fn=get_num_examples_from_predictions_fn,
      datacenter_weight=datacenter_weight,
      client_weight=client_weight)


def _compute_augmenting_datacenter_gradients(
    prior_model, batched_processed_examples,
    loss_fn):
  """Compute the augmenting gradients from model, data, and loss function."""
  if isinstance(batched_processed_examples, collections.abc.Sequence):
    gradients, variance, num_examples, _ = _centralized_gradients_with_tuples(
        prior_model, batched_processed_examples, loss_fn)
  elif isinstance(batched_processed_examples, collections.abc.Mapping):
    gradients, variance, num_examples, _ = _centralized_gradients_with_dicts(
        prior_model, batched_processed_examples, loss_fn)
  else:
    raise ValueError(
        'The processed examples are in unexpected format of %s. (Expected '
        'either tuple or dict.)' % type(batched_processed_examples))
  return gradients, variance, num_examples


@dataclasses.dataclass
class _DatacenterGradients():
  dataset_iter: Iterator[tf.data.Dataset]
  dataset_processing_fn: _DatasetProcessingFn
  loss_fn: _LossFn


class _OneWayGradTransferMixingProcess(MixingProcess):
  """A MixingProcess where client caches augmented with datacenter gradients."""

  def __init__(self,
               keras_model_fn,
               augmenting_iterative_process,
               datacenter_gradients,
               datacenter_gradient_weight = 1.0,
               client_gradient_weight = 1.0):
    """Constructor for MixingProcess augmenting datacenter gradients to devices.

    Args:
      keras_model_fn: A callable returning an instance of a tf.keras.Model, to
        be used in the centralized training. It should match the (TFF) model
        that has been wrapped into the `augmenting_iterative_process`.
      augmenting_iterative_process: A tff.templates.IterativeProcess that
        accomodates augmenting gradients in its `next` call. Also, it should
        have already been composed with data processing, so that its `next`
        method expects a list of client ids.
      datacenter_gradients: Controls direction of the vector passed in as the
        augmenting gradients. A dataclass containing a number of objects to be
        used during centralized ('datacenter') calculation of a gradients
        vector.
     datacenter_gradient_weight: Scalar weight applied to augmenting datacenter
       gradients when summing them with the local client gradients. This should
       be chosen in concert with the `client_gradient_weight`, to properly
       weight the two gradient terms relative to each other as desired.
     client_gradient_weight: A scalar weight applied to local client gradients
       when summing them with augmenting datacenter gradients. This should be
       chosen in concert with the `datacenter_gradient_weight`, to properly
       weight the two gradient terms relative to each other as desired.
    """
    self._keras_model_fn = keras_model_fn
    self._augmenting_iterative_process = augmenting_iterative_process
    self._datacenter_gradients = datacenter_gradients
    self._datacenter_gradient_weight = datacenter_gradient_weight
    self._client_gradient_weight = client_gradient_weight

  def initialize(self):
    initial_state = self._augmenting_iterative_process.initialize()
    return initial_state

  def next(self, prior_state,
           federated_client_ids):

    prior_model = self._keras_model_fn()
    prior_state.model.assign_weights_to(prior_model)
    batched_processed_examples = next(
        iter(
            self._datacenter_gradients.dataset_processing_fn(
                tf.data.Dataset.zip(
                    next(self._datacenter_gradients.dataset_iter)))))
    (augmenting_datacenter_gradients, augmenting_datacenter_gradients_variance,
     num_examples_used) = _compute_augmenting_datacenter_gradients(
         prior_model, batched_processed_examples,
         self._datacenter_gradients.loss_fn)

    state, train_metrics = (
        self._augmenting_iterative_process.next(
            prior_state, federated_client_ids, augmenting_datacenter_gradients,
            self._datacenter_gradient_weight, self._client_gradient_weight))

    deltas = tf.nest.map_structure(
        lambda prev_var, next_var: prev_var - next_var, state.model.trainable,
        prior_state.model.trainable)

    train_metrics['train'][
        'augmenting_datacenter_grads_glob_norm'] = tf.linalg.global_norm(
            tf.nest.flatten(augmenting_datacenter_gradients))
    train_metrics['train'][
        'augmenting_datacenter_grads_sample_variance'] = augmenting_datacenter_gradients_variance
    train_metrics['train'][
        'augmenting_datacenter_grads_num_examples_used'] = num_examples_used
    train_metrics['train']['delta_glob_norm'] = tf.linalg.global_norm(
        tf.nest.flatten(deltas))

    return state, train_metrics


def build_mixing_process_with_gradient_transfer(
    *,
    keras_model_fn,
    tff_model_fn,
    client_data,
    datacenter_dataset_fn,
    datacenter_batch_size,
    datacenter_loss_fn,
    client_optimizer_fn,
    server_optimizer_fn,
    client_dataset_processing_fn = IDENTITY_PROCESSING_FN,
    datacenter_dataset_processing_fn
     = IDENTITY_PROCESSING_FN,
    datacenter_shuffle_buffer = DEFAULT_DATACENTER_SHUFFLE_BUFFER,
    datacenter_gradient_weight = 1.0,
    client_gradient_weight = 1.0,
    augmenting_iterative_process_fn
     = GRAD_TRANSFER_FED_AVG_FN,
):
  """Factory method for getting a _OneWayGradTransferMixingProcess.

  In this form of mixing, the datacenter data is used (along with a loss fn and
  the latest weights of the model) to compute a list of gradients, representing
  the direction of a loss minima for the datacenter data distribution. These
  gradients are then shipped to the clients during a federated round, and they
  are 'augmented' to the gradients computed locally from the client data. The
  idea is that the combined gradients should point in the direction of a minima
  for the overall (federated + datacenter combined) distributions.

  Args:
    keras_model_fn: A callable returning an instance of a tf.keras.Model. It
      should match the model wrapped inside `tff_model_fn`.
    tff_model_fn: A callable returning an instance of a tff.learning.Model.
    client_data: A federated dataset for the clients (e.g., edge devices).
    datacenter_dataset_fn: A callable returning the tf.data.Dataset for the
      datacenter data.
    datacenter_batch_size: The batch size to use when calculating the gradients
      of the datacenter data. This should match the batch size that is applied
      in `datacenter_dataset_processing_fn`.
    datacenter_loss_fn: A callable returning the tf.keras.losses.Loss to be used
      in calculating the datacenter data gradients (i.e., the gradients to be
      used for augmenting). In general, this should be the same loss function as
      that wrapped into the TFF model for application during FL.
    client_optimizer_fn: A callable returning the optimizer to be used for
      calculating client updates.
    server_optimizer_fn: A callable returning the optimizer to be used for
      calculating server updates.
    client_dataset_processing_fn: A callable for converting the raw (e.g.,
      serialized) unprocessed dataset into a batched, shuffled, processed
      dataset of examples in the format expected by the model during client
      training.
    datacenter_dataset_processing_fn: A callable for converting the raw (e.g.,
      serialized) unprocessed dataset into a batched, shuffled, processed
      dataset of examples in the format expected by the model during datacenter
      gradients calculation (i.e., the gradients to be used for augmenting).
    datacenter_shuffle_buffer: Size of buffer to use when shuffling the
      datacenter dataset.
    datacenter_gradient_weight: A scalar weight applied to augmenting datacenter
      gradients when summing them with the local client gradients. This should
      be chosen in concert with the `client_gradient_weight`, to properly weight
      the two gradient terms relative to each other as desired.
    client_gradient_weight: A scalar weight applied to local client gradients
      when summing them with the augmenting datacenter gradients. This should be
      chosen in concert with the `datacenter_gradient_weight`, to properly
      weight the two gradient terms relative to each other as desired.
    augmenting_iterative_process_fn: A callable that returns an instance of
      tff.templates.IterativeProcess with a specific next() signature, to be
      used for performing federated training with gradient augmentation. It
      should take 3 callables as arguments: a callable for getting the
      tff.learning.Model, a callable for getting the client optimizer, and a
      callable for getting the server optimizer. It should return an iterative
      process with a next() method that takes 5 arguments: the server state, the
      federated data for the clients in the round, a list of "augmenting"
      datacenter gradients (which will be summed with the local gradients
      calculated during the client update steps), the weighting to apply to
      the "augmenting" datacenter gradients, and the weighting to apply to the
      local client gradients.

  Returns:
    An instance of _OneWayGradTransferMixingProcess.
  """

  @tff.tf_computation(tf.string)
  def process_client_data_fn(client_id):
    return client_dataset_processing_fn(
        client_data.dataset_computation(client_id))

  augmenting_iterative_process = augmenting_iterative_process_fn(
      model_fn=tff_model_fn,
      client_optimizer_fn=client_optimizer_fn,
      server_optimizer_fn=server_optimizer_fn)
  # Compose the iterative process with the data processing step; this will take
  # place at the Borg workers (i.e. the clients), resulting in significant
  # speedup.
  augmenting_iterative_process = (
      tff.simulation.compose_dataset_computation_with_iterative_process(
          process_client_data_fn, augmenting_iterative_process))

  datacenter_dataset_iterator = iter(datacenter_dataset_fn().shuffle(
      datacenter_shuffle_buffer,
      reshuffle_each_iteration=True).repeat().window(datacenter_batch_size))

  datacenter_gradients = _DatacenterGradients(
      dataset_iter=datacenter_dataset_iterator,
      dataset_processing_fn=datacenter_dataset_processing_fn,
      loss_fn=datacenter_loss_fn)

  return _OneWayGradTransferMixingProcess(keras_model_fn,
                                          augmenting_iterative_process,
                                          datacenter_gradients,
                                          datacenter_gradient_weight,
                                          client_gradient_weight)


class GradientsComputationOption(enum.Enum):
  FULL_BATCH_WITH_LAST_CHECKPOINT = 'FULL_BATCH_WITH_LAST_CHECKPOINT'
  AVERAGED_FROM_CENTRAL_TRAINING = 'AVERAGED_FROM_CENTRAL_TRAINING'
  FULL_BATCH_WITH_NEXT_CHECKPOINT = 'FULL_BATCH_WITH_NEXT_CHECKPOINT'


class _TwoWayGradTransferMixingProcess(MixingProcess):
  """A MixingProcess where server and clients exchange respective gradients.

  This approach can be considered a meta-SCAFFOLD approach, where instead of
  sharing gradients between individual clients for variance reduction, here we
  exchange gradients between different datasets (centralized dataset and
  decentralized/federated dataset) to achieve mixing.
  """

  def __init__(self,
               keras_model_fn,
               augmenting_iterative_process,
               datacenter_training,
               merging_optimizer,
               datacenter_weight = 1.0,
               client_weight = 1.0,
               augmenting_datacenter_gradients_option
                = GradientsComputationOption
               .AVERAGED_FROM_CENTRAL_TRAINING):
    """Constructs class for augmenting gradients b/w clients and server.

    Args:
      keras_model_fn: A callable returning an instance of a tf.keras.Model, to
        be used in the centralized training. It should match the (TFF) model
        that has been wrapped into the `augmenting_iterative_process`.
      augmenting_iterative_process: A tff.templates.IterativeProcess that
        accomodates augmenting gradients in its `next` call. Also, it should
        have already been composed with data processing, so that its `next`
        method expects a list of client ids.
      datacenter_training: An dataclass containing a number of objects to be
        used during centralized ('datacenter') training. These objects are the
        centralized training equivalents of information already packaged up in
        the `federated_iterative_process` for the FL training, e.g., objects for
        the data, data processing, optimizer and loss, metrics, etc.
      merging_optimizer: The optimizer to be used for updating the global model,
        based on the 'gradient' determined by merging the client (FL) update and
        the datacenter update.
      datacenter_weight: A scalar weight to apply to the datacenter gradients
        when summing them with the client gradients. This should be chosen in
        concert with the `client_weight`, to properly weight the two gradient
        terms relative to each other as desired.
      client_weight: A scalar weight to apply to the local client gradients when
        summing them with the datacenter gradients. This should be chosen in
        concert with the `datacenter_weight`, to properly weight the two
        gradient terms relative to each other as desired.
      augmenting_datacenter_gradients_option: Which option to use in how to
        compute (at the datacenter) the augmenting gradients that will be sent
        to the clients.
    """
    self._keras_model_fn = keras_model_fn
    self._augmenting_iterative_process = augmenting_iterative_process
    self._datacenter_training = datacenter_training
    self._merging_optimizer = merging_optimizer
    self._datacenter_weight = datacenter_weight
    self._client_weight = client_weight
    self.augmenting_datacenter_gradients_option = (
        augmenting_datacenter_gradients_option)
    self._prior_datacenter_gradients = None

  def initialize(self):
    initial_state = self._augmenting_iterative_process.initialize()
    self._prior_datacenter_gradients = tf.nest.map_structure(
        tf.zeros_like, initial_state.model.trainable)
    return initial_state

  def next(self, prior_state,
           federated_client_ids):

    prior_model = self._keras_model_fn()
    prior_state.model.assign_weights_to(prior_model)

    federated_state, federated_train_metrics = (
        self._augmenting_iterative_process.next(
            prior_state, federated_client_ids, self._prior_datacenter_gradients,
            self._datacenter_weight, self._client_weight))

    batched_processed_dataset = self._datacenter_training.dataset_processing_fn(
        tf.data.Dataset.zip(next(self._datacenter_training.dataset_iter)))
    if isinstance(batched_processed_dataset.element_spec,
                  collections.abc.Sequence):
      (datacenter_model_weights, datacenter_train_metrics,
       average_datacenter_gradients) = (
           _centralized_training_with_augmenting_client_gradients_with_tuples(
               prior_model, batched_processed_dataset,
               self._datacenter_training.optimizer,
               self._datacenter_training.loss_fn,
               self._datacenter_training.metrics_fn(),
               prior_state.average_client_gradients, self._client_weight,
               self._datacenter_weight))
    elif isinstance(batched_processed_dataset.element_spec,
                    collections.abc.Mapping):
      (datacenter_model_weights, datacenter_train_metrics,
       average_datacenter_gradients) = (
           _centralized_training_with_augmenting_client_gradients_with_dicts(
               prior_model, batched_processed_dataset,
               self._datacenter_training.optimizer,
               self._datacenter_training.loss_fn,
               self._datacenter_training.metrics_fn(),
               prior_state.average_client_gradients, self._client_weight,
               self._datacenter_weight))
    else:
      raise ValueError(
          'The processed examples are in unexpected format of %s. (Expected '
          'either tuple or dict.)' %
          type(batched_processed_dataset.element_spec))

    model_params, overall_deltas, client_deltas, datacenter_deltas = (
        _merge_federated_and_datacenter_outputs(
            self._keras_model_fn(), prior_state.model, federated_state.model,
            self._client_weight, datacenter_model_weights,
            self._datacenter_weight, self._merging_optimizer))
    state = (
        process_with_gradient_transfer_lib
        .ServerStateWithAverageClientGradients(
            model=tff.learning.ModelWeights(
                trainable=[x.numpy() for x in model_params.trainable],
                non_trainable=[x.numpy() for x in model_params.non_trainable]),
            optimizer_state=federated_state.optimizer_state,
            delta_aggregate_state=federated_state.delta_aggregate_state,
            model_broadcast_state=federated_state.model_broadcast_state,
            average_client_gradients=federated_state.average_client_gradients))

    if (self.augmenting_datacenter_gradients_option ==
        GradientsComputationOption.AVERAGED_FROM_CENTRAL_TRAINING):
      # Compute the next augmenting datacenter gradient as an average of the
      # gradients over multiple steps of datacenter descent. This is following
      # an approach that looks like 'Option II' in the SCAFFOLD paper, Eqn 4.
      augmenting_datacenter_gradients = average_datacenter_gradients
      augmenting_datacenter_gradients_variance = 0.0
      num_examples_used = 0
    else:
      if (self.augmenting_datacenter_gradients_option ==
          GradientsComputationOption.FULL_BATCH_WITH_LAST_CHECKPOINT):
        # Compute the next augmenting datacenter gradients from prior model, to
        # be passed to the clients next round. This is following an approach
        # that looks like 'Option I' in the SCAFFOLD paper, Eqn 4: a single
        # (datacenter) gradient calculated at the previous global model.
        model_to_use_for_gradient_calculation = prior_model
      elif (self.augmenting_datacenter_gradients_option ==
            GradientsComputationOption.FULL_BATCH_WITH_NEXT_CHECKPOINT):
        # Compute the next augmenting datacenter gradients from the next model,
        # to be passed to the clients next round. This option is not available
        # in SCAFFOLD, but it is possible here with 2-way Gradient Transfer
        # where we only have two 'meta'-clients.
        model_to_use_for_gradient_calculation = self._keras_model_fn()
        state.model.assign_weights_to(model_to_use_for_gradient_calculation)

      batched_processed_examples = next(
          iter(
              self._datacenter_training.dataset_processing_fn(
                  tf.data.Dataset.zip(
                      next(self._datacenter_training.dataset_iter)))))
      (augmenting_datacenter_gradients,
       augmenting_datacenter_gradients_variance,
       num_examples_used) = _compute_augmenting_datacenter_gradients(
           model_to_use_for_gradient_calculation, batched_processed_examples,
           self._datacenter_training.loss_fn)

    federated_train_metrics['datacenter_train'] = datacenter_train_metrics[
        'train']
    federated_train_metrics['train'][
        'federated_num_examples'] = federated_train_metrics['train'][
            'num_examples']
    federated_train_metrics['train'][
        'datacenter_num_examples'] = datacenter_train_metrics['train'][
            'num_examples']
    federated_train_metrics['train']['num_examples'] += (
        datacenter_train_metrics['train']['num_examples'])
    federated_train_metrics['train'][
        'federated_delta_glob_norm'] = tf.linalg.global_norm(
            tf.nest.flatten(client_deltas))
    federated_train_metrics['train'][
        'datacenter_delta_glob_norm'] = tf.linalg.global_norm(
            tf.nest.flatten(datacenter_deltas))
    federated_train_metrics['train']['delta_glob_norm'] = tf.linalg.global_norm(
        tf.nest.flatten(overall_deltas))
    federated_train_metrics['train']['federated_weight'] = self._client_weight
    federated_train_metrics['train'][
        'datacenter_weight'] = self._datacenter_weight

    federated_train_metrics['train'][
        'augmenting_datacenter_grads_glob_norm'] = tf.linalg.global_norm(
            tf.nest.flatten(augmenting_datacenter_gradients))
    federated_train_metrics['train'][
        'augmenting_datacenter_grads_sample_variance'] = augmenting_datacenter_gradients_variance
    federated_train_metrics['train'][
        'augmenting_datacenter_grads_num_examples_used'] = num_examples_used
    federated_train_metrics['train'][
        'augmenting_client_grads_glob_norm'] = tf.linalg.global_norm(
            tf.nest.flatten(federated_state.average_client_gradients))

    self._prior_datacenter_gradients = augmenting_datacenter_gradients
    return state, federated_train_metrics


def build_mixing_process_with_two_way_gradient_transfer(
    *,
    keras_model_fn,
    tff_model_fn,
    client_data,
    datacenter_dataset_fn,
    num_effective_clients_for_training,
    num_examples_per_effective_client,
    datacenter_loss_fn,
    client_optimizer_fn,
    datacenter_optimizer_fn,
    server_optimizer_fn,
    datacenter_metrics_fn = lambda: [],
    client_dataset_processing_fn = IDENTITY_PROCESSING_FN,
    datacenter_dataset_processing_fn
     = IDENTITY_PROCESSING_FN,
    datacenter_shuffle_buffer = DEFAULT_DATACENTER_SHUFFLE_BUFFER,
    datacenter_weight = 1.0,
    client_weight = 1.0,
    augmenting_iterative_process_fn
     = GRAD_TRANSFER_FED_AVG_FN,
    augmenting_datacenter_gradients_option = (
        GradientsComputationOption.AVERAGED_FROM_CENTRAL_TRAINING),
):
  """Factory method to get a _TwoWayGradTransferMixingProcess.

  Like one-way gradient transfer (described above), but where instead we pass
  augmenting gradients both ways, from clients to server (to provide information
  to centralized learning about the loss landscape of the decentralized data) as
  well as from server to clients (to provide information to decentralized
  learning about the loss landscape of the centralized data). This can be
  considered analogous to SCAFFOLD (https://arxiv.org/abs/1910.06378), where
  instead of sharing gradients between individual clients for variance
  reduction, here we exchange gradients between different datasets (centralized
  dataset and decentralized/federated dataset) to achieve mixing.

  Args:
    keras_model_fn: A callable returning an instance of a tf.keras.Model. It
      should match the model wrapped inside `tff_model_fn`.
    tff_model_fn: A callable returning an instance of a tff.learning.Model.
    client_data: A federated dataset for the clients (e.g., edge devices).
    datacenter_dataset_fn: A callable returning the tf.data.Dataset for the
      datacenter data.
    num_effective_clients_for_training: The number of 'effective' clients whose
      data the datacenter training will consume. The total data used in training
      will be `num_effective_clients_for_training *
      num_examples_per_effective_client`.
    num_examples_per_effective_client: A number of examples that should be set
      in rough equivalence to the expected number of examples in the federated
      client datasets.
    datacenter_loss_fn: A callable returning the tf.keras.losses.Loss to be used
      in calculating the datacenter data gradients (i.e., the gradients to be
      used for augmenting). In general, this should be the same loss function as
      that wrapped into the TFF model for application during FL.
    client_optimizer_fn: A callable returning the optimizer to be used for
      calculating client updates.
    datacenter_optimizer_fn: A callable returning the optimizer to be used for
      calculating datacenter updates.
    server_optimizer_fn: A callable returning the optimizer to be used for
      calculating server updates.
    datacenter_metrics_fn: A callable returning list of tf.keras.metric.Metrics,
      to be used as metrics during datacentering training. Note that the
      analogous metrics for the client training are specified via the
      tff.learning.Model returned by the `tff_model_fn` argument.
    client_dataset_processing_fn: A callable for converting the raw (e.g.,
      serialized) unprocessed dataset into a batched, shuffled, processed
      dataset of examples in the format expected by the model during client
      training.
    datacenter_dataset_processing_fn: A callable for converting the raw (e.g.,
      serialized) unprocessed dataset into a batched, shuffled, processed
      dataset of examples in the format expected by the model during datacenter
      gradients calculation (i.e., the gradients to be used for augmenting).
    datacenter_shuffle_buffer: Size of buffer to use when shuffling the
      datacenter dataset.
    datacenter_weight: A scalar weight to apply to the datacenter gradients when
      summing them with the client gradients. This should be chosen in concert
      with the `client_weight`, to properly weight the two gradient terms
      relative to each other as desired. Note that this weight is used when
      summing gradients both during centralized training on server as well as
      federated training on clients. It is also used to relatively weight the
      centralized and federated model updates when merging at end of the round.
    client_weight: A scalar weight to apply to the client gradients when summing
      them with the datacenter gradients. This should be chosen in concert with
      the `datacenter_weight`, to properly weight the two gradient terms
      relative to each other as desired. Note that this weight is used when
      summing gradients both during centralized training on server as well as
      federated training on clients. It is also used to relatively weight the
      centralized and federated model updates when merging at end of the round.
    augmenting_iterative_process_fn: A callable that returns an instance of
      tff.templates.IterativeProcess with a specific next() signature, to be
      used for performing federated training with gradient augmentation. It
      should take 3 callables as arguments: a callable for getting the
      tff.learning.Model, a callable for getting the client optimizer, and a
      callable for getting the server optimizer. It should return an iterative
      process with a next() method that takes 5 arguments: the server state, the
      federated data for the clients in the round, a list of "augmenting"
      datacenter gradients (which will be summed with the local gradients
      calculated during the client update steps), the weighting to apply to
      the "augmenting" datacenter gradients, and the weighting to apply to the
      client gradients.
    augmenting_datacenter_gradients_option: Which option to use in how to
      compute (at the datacenter) the augmenting gradients that will be sent to
      the clients.

  Returns:
    An instance of _TwoWayGradTransferMixingProcess.
  """

  @tff.tf_computation(tf.string)
  def process_client_data_fn(client_id):
    return client_dataset_processing_fn(
        client_data.dataset_computation(client_id))

  augmenting_iterative_process = augmenting_iterative_process_fn(
      model_fn=tff_model_fn,
      client_optimizer_fn=client_optimizer_fn,
      server_optimizer_fn=server_optimizer_fn)
  # Compose the iterative process with the data processing step; this will take
  # place at the Borg workers (i.e. the clients), resulting in significant
  # speedup.
  augmenting_iterative_process = (
      tff.simulation.compose_dataset_computation_with_iterative_process(
          process_client_data_fn, augmenting_iterative_process))

  datacenter_dataset_iterator = iter(datacenter_dataset_fn().shuffle(
      datacenter_shuffle_buffer, reshuffle_each_iteration=True).repeat().window(
          num_effective_clients_for_training *
          num_examples_per_effective_client))

  datacenter_training = _DatacenterTraining(
      dataset_iter=datacenter_dataset_iterator,
      dataset_processing_fn=datacenter_dataset_processing_fn,
      optimizer=datacenter_optimizer_fn(),
      loss_fn=datacenter_loss_fn,
      metrics_fn=datacenter_metrics_fn)

  return _TwoWayGradTransferMixingProcess(
      keras_model_fn, augmenting_iterative_process, datacenter_training,
      server_optimizer_fn(), datacenter_weight, client_weight,
      augmenting_datacenter_gradients_option)


def _default_mixing_fn(
    client_id,
    client_data,
    datacenter_dataset_fn,
    num_examples_to_augment,
    num_repetitions = 1,
    client_dataset_processing_fn = IDENTITY_PROCESSING_FN,
    datacenter_shuffle_buffer = DEFAULT_DATACENTER_SHUFFLE_BUFFER
):
  """Default mixing strategy for example transfer."""
  augmentation_dataset = tf.data.Dataset.zip(datacenter_dataset_fn().shuffle(
      datacenter_shuffle_buffer, reshuffle_each_iteration=False).take(
          int(num_examples_to_augment /
              num_repetitions))).cache().repeat(num_repetitions)

  return client_dataset_processing_fn(
      client_data.dataset_computation(client_id).concatenate(
          augmentation_dataset).shuffle(1000))


_ExampleTransferMixingFn = Callable[[
    str, tff.simulation.datasets
    .ClientData, _DatasetConstructor, int, int, _DatasetProcessingFn, int
], tf.data.Dataset]
DEFAULT_MIXING_FN = _default_mixing_fn


class _ExampleTransferMixingProcess(MixingProcess):
  """A MixingProcess where client caches are augmented with datacenter data."""

  def __init__(self,
               federated_iterative_process):
    """Constructor for MixingProcess adding datacenter data into device caches.

    When the `next` method of this class is called, a standard FL round of
    FedAvg will take place, but in addition to training against the examples for
    each simulated edge device (i.e., the device's 'cache'), some additional
    examples from the datacenter dataset are added to each cache. If this were
    a real FL training scenario, one could think of this data as also being
    shipped to each participating client alongside the model checkpoint for the
    given round.

    Args:
      federated_iterative_process: A tff.templates.IterativeProcess that has
        already been composed with data processing, so that its `next` method
        expects a list of client ids.
    """
    self._federated_iterative_process = federated_iterative_process

  def initialize(self):
    initial_state = self._federated_iterative_process.initialize()
    return initial_state

  def next(self, prior_state,
           federated_client_ids):

    state, train_metrics = (
        self._federated_iterative_process.next(prior_state,
                                               federated_client_ids))

    deltas = tf.nest.map_structure(
        lambda prev_var, next_var: prev_var - next_var, state.model.trainable,
        prior_state.model.trainable)

    train_metrics['train']['delta_glob_norm'] = tf.linalg.global_norm(
        tf.nest.flatten(deltas))

    return state, train_metrics


def build_mixing_process_with_example_transfer(
    *,
    tff_model_fn,
    client_data,
    datacenter_dataset_fn,
    num_examples_to_augment,
    num_repetitions = 1,
    mixing_fn = DEFAULT_MIXING_FN,
    client_optimizer_fn,
    server_optimizer_fn,
    client_dataset_processing_fn = IDENTITY_PROCESSING_FN,
    datacenter_shuffle_buffer = DEFAULT_DATACENTER_SHUFFLE_BUFFER,
    iterative_process_fn = TFF_LEARN_FED_AVG_FN,
):
  """Factory method for getting a _ExampleTransferMixingProcess.

  This is a method of federated/datacenter data mixing where examples of
  datacenter data are added to the clients' caches, during each round
  of federated training. Thus, the augmented cache is a mix of both client data
  and datacenter data.

  Args:
    tff_model_fn: A callable returning an instance of a tff.learning.Model.
    client_data: A federated dataset for the clients (e.g., edge devices).
    datacenter_dataset_fn: A callable returning the tf.data.Dataset for the
      datacenter data.
    num_examples_to_augment: The total number of datacenter examples to add to
      each client cache. May be all unique (if `num_repetitions` is 1), or may
      involve some repetition (if `num_repetitions` > 1).
    num_repetitions: The number of times to use a unique datacenter example. The
      total number of unique datacenter examples used per client will be
      `num_examples_to_augment` divided by `num_repetitions`. This argument must
      be such that `num_examples_to_augment` is cleanly divisible by it (no
      remainder).
    mixing_fn: A callable defining a customized federated/datacenter data mixing
      strategy. If it is `None`, using the default mixing strategy: appending
      datacenter data to the end of federated dataset and shuffling. A custom
      mixing function should take `client_id`, `client_data`,
      `datacenter_dataset_fn`, `num_examples_to_augment`, `num_repetitions`,
      `client_dataset_processing_fn`, `datacenter_shuffle_buffer` as input args.
      See `DEFAULT_MIXING_FN` for more details on how to implement a mixing fn.
    client_optimizer_fn: A callable returning the optimizer to be used for
      calculating client updates.
    server_optimizer_fn: A callable returning the optimizer to be used for
      calculating server updates.
    client_dataset_processing_fn: A callable for converting the raw (e.g.,
      serialized) unprocessed dataset into a batched, shuffled, processed
      dataset of examples in the format expected by the model.
    datacenter_shuffle_buffer: Size of buffer to use when shuffling the
      datacenter dataset.
    iterative_process_fn: A callable that returns an instance of
      tff.templates.IterativeProcess, to be used for performing the federated
      training. It should take 3 callables as arguments: a callable for getting
        the tff.learning.Model, a callable for getting the client optimizer, and
        a callable for getting the server optimizer.

  Returns:
    An instance of _ExampleTransferMixingProcess.

  Raises:
    ValueError: If `num_examples_to_augment` divided by `num_repetitions` has
      non-zero remainder.
  """
  if num_examples_to_augment % num_repetitions != 0:
    raise ValueError('The num_examples_to_augment must be evenly divisble by '
                     'num_repetitions, but num_examples_to_augment is %d and '
                     'num_repetitions is %d.' %
                     (num_examples_to_augment, num_repetitions))

  @tff.tf_computation(tf.string)
  def process_and_augment_client_data_fn(client_id):
    return mixing_fn(client_id, client_data, datacenter_dataset_fn,
                     num_examples_to_augment, num_repetitions,
                     client_dataset_processing_fn, datacenter_shuffle_buffer)

  federated_iterative_process = iterative_process_fn(
      model_fn=tff_model_fn,
      client_optimizer_fn=client_optimizer_fn,
      server_optimizer_fn=server_optimizer_fn)
  # Compose the iterative process with the data processing step; this will take
  # place at the Borg workers (i.e. the clients), resulting in significant
  # speedup.
  federated_iterative_process = (
      tff.simulation.compose_dataset_computation_with_iterative_process(
          process_and_augment_client_data_fn, federated_iterative_process))

  return _ExampleTransferMixingProcess(federated_iterative_process)
