# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provides iterative process accepting augmenting grads w/ each next() call."""

import collections
import functools
from typing import Any, Callable, List, Optional, OrderedDict, Tuple, Union

import attr
import tensorflow as tf
import tensorflow_federated as tff

# Type aliases.
_ModelConstructor = Callable[[], tff.learning.Model]
_ClientTrainReduceState = Tuple[tf.Tensor, tf.Tensor, Tuple[tf.Tensor]]
_OptimizerConstructor = Callable[[], tf.keras.optimizers.Optimizer]


# Based on method in
# https://github.com/tensorflow/federated/blob/b4ee5a791a137241a715b669b9005836a9b355f2/tensorflow_federated/python/tensorflow_libs/tensor_utils.py#L93
@tf.function
def _zero_all_if_any_non_finite(structure):
  """Zeroes out all entries in input if any are not finite.

  Args:
    structure: A structure supported by tf.nest.

  Returns:
     A tuple (input, 0) if all entries are finite or the structure is empty, or
     a tuple (zeros, 1) if any non-finite entries were found.
  """
  flat = tf.nest.flatten(structure)
  if not flat:
    return (structure, tf.constant(0))
  flat_bools = [tf.reduce_all(tf.math.is_finite(t)) for t in flat]
  all_finite = functools.reduce(tf.logical_and, flat_bools)
  if all_finite:
    return (structure, tf.constant(0))
  else:
    return (tf.nest.map_structure(tf.zeros_like, structure), tf.constant(1))


def _check_callable(target,
                    label = None):
  """Checks target is callable."""
  if not callable(target):
    raise TypeError('Expected {} callable, found non-callable {}.'.format(
        '{} to be'.format(label) if label is not None else 'a', type(target)))


@attr.s(eq=False, frozen=True)
class ClientOutputWithAverageGradients(object):
  """Like ClientOutput, but w/ additional client gradient vectors returned."""
  weights_delta = attr.ib()
  weights_delta_weight = attr.ib()
  model_output = attr.ib()
  optimizer_output = attr.ib(default=None)
  average_gradients = attr.ib(default=None)


@attr.s(eq=False, frozen=True)
class ServerStateWithAverageClientGradients(object):
  """Like ServerState, but w/ additional aggregated client gradient vectors."""
  model = attr.ib()
  optimizer_state = attr.ib()
  delta_aggregate_state = attr.ib()
  model_broadcast_state = attr.ib()
  average_client_gradients = attr.ib()


class ProcessTypeError(Exception):
  """Error raised when a `MeasuredProcess` does not have the correct type signature."""
  pass


def _apply_delta(
    *,
    optimizer,
    model_variables,
    delta,
):
  """Applies `delta` to `model` using `optimizer`."""
  tf.nest.assert_same_structure(delta, model_variables.trainable)
  grads_and_vars = tf.nest.map_structure(
      lambda x, v: (-1.0 * x, v), tf.nest.flatten(delta),
      tf.nest.flatten(model_variables.trainable))
  # Note: this may create variables inside `optimizer`, for example if this is
  # the first usage of Adam or momentum optmizers.
  optimizer.apply_gradients(grads_and_vars)


def _eagerly_create_optimizer_variables(
    *, model_variables,
    optimizer):
  """Forces eager construction of the optimizer variables.

  This code is needed both in `server_init` and `server_update` (to introduce
  variables so we can read their initial values for the initial state).

  Args:
    model_variables: A `tff.learning.ModelWeights` structure of `tf.Variables`.
    optimizer: A `tf.keras.optimizers.Optimizer`.

  Returns:
    A list of optimizer variables.
  """
  delta_tensor_spec = tf.nest.map_structure(
      lambda v: tf.TensorSpec.from_tensor(v.read_value()),
      model_variables.trainable)
  # Trace the function, which forces eager variable creation.
  tf.function(_apply_delta).get_concrete_function(
      optimizer=optimizer,
      model_variables=model_variables,
      delta=delta_tensor_spec)
  return optimizer.variables()


@tff.tf_computation()
def _dict_update(original_dict, new_dict):
  intersection = original_dict.keys() & new_dict.keys()
  if intersection:
    raise ValueError(
        f'Cannot merge dicts; there are colliding dict keys: {intersection}')
  original_dict.update(new_dict)
  return original_dict


# ==============================================================================
# Federated Computations
#
# These constructors setup the system level orchestration logic.
# ==============================================================================


