# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for process_with_gradient_transfer_lib."""

import collections
import math
from typing import Callable, List, OrderedDict

from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

from mixed_fl import process_with_gradient_transfer_lib

CLIENT_LEARNING_RATE = 0.01
DEFAULT_CLIENT_OPTIMIZER_FN = (
    lambda: tf.keras.optimizers.SGD(learning_rate=CLIENT_LEARNING_RATE))
DEFAULT_SERVER_OPTIMIZER_FN = (
    lambda: tf.keras.optimizers.SGD(learning_rate=1.0))


class LinearRegression(tff.learning.Model):
  """Example of a simple linear regression implemented directly."""

  # A tuple (x, y), where 'x' represent features, and 'y' represent labels.
  Batch = collections.namedtuple('Batch', ['x', 'y'])

  def __init__(self, feature_dim=2):
    # Define all the variables, similar to what Keras Layers and Models
    # do in build().
    self._feature_dim = feature_dim

    self._num_examples = tf.Variable(0, trainable=False)
    self._num_batches = tf.Variable(0, trainable=False)
    self._loss_sum = tf.Variable(0.0, trainable=False)
    self._a = tf.Variable([[0.0]] * feature_dim, trainable=True)
    self._b = tf.Variable(0.0, trainable=True)
    # Define a non-trainable model variable (another bias term) for code
    # coverage in testing.
    self._c = tf.Variable(0.0, trainable=False)
    self._input_spec = LinearRegression.make_batch(
        x=tf.TensorSpec([None, self._feature_dim], tf.float32),
        y=tf.TensorSpec([None, 1], tf.float32))

  @property
  def trainable_variables(self):
    return [self._a, self._b]

  @property
  def non_trainable_variables(self):
    return [self._c]

  @property
  def local_variables(self):
    return [self._num_batches, self._loss_sum]

  @property
  def input_spec(self):
    # Model expects batched input, but the batch dimension is unspecified.
    return self._input_spec

  @tf.function
  def predict_on_batch(self, x):
    return tf.matmul(x, self._a) + self._b + self._c

  @tf.function
  def forward_pass(self, batch, training=True):
    del training  # Unused.
    if isinstance(batch, dict):
      batch = self.make_batch(**batch)
    if not self._input_spec.y.is_compatible_with(batch.y):
      raise ValueError('Expected batch.y to be compatible with '
                       '{} but found {}'.format(self._input_spec.y, batch.y))
    if not self._input_spec.x.is_compatible_with(batch.x):
      raise ValueError('Expected batch.x to be compatible with '
                       '{} but found {}'.format(self._input_spec.x, batch.x))
    predictions = self.predict_on_batch(batch.x)
    residuals = predictions - batch.y
    num_examples = tf.gather(tf.shape(predictions), 0)
    total_loss = 0.5 * tf.reduce_sum(tf.pow(residuals, 2))

    self._loss_sum.assign_add(total_loss)
    self._num_examples.assign_add(num_examples)
    self._num_batches.assign_add(1)

    average_loss = total_loss / tf.cast(num_examples, tf.float32)
    return tff.learning.BatchOutput(
        loss=average_loss, predictions=predictions, num_examples=num_examples)

  @tf.function
  def report_local_unfinalized_metrics(
      self):
    """Creates an `OrderedDict` of metric names to unfinalized values."""
    return collections.OrderedDict(
        loss=[self._loss_sum,
              tf.cast(self._num_examples, tf.float32)])

  def metric_finalizers(
      self):
    """Creates an `OrderedDict` of metric names to finalizer functions."""
    return collections.OrderedDict(loss=tf.function(func=lambda x: x[0] / x[1]))

  @classmethod
  def make_batch(cls, x, y):
    """Returns a `Batch` to pass to the forward pass."""
    return cls.Batch(x, y)

  @tf.function
  def reset_metrics(self):
    """Resets metrics variables to initial value."""
    raise NotImplementedError(
        'The `reset_metrics` method isn\'t implemented for your custom '
        '`tff.learning.Model`. Please implement it before using this method. '
        'You can leave this method unimplemented if you won\'t use this method.'
    )


def _get_federated_dataset_for_testing():
  ds = tf.data.Dataset.from_tensor_slices(
      collections.OrderedDict([
          ('x', [[1.0, 2.0], [3.0, 4.0]]),
          ('y', [[5.0], [6.0]]),
      ])).batch(2)
  return [ds] * 3


def _get_federated_dataset_with_empty_client_for_testing():
  federated_ds = _get_federated_dataset_for_testing()
  # This empties the dataset for the last client in the list, but maintains the
  # same shape spec.
  federated_ds[-1] = federated_ds[-1].filter(lambda _: False)
  return federated_ds


