# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for mixing_process_lib."""

import abc
import collections
import functools

from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

from mixed_fl import mixing_process_lib

DEFAULT_CLIENT_OPTIMIZER_FN = (
    lambda: tf.keras.optimizers.SGD(learning_rate=0.01))
DEFAULT_DATACENTER_OPTIMIZER_FN = (
    lambda: tf.keras.optimizers.SGD(learning_rate=0.01))
DEFAULT_SERVER_OPTIMIZER_FN = (
    lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
LOSS = tf.keras.losses.MeanSquaredError()
UNREDUCED_LOSS = tf.keras.losses.MeanSquaredError(
    reduction=tf.keras.losses.Reduction.NONE)
NUM_CLIENTS = 3


def _get_keras_model():
  inputs = tf.keras.Input(shape=(2,))
  outputs = tf.keras.layers.Dense(
      1, activation=None, use_bias=True, kernel_initializer='zeros')(
          inputs)
  return tf.keras.Model(inputs=inputs, outputs=outputs)


def _get_tff_model(keras_model, metrics=None):
  return tff.learning.from_keras_model(
      keras_model=keras_model,
      loss=LOSS,
      metrics=metrics,
      input_spec=collections.OrderedDict(
          x=tf.TensorSpec(shape=[None, 2], dtype=tf.float32),
          y=tf.TensorSpec(shape=[None, 1], dtype=tf.float32)))


def _get_client_data_for_testing():
  client_example_dict = collections.OrderedDict([
      ('x', [[1.0, 2.0], [1.0, 2.0]]),
      ('y', [[10.0], [10.0]]),
  ])
  tensor_slices_dict = {}
  for i in range(NUM_CLIENTS):
    tensor_slices_dict['%d' % i] = client_example_dict
  return tff.simulation.datasets.TestClientData(tensor_slices_dict)


def _get_datacenter_dataset_for_testing():
  ds = tf.data.Dataset.from_tensor_slices(
      collections.OrderedDict(
          x=[[1.0, 2.0], [1.0, 2.1], [1.0, 2.2], [1.0, 2.3], [1.0, 2.4],
             [1.0, 2.5], [1.0, 2.6], [1.0, 2.7], [1.0, 2.8]],
          y=[[11.0], [11.1], [11.2], [11.3], [11.4], [11.5], [11.6], [11.7],
             [11.8]]))
  return ds


def _client_dataset_processing_fn(unprocessed_dataset, batch_size=2):
  return unprocessed_dataset.batch(batch_size)


def _get_non_mixed_iterative_process(client_data):
  # Used as a baseline for comparing how model evolves differently under mixed
  # training.
  non_mixed_iterative_process = tff.learning.build_federated_averaging_process(
      model_fn=lambda: _get_tff_model(_get_keras_model()),
      client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
      server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN)

  @tff.tf_computation(tf.string)
  def client_data_fn(client_id):
    return _client_dataset_processing_fn(
        client_data.dataset_computation(client_id))

  non_mixed_iterative_process = (
      tff.simulation.compose_dataset_computation_with_iterative_process(
          client_data_fn, non_mixed_iterative_process))
  return non_mixed_iterative_process


class BaseMixingProcessTestMixin(object):

  @abc.abstractmethod
  def _get_mixing_process(self, client_data):
    raise NotImplementedError

  def test_construction(self):
    client_data = _get_client_data_for_testing()
    _ = self._get_mixing_process(client_data)

  @abc.abstractmethod
  def _assert_on_metrics(self, metrics, non_mixed_metrics):
    raise NotImplementedError

  def test_next_works(self):
    client_data = _get_client_data_for_testing()
    mixing_process = self._get_mixing_process(client_data)

    state = mixing_process.initialize()
    self.assertAllClose(
        list(state.model.trainable), [np.zeros(
            (2, 1)), np.zeros((1,))])
    self.assertAllClose(list(state.model.non_trainable), [])

    non_mixed_iterative_process = _get_non_mixed_iterative_process(client_data)
    non_mixed_state = non_mixed_iterative_process.initialize()

    for _ in range(1, 5):
      state, metrics = mixing_process.next(state, client_data.client_ids)

      non_mixed_state, non_mixed_metrics = non_mixed_iterative_process.next(
          non_mixed_state, client_data.client_ids)

      for mixed_var, non_mixed_var in zip(state.model.trainable,
                                          non_mixed_state.model.trainable):
        # Because the datacenter dataset has a slightly different label than the
        # the federated data, it 'pulls' the model weights to be slightly
        # different (greater).
        self.assertTrue(np.all(np.greater(mixed_var, non_mixed_var)))

      self._assert_on_metrics(metrics, non_mixed_metrics)


class ParallelTrainingMixingProcessTest(BaseMixingProcessTestMixin):

  num_eff_clients = 1
  eff_client_cache_size = 5
  datacenter_weight = 0.25
  client_weight = 1.0 - datacenter_weight

  @abc.abstractmethod
  def map_fn(self, element):
    raise NotImplementedError

  def get_examples_from_preds_fn(self):
    return mixing_process_lib.DEFAULT_NUM_EXAMPLES_FROM_PREDICTIONS_FN

  def get_datacenter_num_examples(self):
    return self.num_eff_clients * self.eff_client_cache_size

  def _get_mixing_process(self, client_data):

    def datacenter_dataset_fn():
      return _get_datacenter_dataset_for_testing()

    def datacenter_dataset_processing_fn(unprocessed_dataset):
      return unprocessed_dataset.map(self.map_fn).batch(2)

    return mixing_process_lib.build_mixing_process_with_parallel_training(
        keras_model_fn=_get_keras_model,
        tff_model_fn=lambda: _get_tff_model(_get_keras_model()),
        client_data=client_data,
        datacenter_dataset_fn=datacenter_dataset_fn,
        num_effective_clients_for_training=self.num_eff_clients,
        num_examples_per_effective_client=self.eff_client_cache_size,
        datacenter_loss_fn=LOSS,
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        datacenter_optimizer_fn=DEFAULT_DATACENTER_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN,
        client_dataset_processing_fn=_client_dataset_processing_fn,
        datacenter_dataset_processing_fn=datacenter_dataset_processing_fn,
        datacenter_weight=self.datacenter_weight,
        client_weight=self.client_weight,
        get_num_examples_from_predictions_fn=self.get_examples_from_preds_fn(),
    )

  def _assert_on_metrics(self, metrics, non_mixed_metrics):
    self.assertEqual(metrics['train']['federated_num_examples'],
                     non_mixed_metrics['train']['num_examples'])
    self.assertEqual(metrics['train']['datacenter_num_examples'],
                     self.get_datacenter_num_examples())
    self.assertEqual(
        metrics['train']['num_examples'],
        metrics['train']['federated_num_examples'] +
        metrics['train']['datacenter_num_examples'])


class ParallelTrainingWithTuplesMixingProcessTest(
    ParallelTrainingMixingProcessTest, tf.test.TestCase):

  def map_fn(self, element):
    return (element['x'], element['y'])


class ParallelTrainingWithDictsMixingProcessTest(
    ParallelTrainingMixingProcessTest, tf.test.TestCase):

  def map_fn(self, element):
    return element


class ParallelTrainingWithCustomExamplesFromPredsFnMixingProcessTest(
    ParallelTrainingMixingProcessTest):

  def get_examples_from_preds_fn(self):
    return lambda x: 10 * tf.shape(x)[0]

  def get_datacenter_num_examples(self):
    return 10 * self.num_eff_clients * self.eff_client_cache_size


class ParallelTrainingWithCustomExamplesFromPredsFnWithTuplesMixingProcessTest(
    ParallelTrainingWithCustomExamplesFromPredsFnMixingProcessTest,
    tf.test.TestCase):

  def map_fn(self, element):
    return (element['x'], element['y'])


class ParallelTrainingWithCustomExamplesFromPredsFnWithDictsMixingProcessTest(
    ParallelTrainingWithCustomExamplesFromPredsFnMixingProcessTest,
    tf.test.TestCase):

  def map_fn(self, element):
    return element


def _get_grad_augment_mixing_process_for_augmenting_gradient_weight_tests(
    client_data, map_fn, datacenter_gradient_weight):
  datacenter_batch_size = 6

  def datacenter_dataset_fn():
    return _get_datacenter_dataset_for_testing()

  def datacenter_dataset_processing_fn(unprocessed_dataset):
    return unprocessed_dataset.map(map_fn).batch(datacenter_batch_size)

  return mixing_process_lib.build_mixing_process_with_gradient_transfer(
      keras_model_fn=_get_keras_model,
      tff_model_fn=lambda: _get_tff_model(_get_keras_model()),
      client_data=client_data,
      datacenter_dataset_fn=datacenter_dataset_fn,
      datacenter_batch_size=datacenter_batch_size,
      datacenter_loss_fn=LOSS,
      client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
      server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN,
      client_dataset_processing_fn=_client_dataset_processing_fn,
      datacenter_dataset_processing_fn=datacenter_dataset_processing_fn,
      datacenter_gradient_weight=datacenter_gradient_weight)


class OneWayGradTransferMixingProcessTest(BaseMixingProcessTestMixin):

  datacenter_batch_size = 6

  @abc.abstractmethod
  def map_fn(self, element):
    raise NotImplementedError

  @abc.abstractmethod
  def datacenter_loss_fn(self):
    raise NotImplementedError

  def _get_mixing_process(self, client_data):

    def datacenter_dataset_fn():
      return _get_datacenter_dataset_for_testing()

    def datacenter_dataset_processing_fn(unprocessed_dataset):
      return unprocessed_dataset.map(self.map_fn).batch(2)

    return mixing_process_lib.build_mixing_process_with_gradient_transfer(
        keras_model_fn=_get_keras_model,
        tff_model_fn=lambda: _get_tff_model(_get_keras_model()),
        client_data=client_data,
        datacenter_dataset_fn=datacenter_dataset_fn,
        datacenter_batch_size=self.datacenter_batch_size,
        datacenter_loss_fn=self.datacenter_loss_fn(),
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN,
        client_dataset_processing_fn=_client_dataset_processing_fn,
        datacenter_dataset_processing_fn=datacenter_dataset_processing_fn)

  def _assert_on_metrics(self, metrics, non_mixed_metrics):
    self.assertIn('augmenting_gradients_glob_norm', metrics['train'])
    self.assertNotIn('augmenting_gradients_glob_norm',
                     non_mixed_metrics['train'])

    self.assertIn('augmenting_datacenter_grads_sample_variance',
                  metrics['train'])
    if self.datacenter_loss_fn().reduction is tf.keras.losses.Reduction.NONE:
      self.assertGreater(
          metrics['train']['augmenting_datacenter_grads_sample_variance'], 0.0)
    else:
      self.assertEqual(
          metrics['train']['augmenting_datacenter_grads_sample_variance'], 0.0)

  def _test_datacenter_gradient_weight_works(self,
                                             smaller_datacenter_gradient_weight,
                                             larger_datacenter_gradient_weight):
    client_data = _get_client_data_for_testing()
    process_with_smaller_weight = (
        _get_grad_augment_mixing_process_for_augmenting_gradient_weight_tests(
            client_data, self.map_fn, smaller_datacenter_gradient_weight))
    process_with_larger_weight = (
        _get_grad_augment_mixing_process_for_augmenting_gradient_weight_tests(
            client_data, self.map_fn, larger_datacenter_gradient_weight))

    state_with_smaller_weight, _ = process_with_smaller_weight.next(
        process_with_smaller_weight.initialize(), client_data.client_ids)
    state_with_larger_weight, _ = process_with_larger_weight.next(
        process_with_larger_weight.initialize(), client_data.client_ids)

    for smaller_var, larger_var in zip(
        state_with_smaller_weight.model.trainable,
        state_with_larger_weight.model.trainable):
      self.assertTrue(np.all(larger_var > smaller_var))


class OneWayGradTransferWithReducedDatacenterLossMixingProcessTest(
    OneWayGradTransferMixingProcessTest):

  def datacenter_loss_fn(self):
    return LOSS


class OneWayGradTransferWithUnreducedDatacenterLossMixingProcessTest(
    OneWayGradTransferMixingProcessTest):

  def datacenter_loss_fn(self):
    return UNREDUCED_LOSS


class OneWayGradTransferWithTuplesWithReducedDatacenterLossMixingProcessTest(
    OneWayGradTransferWithReducedDatacenterLossMixingProcessTest,
    tf.test.TestCase, parameterized.TestCase):

  def map_fn(self, element):
    return (element['x'], element['y'])

  @parameterized.named_parameters(
      ('datacenter_gradient_weight_0p0_vs_0p1', 0.0, 0.1),
      ('datacenter_gradient_weight_0p1_vs_1p0', 0.1, 1.0))
  def test_datacenter_gradient_weight_works(self, smaller_weight,
                                            larger_weight):
    self._test_datacenter_gradient_weight_works(smaller_weight, larger_weight)


class OneWayGradTransferWithDictsWithReducedDatacenterLossMixingProcessTest(
    OneWayGradTransferWithReducedDatacenterLossMixingProcessTest,
    tf.test.TestCase, parameterized.TestCase):

  def map_fn(self, element):
    return element

  @parameterized.named_parameters(
      ('datacenter_gradient_weight_0p0_vs_0p1', 0.0, 0.1),
      ('datacenter_gradient_weight_0p1_vs_1p0', 0.1, 1.0))
  def test_datacenter_gradient_weight_works(self, smaller_weight,
                                            larger_weight):
    self._test_datacenter_gradient_weight_works(smaller_weight, larger_weight)


class OneWayGradTransferWithTuplesWithUnreducedDatacenterLossMixingProcessTest(
    OneWayGradTransferWithUnreducedDatacenterLossMixingProcessTest,
    tf.test.TestCase, parameterized.TestCase):

  def map_fn(self, element):
    return (element['x'], element['y'])

  @parameterized.named_parameters(
      ('datacenter_gradient_weight_0p0_vs_0p1', 0.0, 0.1),
      ('datacenter_gradient_weight_0p1_vs_1p0', 0.1, 1.0))
  def test_datacenter_gradient_weight_works(self, smaller_weight,
                                            larger_weight):
    self._test_datacenter_gradient_weight_works(smaller_weight, larger_weight)


class OneWayGradTransferWithDictsWithUnreducedDatacenterLossMixingProcessTest(
    OneWayGradTransferWithUnreducedDatacenterLossMixingProcessTest,
    tf.test.TestCase, parameterized.TestCase):

  def map_fn(self, element):
    return element

  @parameterized.named_parameters(
      ('datacenter_gradient_weight_0p0_vs_0p1', 0.0, 0.1),
      ('datacenter_gradient_weight_0p1_vs_1p0', 0.1, 1.0))
  def test_datacenter_gradient_weight_works(self, smaller_weight,
                                            larger_weight):
    self._test_datacenter_gradient_weight_works(smaller_weight, larger_weight)


class TwoWayGradTransferMixingProcessTest(BaseMixingProcessTestMixin):

  num_eff_clients = 1
  eff_client_cache_size = 5

  @abc.abstractmethod
  def map_fn(self, element):
    raise NotImplementedError

  @abc.abstractmethod
  def datacenter_loss_fn(self):
    raise NotImplementedError

  def get_augmenting_datacenter_gradients_option(self):
    return mixing_process_lib.GradientsComputationOption.FULL_BATCH_WITH_LAST_CHECKPOINT

  def _get_mixing_process(self, client_data):

    def datacenter_dataset_fn():
      return _get_datacenter_dataset_for_testing()

    def datacenter_dataset_processing_fn(unprocessed_dataset):
      return unprocessed_dataset.map(self.map_fn).batch(2)

    return mixing_process_lib.build_mixing_process_with_two_way_gradient_transfer(
        keras_model_fn=_get_keras_model,
        tff_model_fn=lambda: _get_tff_model(_get_keras_model()),
        client_data=client_data,
        datacenter_dataset_fn=datacenter_dataset_fn,
        num_effective_clients_for_training=self.num_eff_clients,
        num_examples_per_effective_client=self.eff_client_cache_size,
        datacenter_loss_fn=self.datacenter_loss_fn(),
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        datacenter_optimizer_fn=DEFAULT_DATACENTER_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN,
        client_dataset_processing_fn=_client_dataset_processing_fn,
        datacenter_dataset_processing_fn=datacenter_dataset_processing_fn,
        augmenting_datacenter_gradients_option=self
        .get_augmenting_datacenter_gradients_option())

  def _assert_on_metrics(self, metrics, non_mixed_metrics):
    self.assertIn('augmenting_datacenter_grads_glob_norm', metrics['train'])
    self.assertNotIn('augmenting_datacenter_grads_glob_norm',
                     non_mixed_metrics['train'])
    self.assertIn('augmenting_client_grads_glob_norm', metrics['train'])
    self.assertNotIn('augmenting_client_grads_glob_norm',
                     non_mixed_metrics['train'])

    self.assertIn('augmenting_datacenter_grads_sample_variance',
                  metrics['train'])
    if (self.datacenter_loss_fn().reduction is tf.keras.losses.Reduction.NONE
        and self.get_augmenting_datacenter_gradients_option()
        is not mixing_process_lib.GradientsComputationOption
        .AVERAGED_FROM_CENTRAL_TRAINING):
      self.assertGreater(
          metrics['train']['augmenting_datacenter_grads_sample_variance'], 0.0)
    else:
      self.assertEqual(
          metrics['train']['augmenting_datacenter_grads_sample_variance'], 0.0)


class TwoWayGradTransferWithReducedDatacenterLossMixingProcessTest(
    TwoWayGradTransferMixingProcessTest):

  def datacenter_loss_fn(self):
    return LOSS


class TwoWayGradTransferWithUnreducedDatacenterLossMixingProcessTest(
    TwoWayGradTransferMixingProcessTest):

  def datacenter_loss_fn(self):
    return UNREDUCED_LOSS


class TwoWayGradTransferWithTuplesWithReducedDatacenterLossMixingProcessTest(
    TwoWayGradTransferWithReducedDatacenterLossMixingProcessTest,
    tf.test.TestCase, parameterized.TestCase):

  def map_fn(self, element):
    return (element['x'], element['y'])


class TwoWayGradTransferWithDictsWithReducedDatacenterLossMixingProcessTest(
    TwoWayGradTransferWithReducedDatacenterLossMixingProcessTest,
    tf.test.TestCase, parameterized.TestCase):

  def map_fn(self, element):
    return element


class TwoWayGradTransferWithTuplesWithUnreducedDatacenterLossMixingProcessTest(
    TwoWayGradTransferWithUnreducedDatacenterLossMixingProcessTest,
    tf.test.TestCase, parameterized.TestCase):

  def map_fn(self, element):
    return (element['x'], element['y'])


class TwoWayGradTransferWithDictsWithUnreducedDatacenterLossMixingProcessTest(
    TwoWayGradTransferWithUnreducedDatacenterLossMixingProcessTest,
    tf.test.TestCase, parameterized.TestCase):

  def map_fn(self, element):
    return element


class TwoWayGradTransferWithTuplesWithAugmentGradientsAveragedFromCentralTrainingMixingProcessTest(
    TwoWayGradTransferWithUnreducedDatacenterLossMixingProcessTest,
    tf.test.TestCase, parameterized.TestCase):

  def map_fn(self, element):
    return (element['x'], element['y'])

  def get_augmenting_datacenter_gradients_option(self):
    return mixing_process_lib.GradientsComputationOption.AVERAGED_FROM_CENTRAL_TRAINING


class TwoWayGradTransferWithDictsWithAugmentGradientsAveragedFromCentralTrainingMixingProcessTest(
    TwoWayGradTransferWithUnreducedDatacenterLossMixingProcessTest,
    tf.test.TestCase, parameterized.TestCase):

  def map_fn(self, element):
    return element

  def get_augmenting_datacenter_gradients_option(self):
    return mixing_process_lib.GradientsComputationOption.AVERAGED_FROM_CENTRAL_TRAINING


def _get_example_transfer_mixing_process_for_repetition_tests(
    client_data,
    num_examples_to_augment,
    num_repetitions):

  class LabelAvg(tf.keras.metrics.Mean):
    """A `tf.keras.metrics.Metric` that averages the labels seen."""

    def __init__(self, name='avg_label', dtype=tf.float32):
      super().__init__(name, dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
      del y_pred
      return super().update_state(y_true, sample_weight)

  def datacenter_dataset_fn():
    return tf.data.Dataset.from_tensor_slices(
        collections.OrderedDict(x=[[0.0, 0.0], [0.0, 0.0]], y=[[0.0], [20.0]]))

  return mixing_process_lib.build_mixing_process_with_example_transfer(
      tff_model_fn=lambda: _get_tff_model(_get_keras_model(), [LabelAvg()]),
      client_data=client_data,
      datacenter_dataset_fn=datacenter_dataset_fn,
      num_examples_to_augment=num_examples_to_augment,
      num_repetitions=num_repetitions,
      client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
      server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN,
      client_dataset_processing_fn=_client_dataset_processing_fn)


def _get_example_transfer_mixing_process_for_custom_mixing_fn_tests(
    client_data,
    num_examples_to_augment
):

  class LabelAvg(tf.keras.metrics.Mean):
    """A `tf.keras.metrics.Metric` that averages the labels seen."""

    def __init__(self, name='avg_label', dtype=tf.float32):
      super().__init__(name, dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
      del y_pred
      return super().update_state(y_true, sample_weight)

  def datacenter_dataset_fn():
    return tf.data.Dataset.from_tensor_slices(
        collections.OrderedDict(x=[[1.0, 2.0]], y=[[11.5]])).repeat(9)

  def merge_fn(client_data, datacenter_data):
    """Append datacenter data to the end of client data."""
    merged_dataset = collections.OrderedDict()
    merged_dataset['x'] = tf.concat([client_data['x'], datacenter_data['x']],
                                    axis=0)
    merged_dataset['y'] = tf.concat([client_data['y'], datacenter_data['y']],
                                    axis=0)
    return merged_dataset

  def custom_mixing_fn(client_id, client_data, datacenter_dataset_fn,
                       num_examples_to_augment, num_repetitions,
                       client_dataset_processing_fn, datacenter_shuffle_buffer):
    """Custom mixing fn to merge datacenter and client data at batch level."""
    del num_repetitions
    augmentation_dataset = tf.data.Dataset.zip(
        (client_dataset_processing_fn(
            client_data.dataset_computation(client_id)),
         datacenter_dataset_fn().shuffle(
             datacenter_shuffle_buffer,
             reshuffle_each_iteration=False).repeat().batch(
                 num_examples_to_augment))).map(merge_fn)
    return augmentation_dataset

  return mixing_process_lib.build_mixing_process_with_example_transfer(
      tff_model_fn=lambda: _get_tff_model(_get_keras_model(), [LabelAvg()]),
      client_data=client_data,
      datacenter_dataset_fn=datacenter_dataset_fn,
      num_examples_to_augment=num_examples_to_augment,
      mixing_fn=custom_mixing_fn,
      client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
      server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN,
      client_dataset_processing_fn=functools.partial(
          _client_dataset_processing_fn, batch_size=1))


class ExampleTransferMixingProcessTest(BaseMixingProcessTestMixin,
                                       tf.test.TestCase,
                                       parameterized.TestCase):

  num_examples_to_augment = 5

  def _get_mixing_process(self, client_data):

    def datacenter_dataset_fn():
      return _get_datacenter_dataset_for_testing()

    return mixing_process_lib.build_mixing_process_with_example_transfer(
        tff_model_fn=lambda: _get_tff_model(_get_keras_model()),
        client_data=client_data,
        datacenter_dataset_fn=datacenter_dataset_fn,
        num_examples_to_augment=self.num_examples_to_augment,
        client_optimizer_fn=DEFAULT_CLIENT_OPTIMIZER_FN,
        server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN,
        client_dataset_processing_fn=_client_dataset_processing_fn)

  def _assert_on_metrics(self, metrics, non_mixed_metrics):
    self.assertEqual(
        metrics['train']['num_examples'],
        non_mixed_metrics['train']['num_examples'] +
        self.num_examples_to_augment * NUM_CLIENTS)

  def test_no_repetition_two_unique_examples_each_used_once(self):
    num_examples_to_augment = 2
    num_repetitions = 1

    client_data = _get_client_data_for_testing()
    mixing_process = _get_example_transfer_mixing_process_for_repetition_tests(
        client_data, num_examples_to_augment, num_repetitions)

    state = mixing_process.initialize()
    _, metrics = mixing_process.next(state, client_data.client_ids)
    # All the labels in the federated data are equal to 10.0. The datacenter
    # data used for augmenting has one example with label 0.0 and one example
    # with label 20.0. If two unique examples are used, then we expect the
    # average label to be 10.0; but if only one datacenter example is used,
    # twice (i.e., two 0.0s or two 20.0s), then the average label will not be
    # 10.0.
    self.assertEqual(10.0, metrics['train']['avg_label'])

  def test_repetition_one_unique_example_used_twice(self):
    num_examples_to_augment = 2
    num_repetitions = 2

    client_data = _get_client_data_for_testing()
    mixing_process = _get_example_transfer_mixing_process_for_repetition_tests(
        client_data, num_examples_to_augment, num_repetitions)

    state = mixing_process.initialize()
    _, metrics = mixing_process.next(state, client_data.client_ids)
    # All the labels in the federated data are equal to 10.0. The datacenter
    # data used for augmenting has one example with label 0.0 and one example
    # with label 20.0. If two unique examples are used, then we expect the
    # average label to be 10.0; but if only one datacenter example is used,
    # twice (i.e., two 0.0s or two 20.0s), then the average label will not be
    # 10.0.
    self.assertNotEqual(10.0, metrics['train']['avg_label'])

  @parameterized.named_parameters(
      ('num_repetitions_exceeds_num_examples_to_augment', 3, 2),
      ('num_repetitions_not_even_divisor_of_num_examples_to_augment', 2, 3))
  def test_raise_value_error_if(self, num_repetitions, num_examples_to_augment):

    client_data = _get_client_data_for_testing()
    with self.assertRaises(ValueError):
      _get_example_transfer_mixing_process_for_repetition_tests(
          client_data, num_examples_to_augment, num_repetitions)

  def test_mixing_with_custom_mixing_fn(self):
    num_examples_to_augment = 2

    client_data = _get_client_data_for_testing()
    mixing_process = (
        _get_example_transfer_mixing_process_for_custom_mixing_fn_tests(
            client_data, num_examples_to_augment))

    state = mixing_process.initialize()
    _, metrics = mixing_process.next(state, ['0', '1', '2'])
    # Each batch of training data contains one example from the federated data
    # and two augmented examples from the datacenter data. All the labels in the
    # federated data are equal to 10.0. All the labels in the datacenter data
    # are equal to 11.5. Then the average label  will be 11.0.
    self.assertEqual(11.0, metrics['train']['avg_label'])


if __name__ == '__main__':
  tff.backends.native.set_local_python_execution_context()
  tf.test.main()