def _build_initialize_computation(
    *,
    model_fn,
    server_optimizer_fn,
    broadcast_process,
    aggregation_process,
):
  """Builds the `initialize` computation for a model delta averaging process.

  Args:
    model_fn: a no-argument callable that constructs and returns a
      `tff.learning.Model`. *Must* construct and return a new model when called.
      Returning captured models from other scopes will raise errors.
    server_optimizer_fn: a no-argument callable that constructs and returns a
      `tf.keras.optimizers.Optimizer`. *Must* construct and return a new
      optimizer when called. Returning captured optimizers from other scopes
      will raise errors.
    broadcast_process: a `tff.templates.MeasuredProcess` to broadcast the global
      model to the clients.
    aggregation_process: a `tff.templates.MeasuredProcess` to aggregate client
      model deltas.

  Returns:
    A `tff.Computation` that initializes the process. The computation takes no
    arguments and returns a `tuple` of global model weights and server state
    with `tff.SERVER` placement.
  """

  @tff.tf_computation
  def server_init(
  ):
    """Returns initialized vars to put in ServerStateWithAverageClientGradients.

    Returns:
      A three `tuple` of `tff.learning.ModelWeights`, a `list` of `tf.Variable`s
      for the global optimizer state, and a structure matching the `trainable`
      portion of the `tff.learning.ModelWeights` with all values set to zero.
    """
    model_variables = tff.learning.ModelWeights.from_model(model_fn())
    optimizer = server_optimizer_fn()
    # We must force variable creation for momentum and adaptive optimizers.
    optimizer_vars = _eagerly_create_optimizer_variables(
        model_variables=model_variables, optimizer=optimizer)
    # This structure matches the shape of the gradients that will be calculated
    # from the model, but with all values set to zero.
    zero_gradients = tf.nest.map_structure(tf.zeros_like,
                                           model_variables.trainable)
    return model_variables, optimizer_vars, zero_gradients

  @tff.federated_computation()
  def initialize_computation():
    """Orchestration logic for server model initialization."""
    initial_global_model, initial_global_optimizer_state, initial_gradients = (
        tff.federated_eval(server_init, tff.SERVER))
    return tff.federated_zip(
        ServerStateWithAverageClientGradients(
            model=initial_global_model,
            optimizer_state=initial_global_optimizer_state,
            delta_aggregate_state=aggregation_process.initialize(),
            model_broadcast_state=broadcast_process.initialize(),
            average_client_gradients=initial_gradients))

  return initialize_computation


@tf.function
def _normalize_vector(vector):
  norm = tf.linalg.global_norm(tf.nest.flatten(vector))
  # If vector is identically zeros, just return it directly. Otherwise, give it
  # unit length by dividing each component by global L2 norm.
  normalized_vector = tf.cond(
      tf.math.equal(norm, 0.0), lambda: vector,
      lambda: tf.nest.map_structure(lambda a: a / norm, vector))
  return normalized_vector


@tf.function
def _dot_product(vector_a,
                 vector_b):
  cum_sum = 0.0
  for a, b in zip(vector_a, vector_b):
    cum_sum += tf.math.reduce_sum(tf.math.multiply(a, b))
  return cum_sum


