import multiprocessing as mp
import numpy as np
import os
#from numba import njit, prange
from joblib import Parallel, delayed
from scipy.linalg import hadamard


def quantile_binary_search(x, m, u, p, T=10, l=1):
    left = l
    right = u
    i = 0
    while i < T:
        mid = (left+right)/2.
        count = len([xi for xi in x if xi<=mid])
        count_noisy = count + np.random.normal(0, np.sqrt(T/(2.*p))) 
        if count_noisy <= m:
            left = mid# + 1
        else:
            right = mid
        i = i+1
    #print(left)
    #print(right)
    return (left+right)/2.

def clipped_mean(x, n, d, u, p, l=0, threshold=None):
    p1 = p * 0.25
    p2 = p * 0.75
    x_norm = np.linalg.norm(x, axis=1)
    if threshold is not None:
        C = threshold
    else:
        T = 10
        m = int(n - 2.*np.sqrt(d/(2.*p2)) - np.sqrt(T/(2*p1)))
        C = quantile_binary_search(x_norm,m,u,p1,T=T,l=0)
    x_clipped = []
    for i in range(len(x)):
        xi_norm = x_norm[i]
        scale = min(C/xi_norm,1.0)
        x_clipped.append(scale*np.array(x[i]))
    mean = np.mean(x_clipped,axis=0)
    #-------------
    # #non_pr = np.mean(x,axis=0)
    # #print(non_pr)
    # clipped = (x_norm > C)
    # x_hat = (x.T / x_norm).T
    # y = x.copy()
    # y[clipped] = x_hat[clipped]*C
    # #print(np.mean(y,axis=0))
    # mean = np.mean(y,axis=0)
    #-------------
    
    #print(mean)
    #print(np.linalg.norm(mean))
    #zero = [0.0]*d
    #cov = C*C/(2.*p)/n/n*np.eye(d)
    #noisy_mean = mean + np.random.multivariate_normal(zero, cov)
    noisy_mean = mean + np.random.normal(0, 2.*C/np.sqrt(2.*p2)/n, size=d)
    #print(noisy_mean)
    #print(np.linalg.norm(noisy_mean))
    return noisy_mean


def random_hadamard(d,t=1):
    H = hadamard(d)/np.sqrt(d)
    #print(np.ones((d,1))))
    #print(np.matmul(H,np.transpose(H)))
    rand = np.random.uniform(low=0,high=1,size=d)
    flipped = (rand < 0.5)
    diagonal = np.ones(d)
    diagonal[flipped] = -1
    D = np.diag(diagonal)
    HD = np.matmul(H,D)
    # rotate t-1 more times
    for i in range(t-1):
        rand = np.random.uniform(low=0,high=1,size=d)
        flipped = (rand < 0.5)
        diagonal = np.ones(d)
        diagonal[flipped] = -1
        D = np.diag(diagonal)
        #print(D)
        HD = np.matmul(HD,np.matmul(H,D))
    #print(HD)
    #print(np.matmul(HD,np.transpose(HD)))
    return HD
    

def random_rotation_mean(x,d,u,p,T=10,prop=0.25):
    n = len(x)
    HD = random_hadamard(d,t=3)
    #HD_inv = np.linalg.inv(HD)
    HD_inv = np.transpose(HD)
    #print(np.matmul(HD, HD_inv))
    x_hat = np.matmul(HD,x.T).T
    #print(x_hat[0])
    #print(np.matmul(HD,x[0]))
    #eps_d = prop*eps/np.sqrt(d)
    #p = eps*eps*0.5
    ps = prop*p/d
    #p = eps_d*eps_d*0.5

    # c_tilde = []
    # for o in range(d):
    #     xo = x_hat[:,o]
    #     co = quantile_binary_search(xo,0.5*n,up,ps,l=-u)
    #     c_tilde.append(co)
    #c_tilde = [0 for i in range(d)]
    up = u/np.sqrt(d)
    #up = u
    c_tilde = quantile_binary_search_mp(x_hat,d,0.5*n,up,ps,T=T,l=-up)
    #print(c_tilde)
    x_shifted = x_hat - c_tilde
    y_tilde = clipped_mean(x_shifted,n,d,u,(1-prop)*p)
    #rint(y_tilde+c_tilde)
    #print(np.linalg.norm(y_tilde+c_tilde))
    #print(np.linalg.norm(np.matmul(HD,np.ones((d,1))).reshape(-1)))
    y_hat = np.matmul(HD_inv,y_tilde+c_tilde)
    #print(np.linalg.norm(y_hat))
    return y_hat


def quantile_binary_search_mp_wrapper(args):
    return quantile_binary_search(*args)


def quantile_binary_search_mp(x,d,m,u,p,T=10,l=1):
    num_cores = mp.cpu_count()
    processed_list = Parallel(n_jobs=num_cores)(delayed(quantile_binary_search)(x[:,o],m,u,p,T=T,l=l) for o in range(d))
    #processed_list = None
    #with mp.Pool(4) as pool:
    #    processed_list = pool.map(quantile_binary_search_mp_wrapper, [(x[:,o],m,u,p,T,l) for o in range(d)])
    return processed_list
