# -*- coding: utf-8 -*-
"""
Created on Wed Apr 10 16:33:41 2024

@author: ZJ
"""
import numpy as np

def epsilon_element(epsilon, delta, d):
    return epsilon / np.sqrt(8*d*np.log(1/delta))

def nextpow2(d):
    return int(2 ** np.ceil(np.log2(d)))
    
def GetU(d):
    H = Hadamard(d)
    Diag = np.diag(np.random.choice([-1,1], d))
    ans = H.dot(Diag) / np.sqrt(d)
    return ans
    
def Hadamard(d):
    """
    d needs to be a power of 2.
    Use Sylvester's construction to get Hadamard matrix.
    """
    if d == 1:
        return np.ones((1,1))
    else:
        I = Hadamard(d//2)
        upper = np.hstack([I, I])
        lower = np.hstack([I, -I])
        ans = np.vstack([upper, lower])
        return ans
    
def clip(x, left, right):
    return max(min(x, right), left)

def PrivateRange(X, L, R, tau, epsilon):
    """
    X: one dimensional array
    """
    n_counts = int(np.ceil((R-L)/tau)) + 1
    counts = [0 for _ in range(n_counts)]
    for x in X:
        x = clip(x, L, R) #Clipping.
        i = int(np.round((x - L) / tau))
        counts[i] += 1
    s_left = [0 for _ in range(n_counts)] #cumulative counts from left
    s_left[0] = counts[0]
    for i in range(1, n_counts):
        s_left[i] = s_left[i-1] + counts[i]
    s_right = [0 for _ in range(n_counts)]
    s_right[n_counts - 1] = counts[n_counts - 1]
    for i in range(n_counts - 2, -1, -1):
        s_right[i] = s_right[i+1] + counts[i]
    c = np.zeros(n_counts)
    for i in range(n_counts):
        leftsum = s_left[i-1] if i >= 1 else 0
        rightsum = s_right[i+1] if i <= n_counts - 2 else 0
        c[i] = max(leftsum, rightsum)
    #For numerical consideration, subtract all items from c
    #so that the minimum becomes 0.
    c = c - np.min(c)
    c_exp = np.exp(-epsilon * c / 2)
    probs = c_exp / np.sum(c_exp)
    i_selected = np.random.choice(np.arange(n_counts), size = 1,\
                replace = False, p = probs)[0]
    x_selected = L + i_selected * tau
    return (x_selected - 2 * tau, x_selected + 2 * tau)
        
        