def _build_one_round_computation(
    *,
    model_fn,
    server_optimizer_fn,
    client_optimizer_fn,
    client_weight_fn = None,
    broadcast_process,
    aggregation_process,
    metrics_aggregator
):
  """Builds `next`; along with state and data it also takes augmenting gradient.

  Args:
    model_fn: a no-argument callable that constructs and returns a
      `tff.learning.Model`. *Must* construct and return a new model when called.
      Returning captured models from other scopes will raise errors.
    server_optimizer_fn: a no-argument callable that constructs and returns a
      `tf.keras.optimizers.Optimizer`. *Must* construct and return a new
      optimizer when called. Returning captured optimizers from other scopes
      will raise errors.
    client_optimizer_fn: a no-argument callable that constructs and returns a
      `tf.keras.optimizers.Optimizer` used in calculating each client update.
      *Must* construct and return a new optimizer when called. Returning
      captured optimizers from other scopes will raise errors.
    client_weight_fn: An optional callable that takes the output of
      `model.report_local_unfinalized_metrics` and returns a tensor that
      provides the weight in the federated average of model deltas. If not
      provided, the default is the total number of examples processed on device.
    broadcast_process: a `tff.templates.MeasuredProcess` to broadcast the global
      model to the clients.
    aggregation_process: a `tff.templates.MeasuredProcess` to aggregate client
      model deltas.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a federated TFF computation of the following type signature
      `local_unfinalized_metrics@CLIENTS -> aggregated_metrics@SERVER`.

  Returns:
    A `tff.Computation` that initializes the process. The computation takes
    a tuple of `(ServerStateWithAverageClientGradients@SERVER,
    tf.data.Dataset@CLIENTS, augmenting_gradients@SERVER)` argument, and returns
    a tuple of `(ServerStateWithAverageClientGradients@SERVER, metrics@SERVER)`.
  """
  with tf.Graph().as_default():
    whimsy_model_for_metadata = model_fn()
    model_weights_type = tff.learning.framework.weights_type_from_model(
        whimsy_model_for_metadata)
    unfinalized_metrics_type = tff.types.type_from_tensors(
        whimsy_model_for_metadata.report_local_unfinalized_metrics())
    federated_metrics_aggregation = metrics_aggregator(
        whimsy_model_for_metadata.metric_finalizers(), unfinalized_metrics_type)

    whimsy_optimizer = server_optimizer_fn()
    # We must force variable creation for momentum and adaptive optimizers.
    _eagerly_create_optimizer_variables(
        model_variables=tff.learning.ModelWeights.from_model(
            whimsy_model_for_metadata),
        optimizer=whimsy_optimizer)
    optimizer_variable_type = tff.framework.type_from_tensors(
        whimsy_optimizer.variables())

  @tff.tf_computation(model_weights_type, model_weights_type.trainable,
                      optimizer_variable_type)
  @tf.function
  def server_update(global_model, mean_model_delta, optimizer_state):
    """Updates the global model with the mean model update from clients."""
    with tf.init_scope():
      # Create a structure of variables that the server optimizer can update.
      model_variables = tf.nest.map_structure(
          lambda t: tf.Variable(initial_value=tf.zeros(t.shape, t.dtype)),
          global_model)
      optimizer = server_optimizer_fn()
      # We must force variable creation for momentum and adaptive optimizers.
      _eagerly_create_optimizer_variables(
          model_variables=model_variables, optimizer=optimizer)
    optimizer_variables = optimizer.variables()
    # Set the variables to the current global model, the optimizer will
    # update these variables.
    tf.nest.map_structure(lambda a, b: a.assign(b),
                          (model_variables, optimizer_variables),
                          (global_model, optimizer_state))
    # We might have a NaN value e.g. if all of the clients processed had no
    # data, so the denominator in the federated_mean is zero. If we see any
    # NaNs, zero out the whole update.
    finite_weights_delta, _ = _zero_all_if_any_non_finite(mean_model_delta)
    # Update the global model variables with the delta as a pseudo-gradient.
    _apply_delta(
        optimizer=optimizer,
        model_variables=model_variables,
        delta=finite_weights_delta)
    return model_variables, optimizer_variables

  dataset_type = tff.SequenceType(whimsy_model_for_metadata.input_spec)

  @tff.tf_computation(dataset_type, model_weights_type,
                      model_weights_type.trainable,
                      tff.TensorType(dtype=tf.float32),
                      tff.TensorType(dtype=tf.float32))
  @tf.function
  def _compute_local_training_and_client_delta(
      dataset,
      initial_model_weights,
      augmenting_gradients,
      augmenting_gradient_weight, client_gradient_weight):
    """Performs client optimization, augmenting gradients with additional term.

    The is exactly like the client local computation of FedAvg, except for the
    addition of the `augmenting_gradients` to whatever gradients are locally
    calculated for the loss/local data/local model weights. (I.e., if
    `augmenting_gradients` are vectors of zeros, then vanilla FedAvg is
    recovered).

    Note that the `augmenting_gradients` are summed into the local gradients 'as
    is'; there's no normalization w.r.t. the local gradients of any sort. So
    it's important to understand and reconcile any aspects of the loss function
    that would lead to vast differences in local gradient magnitude (e.g., if
    the loss function was a sum instead of an average, and there were batches of
    different sizes being processed).

    Args:
      dataset: a `tf.data.Dataset` that provides training examples.
      initial_model_weights: a `tff.learning.ModelWeights` containing the
        starting weights.
      augmenting_gradients: A list (in the same format as the trainable weights
        of the model) of augmenting gradients. These augmenting gradients will
        be summed with the stochastic gradients calculated via derivation of the
        tff.learning.Model's loss function applied to batches of the client's
        data (the contents of `dataset`). The summed gradients are then used in
        the client's optimizer for taking a step and updating model weights.
     augmenting_gradient_weight: A scalar weight to apply to the augmenting
       gradients when summing them with the local client gradients. This should
       be chosen in concert with the `client_gradient_weight`, to properly
       weight the two gradient terms relative to each other as desired.
     client_gradient_weight: A scalar weight to apply to the local client
       gradients when summing them with the augmenting gradients. This should be
       chosen in concert with the `augmenting_gradient_weight`, to properly
       weight the two gradient terms relative to each other as desired.

    Returns:
      A `ClientOutputWithAverageGradients` structure. This data structure
      includes a field for the average of the client gradients over the multiple
      steps of client training. This can be aggregated with the client gradients
      from other clients to get an overall average of the gradient w.r.t. client
      data, which can be used for augmentation during centralized training with
      datacenter data (i.e., 2-way gradient transfer, i.e. 'meta-SCAFFOLD' a la
      https://arxiv.org/pdf/1910.06378.pdf, Equation 4, Option II).
    """
    tf.nest.assert_same_structure(initial_model_weights.trainable,
                                  augmenting_gradients)
    with tf.init_scope():
      model = model_fn()
      optimizer = client_optimizer_fn()

    model_weights = tff.learning.ModelWeights.from_model(model)
    tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                          initial_model_weights)

    grad_wt = client_gradient_weight
    aug_wt = augmenting_gradient_weight

    def empty_client_gradients(model_weights):
      empty_client_gradients = tf.nest.map_structure(tf.zeros_like,
                                                     model_weights.trainable)
      zero_examples = tf.constant(0, tf.float32)
      return empty_client_gradients, zero_examples

    def compute_client_gradients(
        batch, model,
        model_weights
    ):
      """Computes gradients of loss w.r.t. model params, using client batch."""
      with tf.GradientTape() as tape:
        output = model.forward_pass(batch, training=True)

      client_gradients = tape.gradient(output.loss, model_weights.trainable)
      for idx, grad in enumerate(client_gradients):
        if isinstance(grad, tf.IndexedSlices):
          client_gradients[idx] = tf.convert_to_tensor(grad)

      if output.num_examples is None:
        num_examples = tf.shape(output.predictions, out_type=tf.float32)[0]
      else:
        num_examples = tf.cast(output.num_examples, tf.float32)

      return client_gradients, num_examples

    def reduce_fn(state,
                  batch):
      """Trains `tff.learning.Model` on client batch; updates gradients sum."""
      num_steps, num_examples_sum, client_gradients_sum = state

      client_gradients, num_new_examples = compute_client_gradients(
          batch, model, model_weights)

      # Add in the augmenting gradients to the original gradients calculated by
      # minimizing the training loss, before applying via the optimizer.
      gradients = tf.nest.map_structure(
          lambda grad, aug: grad_wt * grad + aug_wt * aug, client_gradients,
          augmenting_gradients)
      optimizer.apply_gradients(zip(gradients, model_weights.trainable))

      num_steps = num_steps + tf.constant(1)
      num_examples_sum = num_examples_sum + num_new_examples

      # This assumes that `client_gradients` is a mean of per-example gradients
      # across the batch, as opposed to a sum of per-example gradients. I.e.,
      # that the reduction of the loss was defined as a mean (as opposed to a
      # sum). This is in keeping with the description of the `loss` attribute of
      # the `tff.learning.BatchOutput` object (
      # https://www.tensorflow.org/federated/api_docs/python/tff/learning/BatchOutput#attributes),
      # but it should be noted that there's nothing structurally in TFF that
      # enforces that the loss for a `tff.learning.Model` must be a mean.
      client_gradients_sum = tf.nest.map_structure(
          lambda grad_sum, grad: grad_sum + num_new_examples * grad,
          client_gradients_sum, tuple(client_gradients))

      return num_steps, num_examples_sum, client_gradients_sum

    # Make copy of first batch in the client dataset, use to compute gradients
    # on the same batch of training examples both before and after the
    # training/reduce (to calculate some metrics about changes in client
    # gradients). We need to use an iterator, `get_next_as_optional()`, and
    # tf.cond statements to handle the scenario where `dataset` is empty.
    # Note: we access the iterator via `dataset.__iter__()` as opposed to
    # `iter(dataset)`, b/c the latter causes pytype checking problems.
    # Specifically, the former has a pytype return type of `tf.data.Iterator`,
    # while the latter has a pytype return type of a generic Iterator, but the
    # `get_next_as_optional()` is only associated with the former class.
    first_batch = dataset.__iter__().get_next_as_optional()

    start_client_gradients, _ = tf.cond(
        first_batch.has_value(),
        lambda: compute_client_gradients(
            first_batch.get_value(), model, model_weights),
        lambda: empty_client_gradients(model_weights))

    zero_gradients = tuple(
        tf.nest.map_structure(tf.zeros_like, model_weights.trainable))
    initial_state = (tf.constant(0), tf.zeros(shape=[],
                                              dtype=tf.float32), zero_gradients)
    num_steps, num_examples_sum, client_gradients_sum = dataset.reduce(
        initial_state, reduce_fn)
    # If `num_examples` is zero, then the average client gradients should just
    # be zeroes. Otherwise, divide the sum of the client gradients by
    # `num_examples` to get the average client gradients.
    average_client_gradients = tf.cond(
        tf.math.equal(num_examples_sum, 0.0),
        lambda: zero_gradients,
        lambda: tf.nest.map_structure(
            lambda g: (1.0 / num_examples_sum) * g, client_gradients_sum))
    average_client_gradients = list(average_client_gradients)

    end_client_gradients, _ = tf.cond(
        first_batch.has_value(),
        lambda: compute_client_gradients(
            first_batch.get_value(), model, model_weights),
        lambda: empty_client_gradients(model_weights))

    # Compute some global norms of the client and total gradients, as a measure
    # of the magnitude of the gradients at start and end of client training
    # during a round.
    start_client_gradients_glob_norm = tf.linalg.global_norm(
        tf.nest.flatten(start_client_gradients))
    start_total_gradients = tf.nest.map_structure(
        lambda grad, aug: grad_wt * grad + aug_wt * aug, start_client_gradients,
        augmenting_gradients)
    start_total_gradients_glob_norm = tf.linalg.global_norm(
        tf.nest.flatten(start_total_gradients))

    end_client_gradients_glob_norm = tf.linalg.global_norm(
        tf.nest.flatten(end_client_gradients))
    end_total_gradients = tf.nest.map_structure(
        lambda grad, aug: grad_wt * grad + aug_wt * aug, end_client_gradients,
        augmenting_gradients)
    end_total_gradients_glob_norm = tf.linalg.global_norm(
        tf.nest.flatten(end_total_gradients))

    weights_delta = tf.nest.map_structure(tf.subtract, model_weights.trainable,
                                          initial_model_weights.trainable)
    model_output = model.report_local_unfinalized_metrics()
    # Any counter metrics that are part of the TFF model's unfinalized metrics
    # (e.g., 'num_batches' and 'num_examples' if using a Keras-derived TFF
    # model) will be erroneous (overcounted). This is because the TFF model will
    # update its metrics anytime its forward pass is used (not just during
    # training via reduce_fn), and we use the model twice apart from training
    # (to calculate gradients, once before and once after training in the
    # reduce_fn). Here we update some counter metrics so that they only reflect
    # the amount of data used during actual training/optimization steps.
    if 'num_batches' in model_output:
      model_output['num_batches'][0] = tf.cast(
          num_steps, model_output['num_batches'][0].dtype)
    if 'num_examples' in model_output:
      model_output['num_examples'][0] = tf.cast(
          num_examples_sum, model_output['num_examples'][0].dtype)

    weights_delta, has_non_finite_delta = (
        _zero_all_if_any_non_finite(weights_delta))
    # Zero out the weight if there are any non-finite values.
    if has_non_finite_delta > 0:
      weights_delta_weight = tf.constant(0.0)
    elif client_weight_fn is None:
      weights_delta_weight = num_examples_sum
    else:
      weights_delta_weight = client_weight_fn(model_output)

    gradients_difference = tf.nest.map_structure(tf.subtract,
                                                 end_client_gradients,
                                                 start_client_gradients)

    weights_delta_glob_norm = tf.linalg.global_norm(
        tf.nest.flatten(weights_delta))
    lipschitz_smoothness = tf.cond(
        weights_delta_glob_norm > 0.0,
        lambda: client_gradient_weight * tf.linalg.global_norm(
            tf.nest.flatten(gradients_difference)) / weights_delta_glob_norm,
        lambda: 0.0)

    optimizer_output = collections.OrderedDict(
        num_examples=tf.cast(num_examples_sum, tf.int64),
        start_client_gradients_glob_norm=start_client_gradients_glob_norm,
        start_total_gradients_glob_norm=start_total_gradients_glob_norm,
        end_client_gradients_glob_norm=end_client_gradients_glob_norm,
        end_total_gradients_glob_norm=end_total_gradients_glob_norm,
        lipschitz_smoothness=lipschitz_smoothness)
    return ClientOutputWithAverageGradients(weights_delta, weights_delta_weight,
                                            model_output, optimizer_output,
                                            average_client_gradients)

  broadcast_state = broadcast_process.initialize.type_signature.result.member
  aggregation_state = aggregation_process.initialize.type_signature.result.member

  server_state_type = ServerStateWithAverageClientGradients(
      model=model_weights_type,
      optimizer_state=optimizer_variable_type,
      delta_aggregate_state=aggregation_state,
      model_broadcast_state=broadcast_state,
      average_client_gradients=model_weights_type.trainable)

  average_client_gradients_aggregation_process = tff.aggregators.MeanFactory(
  ).create(model_weights_type.trainable, tff.TensorType(tf.float32))
  average_client_gradients_aggregation_state = (
      average_client_gradients_aggregation_process.initialize())

  @tff.tf_computation()
  def _take_global_norm(tensors):
    return tf.linalg.global_norm(tf.nest.flatten(tensors))

  @tff.federated_computation(
      tff.type_at_server(server_state_type), tff.type_at_clients(dataset_type),
      tff.type_at_server(model_weights_type.trainable),
      tff.type_at_server(tff.TensorType(dtype=tf.float32)),
      tff.type_at_server(tff.TensorType(dtype=tf.float32)))
  def one_round_computation(server_state, federated_dataset,
                            augmenting_gradients, augmenting_gradient_weight,
                            client_gradient_weight):
    """Orchestration logic for one round of optimization.

    Args:
      server_state: a `ServerStateWithAverageClientGradients` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.
      augmenting_gradients: A list (in the same format as the trainable weights
        of the model) of augmenting gradients. These augmenting gradients will
        be summed with the stochastic gradients calculated via derivation of the
        tff.learning.Model's loss function applied to batches of the client's
        data (the contents of `federated_dataset`). The summed gradients are
        then used in the client's optimizer for taking a step and updating model
        weights.
     augmenting_gradient_weight: A scalar weight to apply to the augmenting
       gradients when summing them with the local client gradients. This should
       be chosen in concert with the `client_gradient_weight`, to properly
       weight the two gradient terms relative to each other as desired.
     client_gradient_weight: A scalar weight to apply to the local client
       gradients when summing them with the augmenting gradients. This should be
       chosen in concert with the `augmenting_gradient_weight`, to properly
       weight the two gradient terms relative to each other as desired.

    Returns:
      A tuple of updated `ServerStateWithAverageClientGradients` and aggregated
      metrics, both having `tff.SERVER` placement. The server state's
      `average_client_gradients` field contains a weighted aggregation of the
      average gradients computed against the client data (averaged over all
      steps of client optimization).
    """
    broadcast_output = broadcast_process.next(
        server_state.model_broadcast_state, server_state.model)
    client_outputs = tff.federated_map(
        _compute_local_training_and_client_delta,
        (federated_dataset, broadcast_output.result,
         tff.federated_broadcast(augmenting_gradients),
         tff.federated_broadcast(augmenting_gradient_weight),
         tff.federated_broadcast(client_gradient_weight)))
    if aggregation_process.is_weighted:  # pytype: disable=attribute-error  # gen-stub-imports
      aggregation_output = aggregation_process.next(
          server_state.delta_aggregate_state, client_outputs.weights_delta,
          client_outputs.weights_delta_weight)
    else:
      aggregation_output = aggregation_process.next(
          server_state.delta_aggregate_state, client_outputs.weights_delta)
    new_global_model, new_optimizer_state = tff.federated_map(
        server_update, (server_state.model, aggregation_output.result,
                        server_state.optimizer_state))

    # Do a weighted aggregation of the average gradients on the clients.
    average_client_gradients_aggregation_output = (
        average_client_gradients_aggregation_process.next(
            tff.federated_value(average_client_gradients_aggregation_state,
                                tff.SERVER), client_outputs.average_gradients,
            client_outputs.weights_delta_weight))
    aggregated_average_client_gradients = (
        average_client_gradients_aggregation_output.result)

    new_server_state = tff.federated_zip(
        ServerStateWithAverageClientGradients(
            new_global_model, new_optimizer_state, aggregation_output.state,
            broadcast_output.state, aggregated_average_client_gradients))

    @tff.tf_computation
    def _square_value(tensor_value):
      """Computes the square of a tensor."""
      return tensor_value**2

    @tff.tf_computation
    def _multiply_by_gradient_weight_squared(tensor_value, gradient_weight):
      return gradient_weight**2 * tensor_value

    @tff.tf_computation
    def _compute_biased_sampled_variance(client_gradients_norm_squared_average,
                                         client_gradients_average_norm_squared):
      """Get biased sample variance of client gradients; sample is the cohort."""
      return (client_gradients_norm_squared_average -
              client_gradients_average_norm_squared)

    @tff.tf_computation
    def _compute_unbiased_sampled_variance(weights_sum, weights_squared_sum,
                                           biased_sample_variance):
      """Get unbiased sample variance of client gradients."""
      # The Bessel correction term is used to get unbiased sample variance from
      # the biased sample variance.
      bessel_correction = tf.math.divide_no_nan(
          weights_sum**2, weights_sum**2 - weights_squared_sum)
      return bessel_correction * biased_sample_variance

    client_gradients_norm_squared_average = tff.federated_mean(
        tff.federated_map(
            _square_value,
            client_outputs.optimizer_output.start_client_gradients_glob_norm))

    weighted_client_gradients_norm_squared_average = tff.federated_map(
        _multiply_by_gradient_weight_squared,
        (client_gradients_norm_squared_average, client_gradient_weight))

    client_gradients_average_glob_norm = tff.federated_map(
        _take_global_norm, aggregated_average_client_gradients)
    client_gradients_average_norm_squared = tff.federated_map(
        _square_value, client_gradients_average_glob_norm)

    weighted_client_gradients_average_norm_squared = tff.federated_map(
        _multiply_by_gradient_weight_squared,
        (client_gradients_average_norm_squared, client_gradient_weight))

    client_gradients_biased_sample_variance = tff.federated_map(
        _compute_biased_sampled_variance,
        (client_gradients_norm_squared_average,
         client_gradients_average_norm_squared))

    client_gradients_biased_sample_variance_weighted = tff.federated_map(
        _compute_biased_sampled_variance,
        (weighted_client_gradients_norm_squared_average,
         weighted_client_gradients_average_norm_squared))

    weights_sum = tff.federated_sum(client_outputs.weights_delta_weight)
    weights_squared_sum = tff.federated_sum(
        tff.federated_map(_square_value, client_outputs.weights_delta_weight))

    client_gradients_unbiased_sample_variance = tff.federated_map(
        _compute_unbiased_sampled_variance,
        (weights_sum, weights_squared_sum,
         client_gradients_biased_sample_variance))

    client_gradients_unbiased_sample_variance_weighted = tff.federated_map(
        _compute_unbiased_sampled_variance,
        (weights_sum, weights_squared_sum,
         client_gradients_biased_sample_variance_weighted))

    augmenting_gradients_glob_norm = tff.federated_map(_take_global_norm,
                                                       augmenting_gradients)

    augmenting_gradients_norm_squared = tff.federated_map(
        _square_value, augmenting_gradients_glob_norm)

    start_total_gradients_glob_norm = tff.federated_mean(
        client_outputs.optimizer_output.start_total_gradients_glob_norm)

    @tff.tf_computation
    def _compute_weighted_total_gradients_norm(average_client_gradients,
                                               client_gradient_weight,
                                               augmenting_gradients,
                                               augmenting_gradient_weight):
      return client_gradient_weight * tf.linalg.global_norm(
          average_client_gradients
      ) + augmenting_gradient_weight * tf.linalg.global_norm(
          augmenting_gradients)

    total_gradients_glob_norm = tff.federated_map(
        _compute_weighted_total_gradients_norm,
        (aggregated_average_client_gradients, client_gradient_weight,
         augmenting_gradients, augmenting_gradient_weight))

    total_gradients_norm_squared = tff.federated_map(_square_value,
                                                     total_gradients_glob_norm)

    @tff.tf_computation
    def _calculate_bounded_grad_dissimilarity_difference_unweighted(
        client_gradients_norm_squared, augmenting_gradients_norm_squared,
        total_gradients_norm_squared):
      return _calculate_bounded_grad_dissimilarity_difference(
          0.5, client_gradients_norm_squared, 0.5,
          augmenting_gradients_norm_squared, total_gradients_norm_squared)

    @tff.tf_computation
    def _calculate_bounded_grad_dissimilarity_difference(
        client_gradient_weight, client_gradients_norm_squared,
        augmenting_gradient_weight, augmenting_gradients_norm_squared,
        total_gradients_norm_squared):
      return (client_gradient_weight * client_gradients_norm_squared +
              augmenting_gradient_weight * augmenting_gradients_norm_squared -
              total_gradients_norm_squared)

    @tff.tf_computation
    def _calculate_bounded_grad_dissimilarity_ratio_unweighted(
        client_gradients_norm_squared, augmenting_gradients_norm_squared,
        total_gradients_norm_squared):
      return _calculate_bounded_grad_dissimilarity_ratio(
          0.5, client_gradients_norm_squared, 0.5,
          augmenting_gradients_norm_squared, total_gradients_norm_squared)

    @tff.tf_computation
    def _calculate_bounded_grad_dissimilarity_ratio(
        client_gradient_weight, client_gradients_norm_squared,
        augmenting_gradient_weight, augmenting_gradients_norm_squared,
        total_gradients_norm_squared):
      return (client_gradient_weight * client_gradients_norm_squared +
              augmenting_gradient_weight *
              augmenting_gradients_norm_squared) / total_gradients_norm_squared

    bgd_difference_weighted = tff.federated_map(
        _calculate_bounded_grad_dissimilarity_difference,
        (client_gradient_weight, client_gradients_average_norm_squared,
         augmenting_gradient_weight, augmenting_gradients_norm_squared,
         total_gradients_norm_squared))
    bgd_ratio_weighted = tff.federated_map(
        _calculate_bounded_grad_dissimilarity_ratio,
        (client_gradient_weight, client_gradients_average_norm_squared,
         augmenting_gradient_weight, augmenting_gradients_norm_squared,
         total_gradients_norm_squared))
    bgd_difference_unweighted = tff.federated_map(
        _calculate_bounded_grad_dissimilarity_difference_unweighted,
        (client_gradients_average_norm_squared,
         augmenting_gradients_norm_squared, total_gradients_norm_squared))
    bgd_ratio_unweighted = tff.federated_map(
        _calculate_bounded_grad_dissimilarity_ratio_unweighted,
        (client_gradients_average_norm_squared,
         augmenting_gradients_norm_squared, total_gradients_norm_squared))

    aggregated_outputs = federated_metrics_aggregation(
        client_outputs.model_output)
    additional_metrics_dict = collections.OrderedDict(
        # We now add several gradient-related metrics to the measurements output
        # by the iterative process...
        # About the relative weighting we apply to the respective gradients:
        augmenting_gradient_weight=augmenting_gradient_weight,
        client_gradient_weight=client_gradient_weight,
        # About the augmenting gradients:
        augmenting_gradients_glob_norm=augmenting_gradients_glob_norm,
        augmenting_gradients_norm_squared=augmenting_gradients_norm_squared,
        # About the local client and total gradients at start and end of local
        # client descent:
        start_client_gradients_glob_norm=tff.federated_mean(
            client_outputs.optimizer_output.start_client_gradients_glob_norm),
        end_client_gradients_glob_norm=tff.federated_mean(
            client_outputs.optimizer_output.end_client_gradients_glob_norm),
        start_total_gradients_glob_norm=start_total_gradients_glob_norm,
        end_total_gradients_glob_norm=tff.federated_mean(
            client_outputs.optimizer_output.end_total_gradients_glob_norm),
        # About the average client and total gradients:
        client_gradients_average_glob_norm=client_gradients_average_glob_norm,
        client_gradients_average_norm_squared=client_gradients_average_norm_squared,
        total_gradients_glob_norm=total_gradients_glob_norm,
        total_gradients_norm_squared=total_gradients_norm_squared,
        # About the variance of the client gradients (both based on just on the
        # loss function only, and also including the `client_gradient_weight` as
        # a weighting premultiplier):
        client_gradients_biased_sample_variance=(
            client_gradients_biased_sample_variance),
        client_gradients_unbiased_sample_variance=(
            client_gradients_unbiased_sample_variance),
        client_gradients_biased_sample_variance_weighted=(
            client_gradients_biased_sample_variance_weighted),
        client_gradients_unbiased_sample_variance_weighted=(
            client_gradients_unbiased_sample_variance_weighted),
        # About the bounded gradient dissimilarity between the client gradients
        # and the augmenting gradients:
        bgd_difference_weighted=bgd_difference_weighted,
        bgd_ratio_weighted=bgd_ratio_weighted,
        bgd_difference_unweighted=bgd_difference_unweighted,
        bgd_ratio_unweighted=bgd_ratio_unweighted,
        # About the Lipschitz smoothness of the client loss functions:
        lipschitz_smoothness_avg=tff.federated_mean(
            client_outputs.optimizer_output.lipschitz_smoothness),
        lipschitz_smoothness_min=tff.aggregators.federated_min(
            client_outputs.optimizer_output.lipschitz_smoothness),
        lipschitz_smoothness_max=tff.aggregators.federated_max(
            client_outputs.optimizer_output.lipschitz_smoothness),
    )
    aggregated_outputs = tff.federated_map(
        _dict_update, (aggregated_outputs, additional_metrics_dict))

    optimizer_outputs = collections.OrderedDict(
        num_examples=tff.federated_sum(
            client_outputs.optimizer_output.num_examples),)

    measurements = tff.federated_zip(
        collections.OrderedDict(
            broadcast=broadcast_output.measurements,
            aggregation=aggregation_output.measurements,
            train=aggregated_outputs,
            stat=optimizer_outputs))
    return new_server_state, measurements

  return one_round_computation


