# coding=utf-8
# Copyright 2023 The Private Kendall Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tukey code.

Adapted from
https://github.com/google-research/google-research/blob/master/dp_regression/tukey.py.
Modified functions are identified by comments in both the Returns docstring and
the code body.
"""


import numpy as np
import sklearn.linear_model

from private_kendall import random_shuffle


def perturb_and_sort_matrix(input_matrix):
  """Perturbs and sorts input_matrix.

  Args:
    input_matrix: Matrix in which each row is a vector to be perturbed and
      sorted.

  Returns:
    Adds a small amount of noise to each entry in input_matrix and then sorts
    each row in increasing order.
  """
  d, m = input_matrix.shape
  perturbation_matrix = 1e-10 * np.random.rand(d, m)
  perturbed_matrix = input_matrix + perturbation_matrix
  return np.sort(perturbed_matrix, axis=1)


def log_measure_geq_all_depths(projections):
  """Computes log(volume) of region of at least depth, for all possible depths.

  Args:
    projections: Matrix where each row is a projection of the data sorted in
      increasing order.

  Returns:
    Array A where A[i] is the (natural) logarithm of the volume in projections
    of the region of depth > i. This code has been modified from the original.
  """
  max_depth = int(np.ceil(len(projections[0]) / 2))
  diff = np.flip(projections, axis=1) - projections
  # The following line has been added. It ignores divide by 0 warnings, which
  # np.log handles as intended
  with np.errstate(divide="ignore"):
    return np.sum(np.log(diff[:, :max_depth]), axis=0)


def racing_sample(log_terms):
  """Numerically stable method for sampling from an exponential distribution.

  Args:
    log_terms: Array of terms of form log(coefficient) - (exponent term).

  Returns:
    A sample from the distribution over 0, 1, ..., len(log_terms) - 1 where
    integer k has probability proportional to exp(log_terms[k]). For details,
    see the "racing sampling" described in https://arxiv.org/abs/2201.12333.
  """
  return np.argmin(
      np.log(np.log(1.0 / np.random.uniform(size=log_terms.shape))) - log_terms)


def restricted_racing_sample_depth(projections, epsilon, restricted_depth):
  """Executes epsilon-DP Tukey depth exponential mechanism on projections.

  Args:
    projections: Matrix where each row is a projection of the data sorted in
      increasing order.
    epsilon: The output will be (epsilon, delta)-DP, where delta is determined
      by the preceding call to distance_to_unsafety.
    restricted_depth: Sampling will be restricted to points in projections of
      depth at least restricted_depth.

  Returns:
    A sample from the epsilon-DP Tukey depth exponential mechanism on
    projections (line 3 in Algorithm 1 in https://arxiv.org/pdf/2106.13329.pdf).
    This code has been modified from the original.
  """
  projections = projections[:, (restricted_depth-1):-(restricted_depth-1)]
  atleast_volumes = np.exp(log_measure_geq_all_depths(projections))
  measure_exact_all = atleast_volumes
  measure_exact_all[:-1] = atleast_volumes[:-1] - atleast_volumes[1:]
  depths = np.arange(restricted_depth,
                     restricted_depth + len(atleast_volumes))
  # The following line has been added. It ignores divide by 0 warnings, which
  # np.log handles as intended
  with np.errstate(divide="ignore"):
    log_terms = np.log(measure_exact_all) + epsilon * depths
  # Add 1 because returned depth is 0-indexed
  return 1 + racing_sample(log_terms)


def distance_to_unsafety(log_volumes, epsilon, delta, t, k_low, k_high):
  """Returns Hamming distance lower bound computed by PTR check.

  Args:
    log_volumes: Array of logarithmic volumes of regions of different depths;
      log_volumes[i] = log(volume of region of depth > i).
    epsilon: The overall check is (epsilon, delta)-DP.
    delta: The overall check is (epsilon, delta)-DP.
    t: Fixed depth around which neighboring depth volumes are computed.
    k_low: Lower bound for neighborhood size k. First call of this function
      should use k_low=-1
    k_high: Upper bound for neighborhood size k.

  Returns:
    Hamming distance lower bound computed by PTR check.
  """
  k = int((k_low + k_high) / 2)
  log_vol_ytk = log_volumes[t-k-2]
  log_vols_ytk_gs = log_volumes[t+k+1:len(log_volumes)-1]
  log_epsilon_terms = (epsilon / 2) * np.arange(1, len(log_vols_ytk_gs) + 1)
  log_threshold = np.log(delta / (8 * np.exp(epsilon)))
  if np.min((log_vol_ytk - log_vols_ytk_gs) -
            log_epsilon_terms) <= log_threshold:
    if k_low >= k_high - 1:
      return k_high
    else:
      new_k_low = k_low + int((k_high - k_low) / 2)
      return distance_to_unsafety(log_volumes, epsilon, delta, t, new_k_low,
                                  k_high)
  if k_high > k_low + 1:
    new_k_high = k_high - int((k_high - k_low) / 2)
    return distance_to_unsafety(log_volumes, epsilon, delta, t, k_low,
                                new_k_high)
  return k_low


def pairwiselogdiffexp(a, b):
  """Computes pairwise logdiffexp of a and b using max trick.

  Args:
    a: Array [a_1, ..., a_n]. Requires a_i >= b_i for all i.
    b: Array [b_1, ..., b_n]. Requires a_i >= b_i for all i.

  Returns:
    [log(exp(a_1) - exp(b_1)), ..., log(exp(a_n) - exp(b_n))].
  """
  return a + np.log(1 - np.exp(b - a))


def logsumexp(a):
  """Computes logsumexp of a using max trick.

  Args:
    a: Array [a_1, ..., a_n].

  Returns:
    log(exp(a_1) + ..., exp(a_n)).
  """
  max_element = np.amax(a)
  return max_element + np.log(np.sum(np.exp(a - max_element)))


def new_distance_to_unsafety(log_volumes, epsilon, delta, t, k_low, k_high):
  """Returns Hamming distance lower bound computed by PTR check.

  Args:
    log_volumes: Array of logarithmic volumes of regions of different depths;
      log_volumes[i] = log(volume of region of depth > i).
    epsilon: The overall check is (epsilon, delta)-DP.
    delta: The overall check is (epsilon, delta)-DP.
    t: Fixed depth around which neighboring depth volumes are computed.
    k_low: Lower bound for neighborhood size k. First call of this function
      should use k_low=-1
    k_high: Upper bound for neighborhood size k.

  Returns:
    Hamming distance lower bound computed by PTR check. This code has been
    modified from the original. See ``Modified PTR Lemma'' section of the paper
    for details.
  """
  k = int((k_low + k_high) / 2)
  log_exact_volumes = np.copy(log_volumes)
  # log_exact_volumes[i] is the log volume of the set of points with approximate
  # Tukey depth exactly i + 1
  log_exact_volumes[:-1] = pairwiselogdiffexp(log_volumes[:-1], log_volumes[1:])
  log_numerator = log_volumes[t - k - 2] + epsilon * (t + k + 1)
  # The denominator consists of the unnormalized probability mass assigned by
  # the exponential mechanism to the region of points of depth at least
  # t + k - 1. At depth q, the mass is exp(epsilon * q), which is integrated
  # over a region of size exp(log_exact_volumes[q+1]). To avoid numerical
  # overflows, we operate in log space for both the region volumes and the
  # scores.
  log_denominator = logsumexp(
      epsilon * np.arange(t + k - 1, len(log_exact_volumes) + 1)
      + log_exact_volumes[t + k - 2 :]
  )
  log_threshold = np.log(delta / (8 * np.exp(epsilon)))
  # Apart from the initial if statement, the following logic is unchanged
  if log_numerator - log_denominator <= log_threshold:
    if k_low >= k_high - 1:
      return k_high
    else:
      new_k_low = k_low + int((k_high - k_low) / 2)
      return new_distance_to_unsafety(
          log_volumes, epsilon, delta, t, new_k_low, k_high
      )
  if k_high > k_low + 1:
    new_k_high = k_high - int((k_high - k_low) / 2)
    return new_distance_to_unsafety(
        log_volumes, epsilon, delta, t, k_low, new_k_high
    )
  return k_low


def log_measure_geq_all_dims(depth, projections):
  """Computes log(length) of region of at least depth, for each dimension.

  Args:
    depth: Desired depth in projections. Assumes
      1 <= depth < len(projections[0]) / 2.
    projections: Matrix where each row is a projection of the data sorted in
      increasing order.

  Returns:
    Array A where A[j] is the (natural) logarithm of the length in projections
    of the region of depth >= i in dimension j+1.
  """
  return np.log(projections[:, -depth] - projections[:, depth - 1])


def sample_geq_1d(depth, projection):
  """Samples a point of at least given depth from projection.

  Args:
    depth: Lower bound on depth of point to sample. Assumes
      1 <= depth <= len(proj) / 2.
    projection: Increasing array of 1-dimensional points.

  Returns:
    Point sampled uniformly at random from the region of at least the given
    depth in projection.
  """
  low = projection[depth-1]
  high = projection[-depth]
  return np.random.uniform(low, high)


def sample_exact_1d(depth, projection):
  """Samples a point of exactly given depth from projection.

  Args:
    depth: Depth of point to sample. Assumes
      1 <= depth <= len(proj) / 2.
    projection: Increasing array of 1-dimensional points.

  Returns:
    Point sampled uniformly at random from the region of given depth in
    projection.
  """
  left_low = projection[depth-1]
  left_high = projection[depth]
  right_low = projection[-(depth+1)]
  right_high = projection[-depth]
  measure_left = left_high - left_low
  measure_right = right_high - right_low
  if np.random.uniform() < measure_left / (measure_left + measure_right):
    return left_low + np.random.uniform() * measure_left
  else:
    return right_low + np.random.uniform() * measure_right


def sample_exact(depth, projections):
  """Samples a point of exactly given depth from projections.

  Args:
    depth: Minimum depth of point to sample. Assumes
      1 <= depth < len(proj) / 2.
    projections: Matrix where each row is a projection of the data sorted in
      increasing order.

  Returns:
    Point sampled uniformly at random from the region of exactly given depth in
    projections.
  """
  d, _ = projections.shape
  log_measures_greater_than_depth = log_measure_geq_all_dims(
      depth + 1, projections
  )
  log_measures_geq_depth = log_measure_geq_all_dims(depth, projections)
  # exact_lengths[j] = W_{j+1, depth} in the paper's notation
  log_exact_lengths = np.log(
      np.exp(log_measures_geq_depth) - np.exp(log_measures_greater_than_depth)
  )
  # exp(log_volume_greater_than_depth_left[j]) = V_{<j+1, depth+1}, in the
  # paper's notation. log_volume_greater_than_depth_left[j] is the volume
  # measured along the first j dimensions of the region of depth greater than
  # the depth argument to this function
  log_volume_greater_than_depth_left = np.zeros(d)
  log_volume_greater_than_depth_left[1:] = np.cumsum(
      log_measures_greater_than_depth
  )[:-1]
  # exp(right_dims_geq_than_depth[j]) = V_{>j, depth}, in the paper's notation
  log_right_dims_geq_depth = np.zeros(d)
  log_right_dims_geq_depth[:-1] = (
      np.cumsum(log_measures_greater_than_depth[::-1])[::-1]
  )[1:]
  log_volumes = (
      log_exact_lengths
      + log_volume_greater_than_depth_left
      + log_right_dims_geq_depth
  )
  sampled_volume_idx = racing_sample(log_volumes)
  # Sampled point will be exactly depth in dimension sampled_volume_idx and
  # >= depth elsewhere
  sample = np.zeros(d)
  for j in range(sampled_volume_idx):
    sample[j] = sample_geq_1d(depth + 1, projections[j, :])
  sample[sampled_volume_idx] = sample_exact_1d(
      depth, projections[sampled_volume_idx, :])
  for j in range(sampled_volume_idx+1, d):
    sample[j] = sample_geq_1d(depth, projections[j, :])
  return sample


def multiple_regressions(features, labels, num_models, use_lasso):
  """Computes num_models regression models on random partition of given data.

  Args:
    features: Matrix of feature vectors. Assumed to have intercept feature.
    labels: Vector of labels.
    num_models: Number of models to train.
    use_lasso: Boolean determining whether or not to use Lasso.

  Returns:
    (num_models x d) matrix of models, where features is (n x d).

  Raises:
    RuntimeError: [num_models] models requires [num_models * d] points, but
    given features only has [len(features)] points. This code has been modified
    from the original.
  """
  (n, d) = features.shape
  batch_size = int(n / num_models)
  if batch_size < d:
    raise RuntimeError(
        str(num_models)
        + " models requires "
        + str(num_models * d)
        + " points, but given features only has "
        + str(n)
        + " points."
    )
  # This call to random_shuffle replaces identical logic in the original code
  features, labels = random_shuffle.random_shuffle(features, labels)
  models = np.zeros((num_models, d))
  # The following code has been modified to allow both Lasso and generic OLS
  if use_lasso:
    regression_function = sklearn.linear_model.Lasso(
        fit_intercept=False, max_iter=1000
    )
  else:
    regression_function = sklearn.linear_model.LinearRegression(
        fit_intercept=False
    )
  for i in range(num_models):
    shuffled_features_batch = features[
        batch_size * i : batch_size * i + batch_size, :
    ]
    shuffled_labels_batch = labels[batch_size * i : batch_size * i + batch_size]
    model = regression_function.fit(
        shuffled_features_batch, shuffled_labels_batch
    )
    models[i, :] = model.coef_
  return models


def tukey(models, epsilon, delta):
  """Runs (epsilon, delta)-DP Tukey mechanism using models.

  Args:
    models: Feature vectors of non-private models where each row is a model.
      Assumes that each user contributes to a single model.
    epsilon: Computed model satisfies (epsilon, delta)-DP.
    delta: Computed model satisfies (epsilon, delta)-DP.

  Returns:
    Regression model computed using Tukey mechanism, or "ptr fail" if PTR fails.
    This code has been modified from the original.
  """
  projections = perturb_and_sort_matrix(models.T)
  max_depth = int(len(models) / 2)
  # Compute log(volume_i)_{i=1}^max_depth where volume_i is the volume of
  # the region of depth >= i, according to projections
  log_volumes = log_measure_geq_all_depths(projections)
  t = int(max_depth / 2)
  # Do ptr check
  split_epsilon = epsilon / 2
  distance = new_distance_to_unsafety(
      log_volumes, split_epsilon, delta, t, -1, t - 1
  )
  threshold = np.log(1 / (2 * delta)) / split_epsilon
  if not distance + np.random.laplace(scale=split_epsilon) > threshold:
    # Unlike the original code, ours does not return a trivial model if PTR
    # fails
    return "ptr fail"
  # Sample a depth using the restricted exponential mechanism
  depth = restricted_racing_sample_depth(projections, split_epsilon, t)
  # Sample uniformly from the region of given depth
  return sample_exact(depth, projections[:, t:-t])[:, np.newaxis]