class GradientTransferLibTest(parameterized.TestCase, tf.test.TestCase):

  def assert_structure_prefix_eq(self, a, b):
    """Asserts that the fields in `b` are equal to the first fields of `a`."""
    # zipped_elements will be the length of the shorter of `a` and `b`.
    zipped_elements = zip(list(vars(a).items()), list(vars(b).items()))
    for ((a_name, a_struct), (b_name, b_struct)) in zipped_elements:
      self.assertEqual(a_name, b_name)
      tf.nest.assert_same_structure(a_struct, b_struct)

  def test_construction(self):
    iterative_process = process_with_gradient_transfer_lib.build_federated_averaging_process_with_gradient_transfer(
        model_fn=LinearRegression,
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN)

    model_type = tff.learning.ModelWeights(
        trainable=[
            tff.TensorType(tf.float32, [2, 1]),
            tff.TensorType(tf.float32)
        ],
        non_trainable=[tff.TensorType(tf.float32)])
    gradients_type = tff.FederatedType(model_type.trainable, tff.SERVER)
    gradient_weight_type = tff.FederatedType(
        tff.TensorType(dtype=tf.float32), tff.SERVER)

    server_state_type = tff.FederatedType(
        process_with_gradient_transfer_lib
        .ServerStateWithAverageClientGradients(
            model=model_type,
            optimizer_state=[tf.int64],
            delta_aggregate_state=collections.OrderedDict(
                value_sum_process=(), weight_sum_process=()),
            model_broadcast_state=(),
            average_client_gradients=model_type.trainable), tff.SERVER)

    tff.test.assert_types_equivalent(
        iterative_process.initialize.type_signature,
        tff.FunctionType(parameter=None, result=server_state_type))

    dataset_type = tff.FederatedType(
        tff.SequenceType(
            collections.OrderedDict(
                x=tff.TensorType(tf.float32, [None, 2]),
                y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS)

    metrics_type = tff.FederatedType(
        collections.OrderedDict(
            broadcast=(),
            aggregation=collections.OrderedDict(mean_value=(), mean_weight=()),
            train=collections.OrderedDict(
                loss=tff.TensorType(tf.float32),
                augmenting_gradient_weight=tff.TensorType(tf.float32),
                client_gradient_weight=tff.TensorType(tf.float32),
                augmenting_gradients_glob_norm=tff.TensorType(tf.float32),
                augmenting_gradients_norm_squared=tff.TensorType(tf.float32),
                start_client_gradients_glob_norm=tff.TensorType(tf.float32),
                end_client_gradients_glob_norm=tff.TensorType(tf.float32),
                start_total_gradients_glob_norm=tff.TensorType(tf.float32),
                end_total_gradients_glob_norm=tff.TensorType(tf.float32),
                client_gradients_average_glob_norm=tff.TensorType(tf.float32),
                client_gradients_average_norm_squared=tff.TensorType(
                    tf.float32),
                total_gradients_glob_norm=tff.TensorType(tf.float32),
                total_gradients_norm_squared=tff.TensorType(tf.float32),
                client_gradients_biased_sample_variance=tff.TensorType(
                    tf.float32),
                client_gradients_unbiased_sample_variance=tff.TensorType(
                    tf.float32),
                client_gradients_biased_sample_variance_weighted=tff.TensorType(
                    tf.float32),
                client_gradients_unbiased_sample_variance_weighted=tff
                .TensorType(tf.float32),
                bgd_difference_weighted=tff.TensorType(tf.float32),
                bgd_ratio_weighted=tff.TensorType(tf.float32),
                bgd_difference_unweighted=tff.TensorType(tf.float32),
                bgd_ratio_unweighted=tff.TensorType(tf.float32),
                lipschitz_smoothness_avg=tff.TensorType(tf.float32),
                lipschitz_smoothness_min=tff.TensorType(tf.float32),
                lipschitz_smoothness_max=tff.TensorType(tf.float32),
            ),
            stat=collections.OrderedDict(
                num_examples=tff.TensorType(tf.int64))), tff.SERVER)

    tff.test.assert_types_equivalent(
        iterative_process.next.type_signature,
        tff.FunctionType(
            parameter=collections.OrderedDict(
                server_state=server_state_type,
                federated_dataset=dataset_type,
                augmenting_gradients=gradients_type,
                augmenting_gradient_weight=gradient_weight_type,
                client_gradient_weight=gradient_weight_type,
            ),
            result=(server_state_type, metrics_type)))

  def test_iterative_process_works(self):
    federated_ds = _get_federated_dataset_for_testing()

    iterative_process = process_with_gradient_transfer_lib.build_federated_averaging_process_with_gradient_transfer(
        model_fn=LinearRegression,
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN)

    state = iterative_process.initialize()
    self.assertAllClose(list(state.model.trainable), [np.zeros((2, 1)), 0.0])
    self.assertAllClose(list(state.model.non_trainable), [0.0])

    augmenting_gradients = tf.nest.map_structure(
        tf.ones_like,
        LinearRegression().trainable_variables)
    augmenting_gradient_weight = 1.0
    client_gradient_weight = 1.0

    prev_state = state
    for _ in range(5):
      state, _ = iterative_process.next(state, federated_ds,
                                        augmenting_gradients,
                                        augmenting_gradient_weight,
                                        client_gradient_weight)
      for prev_var, var in zip(prev_state.model.trainable,
                               state.model.trainable):
        self.assertTrue(np.all(var > prev_var))
      prev_state = state

  def test_iterative_process_works_on_empty_client_dataset(self):
    federated_ds = _get_federated_dataset_with_empty_client_for_testing()

    iterative_process = process_with_gradient_transfer_lib.build_federated_averaging_process_with_gradient_transfer(
        model_fn=LinearRegression,
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN)

    state = iterative_process.initialize()
    self.assertAllClose(list(state.model.trainable), [np.zeros((2, 1)), 0.0])
    self.assertAllClose(list(state.model.non_trainable), [0.0])

    augmenting_gradients = tf.nest.map_structure(
        tf.ones_like,
        LinearRegression().trainable_variables)
    augmenting_gradient_weight = 1.0
    client_gradient_weight = 1.0

    for _ in range(5):
      state, metrics = iterative_process.next(state, federated_ds,
                                              augmenting_gradients,
                                              augmenting_gradient_weight,
                                              client_gradient_weight)
      # Make sure all metrics are non-Nan (that nothing was divided by zero).
      self.assertIn('train', metrics)
      for value in metrics['train'].values():
        self.assertFalse(math.isnan(float(value)))

  def test_iterative_process_output_metrics(self):
    federated_ds = _get_federated_dataset_for_testing()

    iterative_process = process_with_gradient_transfer_lib.build_federated_averaging_process_with_gradient_transfer(
        model_fn=LinearRegression,
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN)

    state = iterative_process.initialize()
    augmenting_gradients = tf.nest.map_structure(
        tf.ones_like,
        LinearRegression().trainable_variables)
    augmenting_gradient_weight = 1.0
    client_gradient_weight = 1.0
    _, metrics = iterative_process.next(state, federated_ds,
                                        augmenting_gradients,
                                        augmenting_gradient_weight,
                                        client_gradient_weight)
    self.assertIn('train', metrics)
    self.assertIn('loss', metrics['train'])
    self.assertIn('augmenting_gradients_glob_norm', metrics['train'])
    self.assertAlmostEqual(1.7320508,
                           metrics['train']['augmenting_gradients_glob_norm'])

  def test_correct_magnitude_given_augmenting_grads(self):
    federated_ds = _get_federated_dataset_for_testing()

    augmented_grad_process = process_with_gradient_transfer_lib.build_federated_averaging_process_with_gradient_transfer(
        model_fn=LinearRegression,
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN)

    fed_avg_process = tff.learning.build_federated_averaging_process(
        model_fn=LinearRegression,
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN)

    augmented_grad_state = augmented_grad_process.initialize()
    fed_avg_state = fed_avg_process.initialize()

    ones_gradients = tf.nest.map_structure(
        tf.ones_like,
        LinearRegression().trainable_variables)
    # If augmenting with vectors of ones, the expected difference of each model
    # component b/w vanilla FedAvg and this new augmenting gradient iterative
    # process, after one round, should be equal to the client learning rate.
    expected_difference = CLIENT_LEARNING_RATE

    augmented_grad_state, _ = augmented_grad_process.next(
        augmented_grad_state, federated_ds, ones_gradients, 1.0, 1.0)
    fed_avg_state, _ = fed_avg_process.next(fed_avg_state, federated_ds)
    self.assert_structure_prefix_eq(fed_avg_state, augmented_grad_state)

    for augmented_grad_var, fed_avg_var in zip(
        augmented_grad_state.model.trainable, fed_avg_state.model.trainable):
      difference = fed_avg_var - augmented_grad_var
      self.assertAllClose(expected_difference * tf.ones_like(difference),
                          difference)

  def test_identical_to_standard_fed_avg_when_augmenting_grads_are_zeros(self):
    federated_ds = _get_federated_dataset_for_testing()

    augmented_grad_process = process_with_gradient_transfer_lib.build_federated_averaging_process_with_gradient_transfer(
        model_fn=LinearRegression,
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN)

    fed_avg_process = tff.learning.build_federated_averaging_process(
        model_fn=LinearRegression,
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN)

    augmented_grad_state = augmented_grad_process.initialize()
    fed_avg_state = fed_avg_process.initialize()

    zeros_gradients = tf.nest.map_structure(
        tf.zeros_like,
        LinearRegression().trainable_variables)

    for _ in range(5):
      augmented_grad_state, _ = augmented_grad_process.next(
          augmented_grad_state, federated_ds, zeros_gradients, 1.0, 1.0)
      fed_avg_state, _ = fed_avg_process.next(fed_avg_state, federated_ds)
      self.assert_structure_prefix_eq(fed_avg_state, augmented_grad_state)

      for augmented_grad_var, fed_avg_var in zip(
          augmented_grad_state.model.trainable, fed_avg_state.model.trainable):
        self.assertAllEqual(fed_avg_var, augmented_grad_var)

  def test_identical_to_standard_fed_avg_when_augmenting_gradient_weight_is_zero(
      self):
    federated_ds = _get_federated_dataset_for_testing()

    augmented_grad_process = process_with_gradient_transfer_lib.build_federated_averaging_process_with_gradient_transfer(
        model_fn=LinearRegression,
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN)

    fed_avg_process = tff.learning.build_federated_averaging_process(
        model_fn=LinearRegression,
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN)

    augmented_grad_state = augmented_grad_process.initialize()
    fed_avg_state = fed_avg_process.initialize()

    ones_gradients = tf.nest.map_structure(
        tf.ones_like,
        LinearRegression().trainable_variables)
    augmenting_gradient_weight = 0.0
    client_gradient_weight = 1.0

    for _ in range(5):
      augmented_grad_state, _ = augmented_grad_process.next(
          augmented_grad_state, federated_ds, ones_gradients,
          augmenting_gradient_weight, client_gradient_weight)
      fed_avg_state, _ = fed_avg_process.next(fed_avg_state, federated_ds)
      self.assert_structure_prefix_eq(fed_avg_state, augmented_grad_state)

      for augmented_grad_var, fed_avg_var in zip(
          augmented_grad_state.model.trainable, fed_avg_state.model.trainable):
        self.assertAllEqual(fed_avg_var, augmented_grad_var)

  @parameterized.named_parameters(
      ('with_high_augmenting_weight', 0.9,
       [np.array([[0.0025], [0.008]], dtype=np.float32), -0.0034999996]),
      ('with_low_augmenting_weight', 0.1,
       [np.array([[0.10249999], [0.15199998]], dtype=np.float32), 0.048499998]),
  )
  def test_correct_magnitude_given_gradient_weights(self,
                                                    augmenting_gradient_weight,
                                                    expected_model_weights):
    federated_ds = _get_federated_dataset_for_testing()

    process = process_with_gradient_transfer_lib.build_federated_averaging_process_with_gradient_transfer(
        model_fn=LinearRegression,
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN)
    state = process.initialize()

    ones_gradients = tf.nest.map_structure(
        tf.ones_like,
        LinearRegression().trainable_variables)
    client_gradient_weight = 1.0 - augmenting_gradient_weight
    state, _ = process.next(state, federated_ds, ones_gradients,
                            augmenting_gradient_weight, client_gradient_weight)

    for expected, value in zip(expected_model_weights, state.model.trainable):
      self.assertAllClose(expected, value)

  def test_aggregated_average_client_gradients(self):
    federated_ds = _get_federated_dataset_for_testing()

    iterative_process = process_with_gradient_transfer_lib.build_federated_averaging_process_with_gradient_transfer(
        model_fn=LinearRegression,
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN)

    state = iterative_process.initialize()

    augmenting_gradients = tf.nest.map_structure(
        tf.ones_like,
        LinearRegression().trainable_variables)
    augmenting_gradient_weight = 1.0
    client_gradient_weight = 1.0

    prev_average_client_gradients = None
    for _ in range(5):
      state, _ = iterative_process.next(state, federated_ds,
                                        augmenting_gradients,
                                        augmenting_gradient_weight,
                                        client_gradient_weight)
      if prev_average_client_gradients is not None:
        for prev_grad, grad in zip(prev_average_client_gradients,
                                   state.average_client_gradients):
          self.assertTrue(np.all(np.abs(grad) < np.abs(prev_grad)))
      prev_average_client_gradients = state.average_client_gradients


if __name__ == '__main__':
  tff.backends.native.set_local_python_execution_context()
  tf.test.main()