def _is_valid_stateful_process(process):
  """Validates whether a `MeasuredProcess` is valid for model delta processes.

  Valid processes must have `state` and `measurements` placed on the server.
  This method is intended to be used with additional validation on the non-state
  parameters, inputs and result.

  Args:
    process: A measured process to validate.

  Returns:
    `True` iff `process` is a valid stateful process, `False` otherwise.
  """
  init_type = process.initialize.type_signature
  next_type = process.next.type_signature
  return (init_type.result.placement is tff.SERVER and
          next_type.parameter[0].placement is tff.SERVER and
          next_type.result.state.placement is tff.SERVER and
          next_type.result.measurements.placement is tff.SERVER)


def _is_valid_broadcast_process(process):
  """Validates a `MeasuredProcess` adheres to the broadcast signature.

  A valid broadcast process is one whose argument is placed at `SERVER` and
  whose output is placed at `CLIENTS`.

  Args:
    process: A measured process to validate.

  Returns:
    `True` iff the process is a validate broadcast process, otherwise `False`.
  """
  next_type = process.next.type_signature
  return (isinstance(process, tff.templates.MeasuredProcess) and
          _is_valid_stateful_process(process) and
          next_type.parameter[1].placement is tff.SERVER and
          next_type.result.result.placement is tff.CLIENTS)


