# Copyright 2022, Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wires together matrix factorization query and TFF DPAggregator."""

from typing import Callable, Optional, Tuple, Any

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
import tensorflow_privacy as tfp

from distributed_dp_matrix_factorization import matrix_constructors
from distributed_dp_matrix_factorization import matrix_factorization_query
from distributed_dp_matrix_factorization import compression_query
from distributed_dp_matrix_factorization import modular_clipping_factory
from distributed_dp_matrix_factorization import accounting_utils

from absl import logging


@tf.function
def _compute_h_sensitivity(h_matrix: tf.Tensor) -> tf.Tensor:
  column_norms = tf.linalg.norm(h_matrix, ord=2, axis=0)
  return tf.reduce_max(column_norms)


def _normalize_w_and_h(w_matrix: tf.Tensor,
                       h_matrix: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
  """Normalizes W and H so that the sensitivity of x -> Hx is 1."""
  h_sens = _compute_h_sensitivity(h_matrix)
  return w_matrix * h_sens, h_matrix / h_sens


def _make_residual_matrix(w_matrix: tf.Tensor) -> tf.Tensor:
  """Creates a one-index residual matrix.

  That is, for any vector v, the constructed matrix X satisfies:

  Xv[i] = Wv[i] - Wv[i-1]

  for an input W, where Wv[-1] is interpreted as 0.

  Args:
    w_matrix: Matrix for which to compute residual matrix, as described above.

  Returns:
    A residual matrix.
  """
  num_rows_w = w_matrix.shape[0]
  np_eye = np.eye(num_rows_w - 1)
  # Add a row and column of zeros, on the top and right of the lower-dimensional
  # identity.
  offset_matrix = np.pad(np_eye, pad_width=[(1, 0), (0, 1)])
  tf_offset = tf.constant(offset_matrix, dtype=w_matrix.dtype)
  offset_w = tf_offset @ w_matrix
  return w_matrix - offset_w


def _create_residual_linear_query_dp_factory(
    *,
    tensor_specs: matrix_factorization_query.NestedTensorSpec,
    l2_norm_clip,
    std_dev,
    w_matrix: tf.Tensor,
    h_matrix: tf.Tensor,
    residual_clear_query_fn: Callable[
        [matrix_factorization_query.NestedTensorSpec],
        matrix_factorization_query.OnlineQuery],
    clients_per_round: int,
    seed: Optional[int],
    # new params below
    beta,
    scale,
    client_template,
    bits,
    dim
) -> tff.aggregators.DifferentiallyPrivateFactory:
  """Implements on-the-fly noise generation for a factorized linear query.

  This function represents the integration of the mechanisms presented in
  "Private Online Prefix Sums via Optimal Matrix Factorizations",
  https://arxiv.org/abs/2202.08312, with TFF Aggregators.

  It produces a query that computes the _residuals_ of an underlying linear
  query S.  That is, we have an original linear query `S` which is factorized
  as `S = W H`. Then, this query produces a `DifferentiallyPrivateFactory` that
  on round `t` provides a DP estimate of  `Sx_t - Sx_{t-1}`.

  The `residual_clear_query_fn` must be a function that computes
  `Sx_t - Sx_{t-1}`, likely in a more efficient manner than matrix
  multiplication. For example, if the query `S` is the prefix-sum query (lower
  triangular matrix of 1s), this `residual_clear_query_fn` should return a
  `matrix_factorization_query.OnlineQuery` computing the identity (which is the
  term-by-term residual of prefix-sum).

  The function to which we apply isotropic Gaussian noise in this factorization
  is *guaranteed to have sensitivity exactly `l2_norm_clip` for single-pass
  algorithms*; internally, this is accomplished by normalizing the provided
  factorization.

  Notably, this guarantee is not similarly made by tree-based aggregation
  mechanisms, which have sensitivity depending (logithmically) on the number
  of elements aggregated.

  Args:
    tensor_specs: Nested tensor specs specifying the structure to which the
      constructed mechanism will be applied.
    l2_norm_clip: Global l2 norm to which to clip client contributions in the
      constructed aggregator factory.
    noise_multiplier: Configures Gaussian noise with `stddev = l2_norm_clip *
      noise_multiplier`; see comments on sensitivity above.
    w_matrix: The W term of a matrix factorization S = WH to use.
    h_matrix: The H term of a matrix factorization S = WH to use.
    residual_clear_query_fn: Callable which accepts nested tensor specs and
      returns an instance of `matrix_factorization_query.OnlineQuery`. As noted
      above, this online query should represent the 'residuals' of the matrix S
      (IE, its t^th element should be Sx_t - Sx_{t-1}), for integration with TFF
      aggregators.
    clients_per_round: The number of clients per round to be used with this
      mechanism. Used to normalize the resulting values.
    seed: Optional seed which will guarantee deterministic noise generation.

  Returns:
    An instance of `tff.aggregators.DifferentiallyPrivateFactory` which
    implements residual of prefix-sum computations with the streaming matrix
    factorization mechanism.
  """
  beta = beta or 0
  conditional = beta > 0
  logging.info('Conditional rounding set to %s (beta = %f)', conditional, beta)

  # Build nested aggregators.
  agg_factory = tff.aggregators.SumFactory()
  # 1. Modular clipping.

  # Modular clipping has exclusive upper bound.
  mod_clip_lo, mod_clip_hi = -(2**(bits - 1)), 2**(bits - 1)
  agg_factory = modular_clipping_factory.ModularClippingSumFactory(
      clip_range_lower=mod_clip_lo,
      clip_range_upper=mod_clip_hi,
      inner_agg_factory=agg_factory)

  normalized_w, _ = _normalize_w_and_h(w_matrix, h_matrix)
  # To integrate with tff.learning, we must compute the residuals of our linear
  # query, which requires computing the residuals of w.
  normalized_residual_w = _make_residual_matrix(normalized_w)

  def make_noise_mech(tensor_specs, stddev):
    return matrix_factorization_query.OnTheFlyFactorizedNoiseMechanism(
        tensor_specs=tensor_specs,
        stddev=stddev,
        w_matrix=normalized_residual_w,
        clients_per_round=clients_per_round,
        seed=seed)

  # Add some post-rounding norm leeway to peacefully allow for precision issues.
  scaled_rounded_l2 = accounting_utils.rounded_l2_norm_bound(
      (l2_norm_clip + 1e-5) * scale, beta=beta, dim=dim)
  sum_query = matrix_factorization_query.FactorizedGaussianSumQuery(
      l2_norm_clip=scaled_rounded_l2,
      stddev=std_dev*scale,
      tensor_specs=tensor_specs,
      clear_query_fn=residual_clear_query_fn,
      factorized_noise_fn=make_noise_mech)
  
  #wrap matrix fact query with quantization ops
  quantization_params = compression_query.QuantizationParams(
      stochastic=True,
      conditional=conditional,
      l2_norm_bound=l2_norm_clip,
      beta=beta,
      quantize_scale=scale)
  quantized_matrix_fac_query = compression_query.CompressionSumQuery(
      quantization_params=quantization_params,
      inner_query=sum_query,
      record_template=client_template)
  
  agg_factory = tff.aggregators.DifferentiallyPrivateFactory(
      query=quantized_matrix_fac_query, record_aggregation_factory=agg_factory)
  
  # 3. L2 norm clipping as the first step.
  agg_factory = tff.aggregators.clipping_factory(
      clipping_norm=l2_norm_clip, inner_agg_factory=agg_factory)
  
  # 4. Apply a MeanFactory at last (mean can't be part of the discrete
  # DPQueries (like the case of Gaussian) as the records may become floats
  # and hence break the decompression process).
  agg_factory = tff.aggregators.UnweightedMeanFactory(
      value_sum_factory=agg_factory)  

  #mean_query = tfp.NormalizedQuery(sum_query, denominator=clients_per_round)
  #return tff.aggregators.DifferentiallyPrivateFactory(mean_query)
  return agg_factory

def get_total_dim(client_template):
  """Returns the dimension of the client template as a single vector."""
  return sum(np.prod(x.shape) for x in client_template)

def pad_dim(dim):
  return np.math.pow(2, np.ceil(np.log2(dim)))

def create_residual_prefix_sum_dp_factory(
    *,
    tensor_specs: matrix_factorization_query.NestedTensorSpec,
    w_matrix: tf.Tensor,
    h_matrix: tf.Tensor,
    clients_per_round: int,
    num_rounds: int,
    compression_flags,
    client_template,
    dp_flags,
    sqrt_num_parts,
    seed: Optional[int],
) -> tff.aggregators.DifferentiallyPrivateFactory:
  """Implements on-the-fly noise generation for the prefix-sum query.

  W and H are assumed to represent a so-called 'streaming factorization' of
  the prefix-sum matrix S, as discussed in https://arxiv.org/abs/2202.08312.

  Args:
    tensor_specs: Nested tensor specs specifying the structure to which the
      constructed mechanism will be applied.
    l2_norm_clip: Global l2 norm to which to clip client contributions in the
      constructed aggregator factory.
    noise_multiplier: Configures Gaussian noise with `stddev = l2_norm_clip *
      noise_multiplier`; see comments on sensitivity above.
    w_matrix: The W term of a matrix factorization S = WH to use.
    h_matrix: The H term of a matrix factorization S = WH to use.
    clients_per_round: The number of clients per round to be used with this
      mechanism. Used to normalize the resulting values.
    seed: Optional seed which will guarantee deterministic noise generation.

  Returns:
    An instance of `tff.aggregators.DifferentiallyPrivateFactory` which
    implements residual of prefix-sum computations with the streaming matrix
    factorization mechanism.
  """
  factorized_matrix = (w_matrix @ h_matrix).numpy()
  expected_matrix = np.tril(
      np.ones(shape=[w_matrix.shape[0]] * 2, dtype=factorized_matrix.dtype))
  np.testing.assert_allclose(factorized_matrix, expected_matrix, atol=1e-8)

  normalized_w, _ = _normalize_w_and_h(w_matrix, h_matrix)
  w_row_norms = tf.linalg.norm(normalized_w, ord=2, axis=1)
  w_max_row_norm = tf.reduce_max(w_row_norms)
  factorized_matrix_row_norms = tf.linalg.norm(factorized_matrix, ord=2, axis=1)
  factorized_matrix_max_norm = tf.reduce_max(factorized_matrix_row_norms)

  clip, epsilon = dp_flags['l2_norm_clip'], dp_flags['epsilon']

  # Parameters for DP
  assert epsilon > 0, f'Epsilon should be positive, found {epsilon}.'
  assert clip is not None and clip > 0, f'Clip must be positive, found {clip}.'
  delta = dp_flags['delta'] or 1.0 / (clients_per_round * num_rounds)  # Default to delta = 1 / N.

  dim = get_total_dim(client_template)  
  padded_dim = pad_dim(dim)

  beta=compression_flags['beta']
  bits=compression_flags['num_bits']
  k_stddevs = compression_flags['k_stddevs'] or 4
  gamma, local_stddev = accounting_utils.ddgauss_params(
          epsilon=epsilon,
          l2_clip_norm=clip,
          bits=bits,
          num_clients=clients_per_round,
          dim=padded_dim,
          delta=delta,
          beta=beta,
          k=k_stddevs,
          w_max_norm=w_max_row_norm,
          s_max_norm=factorized_matrix_max_norm,
          steps=num_rounds,
          sqrt_num_parts=sqrt_num_parts)
  scale = 1.0 / gamma
  return _create_residual_linear_query_dp_factory(
      tensor_specs=tensor_specs,
      l2_norm_clip=clip,
      std_dev=local_stddev,
      w_matrix=w_matrix,
      h_matrix=h_matrix,
      clients_per_round=clients_per_round,
      seed=seed,
      residual_clear_query_fn=matrix_factorization_query.IdentityOnlineQuery,
      beta=beta,
      bits=bits,
      client_template=client_template,
      scale=scale,
      dim=padded_dim
  )


def create_residual_momentum_dp_factory(
    *,
    tensor_specs: matrix_factorization_query.NestedTensorSpec,
    l2_norm_clip: float,
    noise_multiplier: float,
    w_matrix: tf.Tensor,
    h_matrix: tf.Tensor,
    clients_per_round: int,
    seed: Optional[int],
    momentum_value: float,
    learning_rates: Any = None) -> tff.aggregators.DifferentiallyPrivateFactory:
  """Implements on-the-fly noise generation for the momentum partial sum query.

  W and H are assumed to represent a so-called 'streaming factorization' of
  the momentum matrix S.

  Args:
    tensor_specs: Nested tensor specs specifying the structure to which the
      constructed mechanism will be applied.
    l2_norm_clip: Global l2 norm to which to clip client contributions in the
      constructed aggregator factory.
    noise_multiplier: Multiplier to compute the standard deviation of noise to
      apply to clipped tensors, after transforming with (a normalized version
      of) the matrix `h_matrix`.
    w_matrix: The W term of a matrix factorization S = WH to use.
    h_matrix: The H term of a matrix factorization S = WH to use.
    clients_per_round: The number of clients per round to be used with this
      mechanism. Used to normalize the resulting values.
    seed: Optional seed which will guarantee deterministic noise generation.
    momentum_value: Value of the momentum parameter.
    learning_rates: A vector of learning rates, one per iteration/round.

  Returns:
    An instance of `tff.aggregators.DifferentiallyPrivateFactory` which
    implements residual of prefix-sum computations with the streaming matrix
    factorization mechanism.

  Raises: An assertion error if the provided factorization does not factorize
     the specified momentum matrix within an absolute tolerance of 1e-8.
  """

  factorized_matrix = (w_matrix @ h_matrix).numpy()
  momentum_matrix = matrix_constructors.momentum_sgd_matrix(
      num_iters=w_matrix.shape[0],
      momentum=momentum_value,
      learning_rates=learning_rates)
  np.testing.assert_allclose(factorized_matrix, momentum_matrix, atol=1e-8)

  def _clear_query_fn(
      tensor_specs: matrix_factorization_query.NestedTensorSpec
  ) -> matrix_factorization_query.OnlineQuery:
    return matrix_constructors.MomentumWithLearningRatesResidual(
        tensor_specs, momentum_value, learning_rates=learning_rates)

  return _create_residual_linear_query_dp_factory(
      tensor_specs=tensor_specs,
      l2_norm_clip=l2_norm_clip,
      noise_multiplier=noise_multiplier,
      w_matrix=w_matrix,
      h_matrix=h_matrix,
      clients_per_round=clients_per_round,
      seed=seed,
      residual_clear_query_fn=_clear_query_fn,
  )
