

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import torch
from scipy.optimize import minimize,fsolve
from scipy.spatial.distance import cdist
import seaborn as sns
sns.set()

from geom_median.torch import compute_geometric_median

def zcdp_to_eps(rho,delta):
  """"
    conversion of zcdp gurantee to (eps,delta)-DP using the formula in Lemma 3.6 of [BS16]
    rho : zCDP
    delta: delta in DP

    return eps
  """
  return rho + np.sqrt(4 * rho * np.log(np.sqrt(np.pi * rho)/delta))

def eps_to_zcdp(eps,delta):
  """"
    conversion of (eps,delta) gurantee to rho-zCDP
    eps : eps in DP
    delta: delta in DP

    return rho
  """
  func_temp = lambda x: zcdp_to_eps(x,delta) - eps
  root = fsolve(func_temp,x0=0.001)[-1]
  return root

def DPGD(gm,theta_0,iters,priv_budget,stepsize,rad):
    """ Differentially Private Gradient Descent

    Parameters
    ----------
    gm = geometric median problem
    theta_0 = initialization
    iters = num of iterations
    priv_budget = privacy budget in zCDP
    stepsize = stepsize
    rad = radius of the feasible set defined as theta s.t.
            || theta - theta_0||<= rad.

    Return
    -------
    the average iterate of DPGD
    """
    sens = 2/(gm.n)  # sensitivity wrt to replacement
    rho = priv_budget / iters  # divide total privacy budget up
    avg_theta = theta_0
    theta = theta_0
    for t in range(iters):
        noise = np.random.normal(scale=sens/np.sqrt(2*rho),size=gm.d)
        theta = theta - stepsize * (gm.grad(theta) + noise)
        norm_diff = np.linalg.norm(theta - theta_0)
        theta = theta_0 + ((theta - theta_0)/max(1e-10,norm_diff)) * min(norm_diff,rad)
        avg_theta = (t/(t+1)) * avg_theta + (1/(t+1)) * theta
    return avg_theta

class MyGeometricMedian:
  """Represents a geometric median problem
  """

  def __init__(self,X,R):
    """Initialize the data.

    Parameters
    ----------
    X = n x d numpy array representing datapoints
    R = apriori bound on the norm of the datapoints.
    Data will be rescaled so that ||X[i,:]|| <= R for all i.
    """
    X = np.array(X)
    assert len(X.shape)==2
    self.R = R
    self.n, self.d = X.shape
    norms = np.linalg.norm(X, axis=1)
    scale = np.minimum(self.R/np.maximum(norms, 1e-10),np.ones_like(norms))
    self.data = (scale[:,None]) * X
    self.theta_star = self.exact_gm()

  def loss(self,theta):
    """Computes the loss represented by this object at theta.

    If X is the data, then the loss is
    (1/n)sum_i^n ||theta-x_i||
    """
    distances = np.linalg.norm(self.data - theta, axis=1)
    return np.mean(distances)

  def grad(self,theta):
    """Computes the gradient of by this object at theta.

    If X is the data, then the gradient is
    (1/n)sum_i^n (theta - x_i)/||theta-x_i||
     """
    differences = theta -  self.data
    norms = np.linalg.norm(differences, axis=1, keepdims=True)
    normalized_differences = differences / np.maximum(norms, 1e-10)
    average_normalized_diff = np.mean(normalized_differences, axis=0)
    return average_normalized_diff

  def exact_gm(self):
    """Computes the non-private geometric median of the dataset.
    """
    point_torch = [torch.from_numpy(pnt).float() for pnt in self.data]
    out = compute_geometric_median(point_torch, weights=None)
    geom_median = out.median.numpy()
    return geom_median