# ============================================================================


@tff.federated_computation()
def _empty_server_initialization():
  return tff.federated_value((), tff.SERVER)


def build_federated_averaging_process_with_gradient_transfer(
    model_fn,
    client_optimizer_fn,
    server_optimizer_fn,
    client_weight_fn = None,
    *,
    broadcast_process = None,
    model_update_aggregation_factory = None,
    metrics_aggregator = tff.learning.metrics.sum_then_finalize
):
  """Constructs `tff.templates.IterativeProcess` for FedAvg w/ augmenting grads.

  The is exactly like FedAvg, except for the addition of `augmenting gradients`
  to whatever gradients are locally calculated for the loss/local data/local
  model weights during client updates. (I.e., if `augmenting_gradients` are
  vectors of zeros, then vanilla FedAvg is recovered.) This iterative process
  is useful e.g. if there is a distribution mismatch between the data
  distribution of the federated training data on clients, and the data
  distribution to be targeted at inference time. If you have examples (at the
  server) of the data that is missing in the client devices, you can calculate
  gradients from them using the latest model checkpoint, and then pass this as
  'augmenting gradients' for a given round, to be used to locally modify the
  direction of descent that takes places during client optimization. Just
  performed on its own (without the additional piece described in the next
  paragraph), this is referred to as '1-way gradient transfer', b/c we are only
  sharing gradient information from datacenter to clients.

  The server state for this iterative process also contains an
  `average_client_gradients` field that is calculated and populated during a
  round. This can be used for additional augmentation 'the other way', i.e.,
  augmenting a centralized training process on datacenter data with a term that
  represents the gradient w.r.t. client data, to modify the direction of descent
  during datacenter optimization. This is referred to as '2-way gradient
  transfer', b/c we are sharing gradient information from datacenter to inform
  client descent as well as sharing gradient information from clients to inform
  datacenter descent. This could also be considered a 'meta-SCAFFOLD' approach
  b/c one would be doing something akin to Eqns 3-5 in the SCAFFOLD paper
  (https://arxiv.org/pdf/1910.06378.pdf), but instead of applying to non-IID
  clients, applying to a datacenter data distribution and client data
  distribution that differ.

  See mixing_process_lib.py for high-level APIs that provide 1-way and 2-way
  gradient transfer, utilizing this iterative process under the hood.

  Note: We pass in functions rather than constructed objects so we can ensure
  any variables or ops created in constructors are placed in the correct graph.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg function that constructs and returns a
      `tf.keras.optimizers.Optimizer` used in calculating each client update.
      *Must* construct and return a new optimizer when called. Returning
      captured optimizers from other scopes will raise errors.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`. The `apply_gradients` method of this
      optimizer is used to apply client updates to the server model.
    client_weight_fn: An optional callable that takes the output of
      `model.report_local_unfinalized_metrics` and returns a tensor that
      provides the weight in the federated average of model deltas. If not
      provided, the default is the total number of examples processed on device.
    broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the
      model weights on the server to the clients. It must support the signature
      `(input_values@SERVER -> output_values@CLIENT)`. If set to default None,
      the server model is broadcast to the clients using the default
      tff.federated_broadcast.
    model_update_aggregation_factory: An optional
      `tff.aggregators.WeightedAggregationFactory` that contstructs
      `tff.templates.AggregationProcess` for aggregating the client model
      updates on the server. If `None`, uses a default constructed
      `tff.aggregators.MeanFactory`, creating a stateless mean aggregation.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a federated TFF computation of the following type signature
      `local_unfinalized_metrics@CLIENTS -> aggregated_metrics@SERVER`. Default
      is `tff.learning.metrics.sum_then_finalize`, which returns a federated TFF
      computation that sums the unfinalized metrics from `CLIENTS`, and then
      applies the corresponding metric finalizers at `SERVER`.

  Returns:
    A `tff.templates.IterativeProcess`.

  Raises:
    ProcessTypeError: if `broadcast_process` does not conform to the signature
      of broadcast (SERVER->CLIENTS).
  """
  _check_callable(model_fn)
  _check_callable(client_optimizer_fn)
  _check_callable(server_optimizer_fn)
  if client_weight_fn is not None:
    _check_callable(client_weight_fn)

  model_weights_type = tff.learning.framework.weights_type_from_model(model_fn)

  if broadcast_process is None:
    broadcast_process = (
        tff.learning.framework.build_stateless_broadcaster(
            model_weights_type=model_weights_type))
  if not _is_valid_broadcast_process(broadcast_process):
    raise ProcessTypeError(
        'broadcast_process type signature does not conform to expected '
        'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).'
        ' Got: {t}'.format(t=broadcast_process.next.type_signature))

  if model_update_aggregation_factory is None:
    model_update_aggregation_factory = tff.aggregators.MeanFactory()
  if isinstance(model_update_aggregation_factory,
                tff.aggregators.WeightedAggregationFactory):
    aggregation_process = model_update_aggregation_factory.create(
        model_weights_type.trainable, tff.TensorType(tf.float32))
  else:
    aggregation_process = model_update_aggregation_factory.create(
        model_weights_type.trainable)
  process_signature = aggregation_process.next.type_signature
  input_client_value_type = process_signature.parameter[1]
  result_server_value_type = process_signature.result[1]
  if input_client_value_type.member != result_server_value_type.member:
    raise TypeError('`model_update_aggregation_factory` does not produce a '
                    'compatible `AggregationProcess`. The processes must '
                    'retain the type structure of the inputs on the '
                    f'server, but got {input_client_value_type.member} != '
                    f'{result_server_value_type.member}.')

  initialize_computation = _build_initialize_computation(
      model_fn=model_fn,
      server_optimizer_fn=server_optimizer_fn,
      broadcast_process=broadcast_process,
      aggregation_process=aggregation_process)

  run_one_round_computation = _build_one_round_computation(
      model_fn=model_fn,
      server_optimizer_fn=server_optimizer_fn,
      client_optimizer_fn=client_optimizer_fn,
      client_weight_fn=client_weight_fn,
      broadcast_process=broadcast_process,
      aggregation_process=aggregation_process,
      metrics_aggregator=metrics_aggregator)

  return tff.templates.IterativeProcess(
      initialize_fn=initialize_computation, next_fn=run_one_round_computation)
