# -*- coding: utf-8 -*-
"""
Created on Fri Apr  5 14:29:48 2024

@author: ZJ
"""
import numpy as np

def GetBeta(d, epsilon, delta):
    if d == 1:
        return epsilon / (2*np.log(1/delta))
    else:
        return epsilon / (2*(d+np.log(2/delta)))

def GetAlpha(d, epsilon, delta):
    if d == 1:
        return epsilon / np.sqrt(np.log(1/delta))
    else:
        return epsilon / (5*np.sqrt(2*np.log(2/delta)))
    
def distance(u,v):
    #Here we use l2 distance.
    return np.sqrt(np.sum((u-v) ** 2))

def GetZScore(D):
    """
    For balanced users, return a scalar Z.
    """
    ybar = np.mean(D, axis = 0)
    return max([distance(D[i], ybar) for i in range(len(D))])

def GetZVector(D, weights):
    """
    For unbalanced users, m_vec is the vector of number of samples per user.
    Return a Z vector.
    """
    ybar = weights.dot(D)/np.sum(weights)
    return np.array([distance(D[i], ybar) for i in range(len(D))])

def h(D, weights, k, Tarray):
    """
    For unbalanced users, calculate h.
    """
    wsum = sum(weights)
    Z = GetZVector(D, weights)
    S = weights * (Tarray + Z)
    num = np.sum(np.partition(S, -k)[-k:])
    denom = wsum - np.sum(np.partition(weights, -k)[-k:])
    return num / denom

def Delta(D, T, kc):
    n = len(D)
    Zmax = (1 - 2*(kc+1)/n) * T
    ybar = np.mean(D, axis = 0)
    dists = np.array([distance(D[i], ybar) for i in range(len(D))])
    k = 0
    while k < kc and np.max(dists) > Zmax:
        mask = np.ones(len(D), bool)
        maxind = np.argmax(dists)
        mask[maxind] = False
        D = D[mask]
        ybar = np.mean(D, axis = 0)
        dists = np.array([distance(D[i], ybar) for i in range(len(D))])
        k+=1
    return k if k < kc else -1

def Delta_unbalanced(D, weights, Tarray, kc):
    Z = GetZVector(D, weights)
    val = np.min(Tarray - Z)
    k = 0
    Dtemp = D.copy()
    while k < kc and h(D, weights, kc+1, Tarray) >= val:
        minind = np.argmin(Tarray - Z)
        ybar = weights.dot(D)/np.sum(weights)
        Dtemp[minind] = ybar
        Z = GetZVector(Dtemp, weights)
        val = np.min(Tarray - Z)
        k += 1
    return k if k < kc else -1
    
def GetLambda(D, T, Rc, epsilon, delta):
    n, d = D.shape
    beta = GetBeta(d, epsilon, delta)
    kc = max(int(np.ceil(np.log(n * Rc / T)/beta)), int(n/4))
    dist_to_regular = Delta(D, T, kc)
    if dist_to_regular >= 0:
        G = 2 * Rc * np.ones(n)
        for k in range(kc - dist_to_regular + 1):
            G[k] = 2 * T / (n - k - dist_to_regular)
        Z = GetZScore(D)
        if Z < (1-2/n) * T:
            G[0] = (T+Z)/(n-1)
        return np.max(np.exp(-beta * np.arange(n)) * G)
    else:
        raise Exception("check")
        return 2 * Rc
    
def GetLambda_unbalanced(D, weights, Tarray, Rc, epsilon, delta):
    n, d = D.shape
    wsum = np.sum(weights)
    mtmax = np.max(weights * Tarray) 
    w_desc = np.sort(weights)[::-1]
    w_cum = np.zeros(n)
    w_cum[0] = w_desc[0]
    for i in range(1,n):
        w_cum[i] = w_cum[i-1] + w_desc[i]
    beta = GetBeta(d, epsilon, delta)
    kc = int(np.ceil(np.log(wsum * Rc / mtmax)/beta))
    dist_to_regular = Delta_unbalanced(D, weights, Tarray, kc)
    if dist_to_regular >= 0:
        G = 2 * Rc * np.ones(n)     
        for k in range(kc - dist_to_regular + 1):
            G[k] = 2 * mtmax / (wsum - w_cum[k + dist_to_regular])
        Z = GetZVector(D, weights)
        val = np.min(Tarray - Z)
        h1 = h(D, weights, 1, Tarray)
        if h1 < val:
            G[0] = h1
        return np.max(np.exp(-beta * np.arange(n)) * G)
    else:
        raise Exception("check")
        return 2 * Rc
    
    
"""
def Delta(D, T):
    n = len(D)
    ybar = np.mean(D, axis = 0)
    dists = np.array([distance(D[i], ybar) for i in range(len(D))])
    inds = np.argsort(dists)
    D_sorted = D[inds]
    Zarray = []
    kmarray = []
    for l in range(n):
        Dtemp = D_sorted[:n-l,:]
        Z = GetZScore(Dtemp)
        km = np.floor((1-Z/T)*n/2)-1-l
        Zarray.append(Z)
        kmarray.append(km)
    l = 0
    Deltaarray = -1 * np.ones(n)
    k = 0
    while True:
        if k <= kmarray[l]:
            Deltaarray[k] = l
            k += 1
        else:
            l += 1
        if k >= n or l >= n:
            break
    return Deltaarray

def GetGvalue(D, T, Rc):
    Deltaarray = Delta(D, T)
    n = len(D)
    G = np.zeros(n)
    G[0] = (T + GetZScore(D))/(n-1)
    for k in range(1, n):
        if Deltaarray[k] >= 0:
            G[k] = 2 * T / (n - k - Deltaarray[k])
        else:
            G[k] = 2 * Rc
    return G

def GetLambda(D, T, Rc, epsilon, delta):
    n, d = D.shape
    beta = GetBeta(d, epsilon, delta)
    G = GetGvalue(D, T, Rc)
    return np.max(np.exp(-beta * np.arange(n)) * G)
"""

def GetLambdaSimple(D, T, Z, Rc, epsilon, delta):
    ans = 0
    n, d = D.shape
    beta = GetBeta(d, epsilon, delta)
    km = int(np.floor((1-Z/T) * n/2))-1
    for k in range(km+1):
        ans = max(ans, 2 * T * np.exp(-beta * k) / (n-k))
    if ans < 2*Rc*np.exp(-beta*(km+1)):
        print(km)
        raise Exception("check")
    ans = max(ans, 2*Rc*np.exp(-beta*(km+1)))
    return ans      

def HLM_balanced(D, T):
    n, d = D.shape
    c = np.zeros(d)
    w = np.zeros(n)
    while True:
        for i in range(n):
            if distance(D[i], c)==0:
                w[i] = 1
            else:
                w[i] = min(1, T/distance(D[i], c))
        w = w / np.sum(w)
        cnext = np.matmul(w.reshape(1,-1), D).ravel()
        if distance(cnext, c) < 1e-5:
            return c
        else:
            c = cnext

def HLM_unbalanced(D, weights, Tarray):
    n, d = D.shape
    c = np.zeros(d)
    w = np.zeros(n)
    while True:
        for i in range(n):
            w[i] = weights[i] * min(1, Tarray[i]/distance(D[i], c))
        w = w / np.sum(w)
        cnext = np.matmul(w.reshape(1,-1), D).ravel()
        if distance(cnext, c) < 1e-5:
            return c
        else:
            c = cnext