class WarmUpAlg:
  """Represents the warmup algorithm
  """

  def __init__(self,gm,r,rho,beta):
    """Initialize the algorithm.

    Parameters
    ----------
    gm = an instance of geometric median problem
    r = Discretization error
    rho = privacy budget in zCDP
    beta = failure probability
    """
    self.data = gm.data
    self.gm = gm
    self.n, self.d = gm.data.shape
    self.R  = gm.R
    self.rho = rho
    max_exponent = int(np.floor(np.log(2 * self.R / r) / np.log(2)))
    self.grid = [r * 2**i for i in range(max_exponent + 1)]
    gamma = 3/4
    self.m = int(np.ceil((self.n) * gamma))
    self.beta = beta

  def above_threshold(self, queries):
    """Above-Threshold-Algorithm

    Parameters
    ----------
    queries = the input queries
    """
    eps_abvt = np.sqrt(2 * (self.rho/2))
    len_q = len(self.grid)
    thresh = self.m + (18/eps_abvt) * np.log(2/self.beta * len_q)
    thresh_noisy = thresh + np.random.laplace(loc=0, scale = 6/eps_abvt)
    for idx, q in enumerate(queries):
        noise_val = np.random.laplace(loc=0, scale = 12/eps_abvt)
        if q + noise_val >= thresh_noisy:
            return idx
    # if the algorithm "fails", return a random index
    return np.random.randint(0,len(queries)-1)

  def compute_queries(self):
    """ computes the queries for the above threshold
    """
    A = cdist(self.data, self.data, 'euclidean')
    qeuries = np.zeros(len(self.grid))
    for i in range(len(self.grid)):
        rad = self.grid[i]
        binary_matrix = (A <= rad).astype(int)
        num_pnts = np.sum(binary_matrix, axis=1)
        sort_num_pnts = np.sort(num_pnts)
        qeuries[i] = (1/self.m)*sum(sort_num_pnts[-self.m:])
    return qeuries

  def radius_finder(self):
    """ the radius finder algorithm
        estimates Delta_{3n/4}(theta*)
    """
    qeuries = self.compute_queries()
    idx = self.above_threshold(qeuries)
    return self.grid[idx]

  def localization(self):
    """Finds a good initialization point
    """
    hat_delta = self.radius_finder()
    rad_old = self.R
    k_wu = int(np.ceil(np.log(self.R/hat_delta)/np.log(2)))
    k_wu = max(k_wu,1)
    priv_budget_gd = ((self.rho)/2) / k_wu
    theta_old = np.zeros(self.d)
    iters = 500
    for _ in range(k_wu):
      stepsize = rad_old * np.sqrt((2*self.d)/(3 * priv_budget_gd * (self.n)**2))
      theta_new = DPGD(self.gm,theta_old,iters,priv_budget_gd,stepsize,rad_old)
      rad_new = 1/2 * rad_old + 12 * hat_delta
      rad_old = rad_new
      theta_old = theta_new
    return theta_new, hat_delta

def compare_algs(gm,r,eps):
    """ compares our proposed algorithm and the baseline
        in terms of ratio of loss
    """
    rho = eps_to_zcdp(eps,1/gm.n)
    theta_star = gm.theta_star
    #baseline: proposed algorithm
    beta = 0.05
    wualg = WarmUpAlg(gm,r,rho/2,beta)
    theta_0, hat_delta = wualg.localization()
    rad_fine_tune = 25 * hat_delta
    stepsize_fine_tune = 50 * hat_delta * np.sqrt(gm.d/(6*rho* (gm.n)**2))
    iters_fine_tune = int((gm.n)**2 * rho/(256 * gm.d))
    theta_out = DPGD(gm,
                 theta_0,
                 iters_fine_tune,
                 rho/2,
                 stepsize_fine_tune,
                 rad_fine_tune)
    ratio_our = gm.loss(theta_out)/gm.loss(theta_star)
    #baseline: dpgd
    stepsize_0 = 2*R*np.sqrt(gm.d/(12*rho* (gm.n)**2))
    iters_0 = int((gm.n)**2 * rho/(128* gm.d))
    rad_0 = R
    theta_dpgd = DPGD(gm,
                 np.zeros(gm.d),
                 iters_0,
                 rho,
                 stepsize_0,
                 rad_0)
    ratio_dpgd = gm.loss(theta_dpgd)/gm.loss(theta_star)
    return ratio_our, ratio_dpgd

