# Copyright 2020, Anonymous.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for constructing reconstruction models from Keras models."""

import collections
from typing import Iterable, List, Sequence, Union

import tensorflow as tf
import tensorflow_federated as tff

from reconstruction import reconstruction_model


def from_keras_model(
    keras_model: tf.keras.Model,
    *,  # Caller passes below args by name.
    global_layers: Iterable[tf.keras.layers.Layer],
    local_layers: Iterable[tf.keras.layers.Layer],
    input_spec,
) -> reconstruction_model.ReconstructionModel:
  """Builds a `ReconstructionModel` from a `tf.keras.Model`.

  The `ReconstructionModel` returned by this function uses `keras_model` for
  its forward pass and autodifferentiation steps. During reconstruction,
  variables in `local_layers` are initialized and trained. Post-reconstruction,
  variables in `global_layers` are trained and aggregated on the server.

  Args:
    keras_model: A `tf.keras.Model` object that is not compiled.
    global_layers: Iterable of global layers to be aggregated across users. All
      trainable and non-trainable model variables that can be aggregated on the
      server should be included in these layers.
    local_layers: Iterable of local layers not shared with the server. All
      trainable and non-trainable model variables that should not be aggregated
      on the server should be included in these layers.
    input_spec: A structure of `tf.TensorSpec`s specifying the type of arguments
      the model expects. Notice this must be a compound structure of two
      elements, specifying both the data fed into the model to generate
      predictions, as its first element, as well as the expected type of the
      ground truth as its second.

  Returns:
    A `ReconstructionModel` object.

  Raises:
    TypeError: If `keras_model` is not an instance of `tf.keras.Model`.
    ValueError: If `keras_model` was compiled.
  """
  if not isinstance(keras_model, tf.keras.Model):
    raise TypeError('Expected `int`, found {}.'.format(type(keras_model)))
  if len(input_spec) != 2:
    raise ValueError('The top-level structure in `input_spec` must contain '
                     'exactly two elements, as it must specify type '
                     'information for both inputs to and predictions from the '
                     'model.')

  if keras_model._is_compiled:  # pylint: disable=protected-access
    raise ValueError('`keras_model` must not be compiled')

  return _KerasReconstructionModel(
      inner_model=keras_model,
      global_layers=global_layers,
      local_layers=local_layers,
      input_spec=input_spec)