def my_data(dim,n):
    R = 1e2
    n1 = int(9/10 * n)
    n2 = int(n - n1)
    #group1
    v = np.random.multivariate_normal(mean = np.zeros(dim), cov = np.identity(dim))
    mu = R/2 * v/np.linalg.norm(v)
    points1 = [np.random.multivariate_normal(mean = mu, cov = 0.01 * np.identity(dim) ) for i in range(n1)]
    #group2
    random_directions = np.random.normal(size=(n2, dim))
    random_directions /= np.linalg.norm(random_directions, axis=1)[:, np.newaxis]
    random_radii = R * (np.random.rand(n2) ** (1/dim))
    points2 = random_directions * random_radii[:, np.newaxis]
    points = np.concatenate((points1,points2),axis=0)
    return points

R_vals = np.logspace(3,10,8)
num_rep = 20
dim = 200
n = 3000
r = 0.05

eps = 3
ratio_our_mean, ratio_our_std  = [], []
ratio_dpgd_mean, ratio_dpgd_std = [], []
datapoints = my_data(dim,n)
for R in R_vals:
    print(R)
    temp_ratio_our = []
    temp_ratio_dpgd = []
    gm = MyGeometricMedian(datapoints,R)
    for _ in range(num_rep):
        ratio_our, ratio_dpgd = compare_algs(gm,r,eps)
        temp_ratio_our.append(ratio_our)
        temp_ratio_dpgd.append(ratio_dpgd)
    ratio_our_mean.append(np.median(temp_ratio_our))
    ratio_our_std.append(np.std(temp_ratio_our)/np.sqrt(num_rep))
    ratio_dpgd_mean.append(np.median(temp_ratio_dpgd))
    ratio_dpgd_std.append(np.std(temp_ratio_dpgd)/np.sqrt(num_rep))

plt.title(r"$ d=200, n=3000, \varepsilon = 3, \delta = 1/n$",fontsize=20)
plt.errorbar(R_vals, ratio_our_mean, yerr=ratio_our_std, label = 'Ours', linewidth=1.5)
plt.errorbar(R_vals, ratio_dpgd_mean, yerr=ratio_dpgd_std,label = 'DPGD', linewidth=1.5)
plt.xscale('log')
plt.yscale('log')
plt.xlabel(r"$ R$",fontsize=20)
plt.ylabel(r"$F(\theta;X^{(n)})/F(\theta^\star;X^{(n)})$",fontsize=20)
plt.legend(fontsize="20")
plt.savefig('high-eps.pdf',bbox_inches='tight')
plt.show()

eps = 2
ratio_our_mean, ratio_our_std  =  [], []
ratio_dpgd_mean, ratio_dpgd_std =  [], []
datapoints = my_data(dim,n)
# gm = MyGeometricMedian(datapoints,R)
for R in R_vals:
    print(R)
    temp_ratio_our = []
    temp_ratio_dpgd = []
    gm = MyGeometricMedian(datapoints,R)
    for _ in range(num_rep):
        ratio_our, ratio_dpgd = compare_algs(gm,r,eps)
        temp_ratio_our.append(ratio_our)
        temp_ratio_dpgd.append(ratio_dpgd)
    ratio_our_mean.append(np.median(temp_ratio_our))
    ratio_our_std.append(np.std(temp_ratio_our)/np.sqrt(num_rep))
    ratio_dpgd_mean.append(np.median(temp_ratio_dpgd))
    ratio_dpgd_std.append(np.std(temp_ratio_dpgd)/np.sqrt(num_rep))

plt.title(r"$ d=200, n=3000, \varepsilon = 2, \delta = 1/n$",fontsize=20)
plt.errorbar(R_vals, ratio_our_mean, yerr=ratio_our_std, label = 'Ours', linewidth=1.5)
plt.errorbar(R_vals, ratio_dpgd_mean, yerr=ratio_dpgd_std,label = 'DPGD', linewidth=1.5)
plt.xscale('log')
plt.yscale('log')
plt.xlabel(r"$ R$",fontsize=20)
plt.savefig('low-eps.pdf')
plt.ylabel(r"$F(\theta;X^{(n)})/F(\theta^\star;X^{(n)})$",fontsize=20)
plt.legend(fontsize="20")
plt.savefig('low-eps.pdf',bbox_inches='tight')
plt.show()