class _KerasReconstructionModel(reconstruction_model.ReconstructionModel):
  """Internal wrapper class for `tf.keras.Model` objects.

  Wraps uncompiled Keras models as `ReconstructionModel`s.
  Tracks global and local layers of the model. Parameters contained in global
  layers are sent to the server and aggregated across users normally, and
  parameters contained in local layers are reconstructed at the beginning of
  each round and not sent to the server. The loss function and metrics are
  passed to a `tff.templates.IterativeProcess` wrapping this model and computed
  there for both training and evaluation.
  """

  def __init__(self, inner_model: tf.keras.Model,
               global_layers: Iterable[tf.keras.layers.Layer],
               local_layers: Iterable[tf.keras.layers.Layer],
               input_spec: tff.Type):
    self._keras_model = inner_model
    self._global_layers = list(global_layers)
    self._local_layers = list(local_layers)
    self._input_spec = input_spec

    # Ensure global_layers and local_layers include exactly the Keras model's
    # trainable and non-trainable variables. Use hashable refs to uniquely
    # compare variables, and track variable names for informative error
    # messages.
    global_and_local_variables = set()
    for layer in self._global_layers + self._local_layers:
      global_and_local_variables.update(
          (var.ref(), var.name)
          for var in layer.trainable_variables + layer.non_trainable_variables)

    keras_variables = set((var.ref(), var.name)
                          for var in inner_model.trainable_variables +
                          inner_model.non_trainable_variables)

    if global_and_local_variables != keras_variables:
      # Use a symmetric set difference to compare the variables, since either
      # set may include variables not present in the other.
      variables_difference = global_and_local_variables ^ keras_variables
      raise ValueError('Global and local layers must include all trainable '
                       'and non-trainable variables in the Keras model. '
                       'Difference: {d}, Global and local layers vars: {v}, '
                       'Keras vars: {k}'.format(
                           d=variables_difference,
                           v=global_and_local_variables,
                           k=keras_variables))

  @property
  def global_trainable_variables(self):
    variables = []
    for layer in self._global_layers:
      variables.extend(layer.trainable_variables)
    return variables

  @property
  def global_non_trainable_variables(self):
    variables = []
    for layer in self._global_layers:
      variables.extend(layer.non_trainable_variables)
    return variables

  @property
  def local_trainable_variables(self):
    variables = []
    for layer in self._local_layers:
      variables.extend(layer.trainable_variables)
    return variables

  @property
  def local_non_trainable_variables(self):
    variables = []
    for layer in self._local_layers:
      variables.extend(layer.non_trainable_variables)
    return variables

  @property
  def input_spec(self):
    return self._input_spec

  @tf.function
  def forward_pass(self, batch_input, training=True):
    if hasattr(batch_input, '_asdict'):
      batch_input = batch_input._asdict()
    if isinstance(batch_input, collections.abc.Mapping):
      inputs = batch_input.get('x')
    else:
      inputs = batch_input[0]
    if inputs is None:
      raise KeyError('Received a batch_input that is missing required key `x`. '
                     'Instead have keys {}'.format(list(batch_input.keys())))
    predictions = self._keras_model(inputs, training=training)

    if isinstance(batch_input, collections.abc.Mapping):
      y_true = batch_input.get('y')
    else:
      y_true = batch_input[1]

    return reconstruction_model.BatchOutput(
        predictions=predictions,
        labels=y_true,
        num_examples=tf.shape(tf.nest.flatten(inputs)[0])[0])


class MeanLossMetric(tf.keras.metrics.Mean):
  """A `tf.keras.metrics.Metric` wrapper for a loss function.

  The loss function can be a `tf.keras.losses.Loss`, or it can be any callable
  with the signature loss(y_true, y_pred).

  Note that the dependence on a passed-in loss function may cause issues with
  serialization of this metric.
  """

  def __init__(self, loss_fn, name='loss', dtype=tf.float32):
    super().__init__(name, dtype)
    self._loss_fn = loss_fn

  def update_state(self, y_true, y_pred, sample_weight=None):
    batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype)
    y_true = tf.cast(y_true, self._dtype)
    y_pred = tf.cast(y_pred, self._dtype)
    batch_loss = self._loss_fn(y_true, y_pred)

    return super().update_state(batch_loss, batch_size)

  def get_config(self):
    """Used to recreate an instance of this class during aggregation."""
    config = {'loss_fn': self._loss_fn}
    base_config = super().get_config()
    return dict(list(base_config.items()) + list(config.items()))


def federated_aggregate_keras_metric(
    metrics: Union[tf.keras.metrics.Metric, Sequence[tf.keras.metrics.Metric]],
    federated_values: Union[tff.Value, Sequence[tff.Value]]
) -> tff.federated_computation:
  """Aggregates variables a keras metric placed at CLIENTS to SERVER.

  Args:
    metrics: a single `tf.keras.metrics.Metric` or a `Sequence` of metrics . The
      order must match the order of variables in `federated_values`.
    federated_values: a single federated value, or a `Sequence` of federated
      values. The values must all have `tff.CLIENTS` placement. If value is a
      `Sequence` type, it must match the order of the sequence in `metrics.

  Returns:
    The result of performing a federated sum on federated_values, then assigning
    the aggregated values into the variables of the corresponding
    `tf.keras.metrics.Metric` and calling `tf.keras.metrics.Metric.result`. The
    resulting structure has `tff.SERVER` placement.
  """
  member_types = tf.nest.map_structure(lambda t: t.type_signature.member,
                                       federated_values)

  @tff.tf_computation
  def zeros_fn():
    # `member_type` is a (potentially nested) `tff.StructType`, which is an
    # `structure.Struct`.
    return tff.structure.map_structure(
        lambda v: tf.zeros(v.shape, dtype=v.dtype), member_types)

  zeros = zeros_fn()

  @tff.tf_computation(member_types, member_types)
  def accumulate(accumulators, variables):
    return tf.nest.map_structure(tf.add, accumulators, variables)

  @tff.tf_computation(member_types, member_types)
  def merge(a, b):
    return tf.nest.map_structure(tf.add, a, b)

  @tff.tf_computation(member_types)
  def report(accumulators):
    """Insert `accumulators` back into the keras metric to obtain result."""

    def finalize_metric(
        metric: tf.keras.metrics.Metric,
        values: Union[Sequence[tf.Tensor],
                      collections.OrderedDict]) -> tff.Value:
      # Note: the following call requires that `type(metric)` have a no argument
      # __init__ method, which will restrict the types of metrics that can be
      # used. This is somewhat limiting, but the pattern to use default
      # arguments and export the values in `get_config()` (see
      # `tf.keras.metrics.TopKCategoricalAccuracy`) works well.
      keras_metric = None
      try:
        # This is some trickery to reconstruct a metric object in the current
        # scope, so that the `tf.Variable`s get created when we desire.
        keras_metric = type(metric).from_config(metric.get_config())
      except TypeError as e:
        # Re-raise the error with a more helpful message, but the previous stack
        # trace.
        raise TypeError(
            'Caught exception trying to call `{t}.from_config()` with '
            'config {c}. Confirm that {t}.__init__() has an argument for '
            'each member of the config.\nException: {e}'.format(
                t=type(metric), c=metric.config(), e=e))

      assignments = []
      for v, a in zip(keras_metric.variables, values):
        assignments.append(v.assign(a))
      with tf.control_dependencies(assignments):
        return keras_metric.result()

    if isinstance(metrics, tf.keras.metrics.Metric):
      # Only a single metric to aggregate.
      return finalize_metric(metrics, accumulators)
    else:
      # Otherwise map over all the metrics.
      return collections.OrderedDict([
          (name, finalize_metric(metric, values))
          for metric, (name, values) in zip(metrics, accumulators.items())
      ])

  return tff.federated_aggregate(federated_values, zeros, accumulate, merge,
                                 report)


def read_metric_variables(
    metrics: List[tf.keras.metrics.Metric]) -> collections.OrderedDict:
  """Reads values from Keras metric variables."""
  metric_variables = collections.OrderedDict()
  for metric in metrics:
    metric_variables[metric.name] = [v.read_value() for v in metric.variables]
  return metric_variables


def federated_output_computation_from_metrics(
    metrics: List[tf.keras.metrics.Metric]) -> tff.federated_computation:
  """Produces a federated computation for aggregating Keras metrics.

  This can be used to evaluate both Keras and non-Keras models using Keras
  metrics. Aggregates metrics across clients by summing their internal
  variables, producing new metrics with summed internal variables, and calling
  metric.result() on each. See `federated_aggregate_keras_metric` for details.

  Args:
    metrics: A List of `tf.keras.metrics.Metric` to aggregate.

  Returns:
    A `tff.federated_computation` aggregating metrics across clients by summing
    their internal variables, producing new metrics with summed internal
    variables, and calling metric.result() on each.
  """
  # Get a sample of metric variables to use to determine its type.
  sample_metric_variables = read_metric_variables(metrics)

  metric_variable_type_dict = tf.nest.map_structure(tf.TensorSpec.from_tensor,
                                                    sample_metric_variables)
  federated_local_outputs_type = tff.type_at_clients(metric_variable_type_dict)

  def federated_output(local_outputs):
    return federated_aggregate_keras_metric(metrics, local_outputs)

  federated_output_computation = tff.federated_computation(
      federated_output, federated_local_outputs_type)
  return federated_output_computation
