import torch
import numpy as np
import pickle
#%%
def nearest_neighbor(gamma, x):
    return gamma[torch.topk(-torch.norm(x-gamma, dim=1), k=1)[1]]

def voronoi(n_atoms, n_iter, alpha, d):
    gamma = torch.normal(0, np.sqrt(1+2/d), size=(n_atoms, d))
    #gamma = [np.random.normal(0, 1, d) for i in range(n_atoms)]
    #gamma = [e/np.linalg.norm(e) for e in gamma]
    for k in range(n_iter):
        x = torch.normal(0, 1, size=(1, d))
        #x = np.random.normal(0, 1, d)
        #x /= np.linalg.norm(x)      #x on the sphere
        index = torch.topk(-torch.norm(x-gamma, dim=1), k=1)[1]
        gamma[index] -= alpha*(gamma[index]-x)
    return gamma


def quantize_sign(x):
    return torch.sign(x)


def quantize_vqsgd(x, d, workers):
    gamma = np.concatenate([np.sqrt(d)*np.eye(d), -np.sqrt(d)*np.eye(d)], 1)
    ga = 1-torch.norm(x, p=1)/np.sqrt(d)
    proba = []
    for i in range(d):
        if x[0][i] > 0:
            proba.append(x[0][i]/np.sqrt(d)+ga/(2*d))
        else:
            proba.append(ga/(2*d))
    for i in range(d, 2*d):
        if x[0][i-d] <= 0:
            proba.append(-x[0][i-d]/np.sqrt(d)+ga/(2*d))
        else:
            proba.append(ga/(2*d))
    
    proba = np.nan_to_num(proba, nan=1)
    proba /= np.linalg.norm(proba, 1)

    quantized = []
    for k in range(workers):
        quantized.append(np.transpose(gamma)[np.random.choice([i for i in range(gamma.shape[1])], 1, replace=True, p=proba)])
        
    return torch.tensor(np.array(quantized).mean(0))

def quantize_greedy_hsq(gamma, x):
    return nearest_neighbor(gamma, x)

def quantize_hsq(gamma, x, norm, nom, quantized=False):
    gamma = gamma.t()
    g = x.t()*norm
    C_dag = torch.matmul(gamma.t(), torch.pinverse(torch.matmul(gamma, gamma.t())))
    proba = torch.matmul(C_dag, g)#C_dag*g#torch.matmul(C_dag, g)
    # proba = torch.tensor(np.nan_to_num(proba, nan=1))
    p_tilde = torch.abs(proba) / torch.norm(proba, p=1)
    ###
    p_tilde = np.nan_to_num(p_tilde.numpy(), nan=1)
    p_tilde = torch.tensor(p_tilde / np.linalg.norm(p_tilde, 1))
    ###
    gamma = gamma.t()
    chos = np.random.choice([i for i in range(gamma.shape[0])], 1, replace=True, p=p_tilde.t().numpy()[0])
    # print((np.sign(proba[chos])*torch.norm(proba, p=1)).numpy())
    if quantized:
        pseudo_norm = uniform_quantizer((np.sign(proba[chos])*torch.norm(proba, p=1)).numpy(), 6, -nom, nom)
        return gamma[chos] * pseudo_norm
    # print((np.sign(proba[chos])*torch.norm(proba, p=1)).numpy())
    else:
        return gamma[chos] * (np.sign(proba[chos])*torch.norm(proba, p=1)).item()

def quantize_top1(x):
    vec = x
    ind = torch.zeros_like(vec)
    idx = torch.topk(torch.abs(vec), k=1, dim=1)[1]
    ind.scatter_(1, idx, 1)
    return vec * ind

def quantize_rand1(x, d):
    vec = x
    ind = torch.zeros_like(vec)
    idx = torch.randint(0, d, (1, 1))
    ind.scatter_(1, idx, 1)
    return d*vec * ind

def quantize_top2(x):
    crimefile = open('optimal-quantifier-8.txt', 'r')
    yourResult = [line.split(',') for line in crimefile.readlines()]
    optimal_scalaire_8_bits = [float(e[0][:-2]) for e in yourResult]
    crimefile.close()
    
    vec = x
    ind = torch.zeros_like(vec)
    idx = torch.topk(torch.abs(vec), k=2, dim=1)[1]
    ind.scatter_(1, idx, 1)
    qu =  vec * ind
    for i, e in enumerate(qu[0]):
        if e != 0:
            qu[0][i] = nearest_neighbor(torch.tensor(optimal_scalaire_8_bits).view(256, 1), e)
    return qu

def quantize_rand2(x, d):
    crimefile = open('optimal-quantifier-8.txt', 'r')
    yourResult = [line.split(',') for line in crimefile.readlines()]
    optimal_scalaire_8_bits = [float(e[0][:-2]) for e in yourResult]
    crimefile.close()
    
    vec = x
    ind = torch.zeros_like(vec)
    idx = torch.randint(0, d, (1, 2))
    ind.scatter_(1, idx, 1)
    qu = vec * ind
    for i, e in enumerate(qu[0]):
        if e != 0:
            qu[0][i] = nearest_neighbor(torch.tensor(optimal_scalaire_8_bits).view(256, 1), e)
    return d/2 * qu

#%%
def fix_bias(p, x):
    # if np.linalg.norm(x) > 5.9:
    #     return p[-1][0]
    # elif np.linalg.norm(x) < 0.1:
    #     return p[1][0]
    # else:
    #     return p[int(10*np.linalg.norm(x))][0]
    index = np.argmin(np.abs(np.array([i for i in range(61)])-10*np.linalg.norm(x)))
    if index == 0:    ###avoid nan inf
        index+=1
    return p[index][0]

# def fix_bias_3_bits(ip, x):
#     proba = 1 / np.abs(np.array(ip)-x.numpy())
#     return np.array(ip)[np.random.choice([i for i in range(8)], 1, replace=True, p=proba / proba.sum())]
def fix_bias_3_bits(ip, x):
    if x.numpy() > np.array(ip).max():
        return np.array(ip).max()
    elif x.numpy() < np.array(ip).min():
        return np.array(ip).min()
    else:
        idx = np.abs(np.array(ip)-x.numpy()).argsort()[0]
        if x.numpy() > np.array(ip)[idx]:
            proba = (x.numpy() - np.array(ip)[idx]) / (np.array(ip)[idx+1] - np.array(ip)[idx])
            return  (torch.bernoulli(torch.tensor(proba)) * (np.array(ip)[idx+1] - np.array(ip)[idx]) + np.array(ip)[idx]).numpy()
        if x.numpy() < np.array(ip)[idx]:
            proba = (x.numpy() - np.array(ip)[idx-1]) / (np.array(ip)[idx] - np.array(ip)[idx-1])
            return  (torch.bernoulli(torch.tensor(proba)) * (np.array(ip)[idx] - np.array(ip)[idx-1]) + np.array(ip)[idx-1]).numpy()
    
    
def uniform_quantizer(norm_x_true, b, umin, umax):
    unif = np.arange(umin, umax, (umax-umin)/2**b)
    index = np.argmin(np.abs(unif-norm_x_true))
    return unif[index]
#%%
class LeastSquaresRegressorTorch1():

    def __init__(self, n_iter=10, eta=0.1, batch_size=10, samples_per_worker=1, K=1):
        self.n_iter = n_iter
        self.eta = eta
        self.batch_size = batch_size
        self.samples_per_worker = samples_per_worker
        self.K = K
        self.saved_grads = list()
        
    def fit(self, X, Y, to_choose_workers, finalXtrain, init_w):

        n_instances, n_features = X.shape
        
        # we need to "wrap" the NumPy arrays X and Y as PyTorch tensors
        Xt = torch.tensor(X, dtype=torch.float)
        Yt = torch.tensor(Y, dtype=torch.float)
        
        # to_choose_workers = [np.random.choice([i for i in range(n_instances)], self.samples_per_worker, replace=True) for k in range(self.K)]

        # initialize the weight vector to all zeros
        # self.w = torch.zeros(n_features, requires_grad=True, dtype=torch.float)
        self.w = torch.tensor(init_w,  requires_grad=True)
        self.history = []
        # self.history.append(torch.sum((Xt.mv(self.w) - Yt)**2) / n_instances)
        self.history.append(torch.sum((Xt[finalXtrain].mv(self.w) - Yt[finalXtrain])**2) / len(finalXtrain))
        
        optimizer = torch.optim.SGD([self.w], lr=self.eta)
        
        for i in range(self.n_iter):
            print(i)
            concat_grad = []
            # Xt, Yt = shuffle(Xt, Yt)
            for k in range(self.K):
                X = Xt[to_choose_workers[k]]
                Y = Yt[to_choose_workers[k]]
                
                grad_batch = torch.zeros(n_features)
            
                for batch_start in range(0, self.samples_per_worker, self.batch_size):
                    batch_end = batch_start + self.batch_size
                    Xbatch = X[batch_start:batch_end, :]
                    Ybatch = Y[batch_start:batch_end]
                    # to_choose = np.random.choice([i for i in range(self.samples_per_worker)], self.batch_size, replace=True)
                    # Xbatch = Xt[to_choose, :]
                    # Ybatch = Yt[to_choose]
                
                    # mv = matrix-vector multiplication in Torch
                    G = Xbatch.mv(self.w)
                    
                    Error = G - Ybatch
                    loss_batch = torch.sum(Error**2) / self.batch_size
                    
                    # reset all gradients
                    optimizer.zero_grad()                  
    
                    # compute the gradients for the loss for this batch
                    loss_batch.backward()
                    
                    grad_batch += self.w.grad / int(self.samples_per_worker/self.batch_size)
                    
                concat_grad.append(grad_batch.numpy())
            self.saved_grads += concat_grad
            self.w.grad = torch.tensor(np.array(concat_grad).mean(0))
                # for SGD, this is equivalent to w -= learning_rate * gradient as we saw before
            optimizer.step()
                          
            self.history.append(torch.sum((Xt[finalXtrain].mv(self.w) - Yt[finalXtrain])**2) / len(finalXtrain))

#%%
def lsr_grads_distortion(n_machines):
    d=16
    n_iter=1000
    lr=0.1
    ### quadratic error empirical both on X and grid
    Err_sign = []
    Err_sign_rad = []
    Err_random_voronoi = []
    Err_random_voronoi_rad = []
    Err_random_voronoi_unbiased = []
    Err_random_voronoi_unbiased_rad = []
    Err_random_voronoi_uq = []
    Err_random_voronoi_uq_rad = []
    Err_hsq = []
    Err_hsq_rad = []
    Err_hsq_q = []
    Err_hsq_q_rad = []
    Err_greedy_hsq = []
    Err_greedy_hsq_rad = []
    Err_greedy_hsq_q = []
    Err_greedy_hsq_q_rad = []
    Err_top2 = []
    Err_top2_rad = []
    Err_rand2 = []
    Err_rand2_rad = []
    Err_vqsgd = []
    Err_vqsgd_rad = []
    Err_vqsgd_q = []
    Err_vqsgd_q_rad = []
    
    inv_pich = [1.302554292, 1.348986681, 1.391496061, 1.433970247, 1.48116051, 1.535987439, 1.604238237, 1.700437542]
    with open("./radial_biases/radial_bias_"+str(16)+"_"+str(2**13)+"_"+".txt", "rb") as fp:   #Pickling
        pich = pickle.load(fp)
        
    gamma_hsq = voronoi(2**10, n_iter, lr, d)    ### optimal Voronoi
    print("optimal 0 done")
    gamma_hsq = [gamma_hsq for k in range(n_machines)]
    gamma_hsq = [g / torch.norm(g, dim=1, keepdim=True) for g in gamma_hsq]
    
    n_samples = 2**14
    n_features = 16 * 32  #= 2**9
    nb_buckets = 32
    # n_atoms = 2**16 ###total bit budget instead
    niter = 10
    bs = 128 #2**6   
    # workers = 2**5 #K=10;100
    spw = 2**11#2**8 #2**9
    Xtrain = np.random.normal(size = (n_samples, n_features))
    Ytrain = -0.9 * Xtrain[:,0] + 0.4 * np.random.normal(size=Xtrain.shape[0])
    to_choose_workers = [np.random.choice([i for i in range(n_samples)], spw, replace=True) for k in range(n_machines)]
    
    finalXtrain = []
    for k in range(n_machines):
        for i in range(n_samples):
            if i not in finalXtrain and i in to_choose_workers[k]:
                finalXtrain.append(i)
                
    
    ### Find lr optimal
    u, s, vh = np.linalg.svd(np.matmul(np.transpose(Xtrain), Xtrain))
    L = s.max() / n_samples
    f=3
    lr_gd = 1/(f*L)

    init_w = 2 * torch.rand(n_features, dtype=torch.float) - 1
    ### find loss_star:
    regr1 = LeastSquaresRegressorTorch1(n_iter=niter, eta=lr_gd, batch_size=bs, samples_per_worker=spw, K=n_machines)
    regr1.fit(Xtrain, Ytrain, to_choose_workers, finalXtrain, init_w)
    grad = regr1.saved_grads
    
    for b in range(niter):
        g = grad[b*n_machines:(b+1)*n_machines]
        g=np.array(g).reshape((n_machines, -1))
        
        
        for s in range(1):
            print(s)
        
            err_sign = []
            err_sign_rad = []
            err_random_voronoi = []
            err_random_voronoi_rad = []
            err_random_voronoi_unbiased = []
            err_random_voronoi_unbiased_rad = []
            err_random_voronoi_uq = []
            err_random_voronoi_uq_rad = []
            err_hsq = []
            err_hsq_rad = []
            err_hsq_q = []
            err_hsq_q_rad = []
            err_greedy_hsq = []
            err_greedy_hsq_rad = []
            err_greedy_hsq_q = []
            err_greedy_hsq_q_rad = []
            err_top2 = []
            err_top2_rad = []
            err_rand2 = []
            err_rand2_rad = []
            err_vqsgd = []
            err_vqsgd_rad = []
            err_vqsgd_q = []
            err_vqsgd_q_rad = []
            
            for j in range(32*s, 32*(s+1)):
                
                x = torch.tensor(np.array([torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16)).numpy() for k in range(n_machines)]).mean(0)) #x_avg
                
                avg_q = np.array([quantize_sign(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)])).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_sign.append(np.linalg.norm(q_para-x.numpy())**2)
                err_sign_rad.append(np.linalg.norm(q_rad)**2)
                # err_sign.append(np.linalg.norm(np.array([quantize_sign(x).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
                # err_qsgd.append(np.linalg.norm(np.array([quantize_qsgd(x, 1, d).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
    
                if torch.norm(torch.tensor(np.array([torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16)).numpy()/np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0))) < 6:

                    gamma_random = [torch.normal(0, np.sqrt(1+2/d), size=(2**13, d)) for k in range(n_machines)]                

                    # err_random_voronoi.append(np.linalg.norm(np.array([nearest_neighbor(gamma_random[k], x/torch.std(x)).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
                    # err_random_voronoi.append(np.linalg.norm(np.array([nearest_neighbor(gamma_random[k], x/torch.std(x)).numpy() / fix_bias(pich, torch.norm(x, dim=1, keepdim=True))[0].item() for k in range(n_machines)]).mean(0)-x.numpy())**2)
                    avg_q = np.array([nearest_neighbor(gamma_random[k], torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)])).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                    q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                    q_rad = avg_q-q_para
                    err_random_voronoi.append(np.linalg.norm(q_para-x.numpy())**2)
                    err_random_voronoi_rad.append(np.linalg.norm(q_rad)**2)
    
                    avg_q = np.array([nearest_neighbor(gamma_random[k], torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)])).numpy() / fix_bias(pich, torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True)).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                    q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                    q_rad = avg_q-q_para
                    err_random_voronoi_unbiased.append(np.linalg.norm(q_para-x.numpy())**2)
                    err_random_voronoi_unbiased_rad.append(np.linalg.norm(q_rad)**2)         
                    
                    avg_q = np.array([nearest_neighbor(gamma_random[k], torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)])).numpy() * fix_bias_3_bits(inv_pich, 1/fix_bias(pich, torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True))) * np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                    # avg_q = np.array([nearest_neighbor(gamma_random[k], x/torch.std(x)).numpy() * fix_bias_3_bits(inv_pich, 1/fix_bias(pich, torch.norm(x, dim=1, keepdim=True))) for k in range(n_machines)]).mean(0)
                    q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                    q_rad = avg_q-q_para
                    err_random_voronoi_uq.append(np.linalg.norm(q_para-x.numpy())**2)
                    err_random_voronoi_uq_rad.append(np.linalg.norm(q_rad)**2)
                    
                # err_top2.append(np.linalg.norm(np.array([quantize_top2(x).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
                avg_q = np.array([quantize_top2(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)])).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_top2.append(np.linalg.norm(q_para-x.numpy())**2)
                err_top2_rad.append(np.linalg.norm(q_rad)**2)
                # err_rand2.append(np.linalg.norm(np.array([quantize_rand2(x, d).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
    
                avg_q = np.array([quantize_rand2(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), d).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_rand2.append(np.linalg.norm(q_para-x.numpy())**2)
                err_rand2_rad.append(np.linalg.norm(q_rad)**2)
     
                # norm_x_true = torch.norm(torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True)
                # # x_norm = x / norm_x_true
                # norm_x = uniform_quantizer(norm_x_true.numpy(), 6, 0, 12)     
                
                # err_hsq.append(np.linalg.norm(np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 6).numpy() for k in range(n_machines)]).mean(0)-(x_norm*norm_x_true).numpy())**2)
                # err_hsq.append(np.linalg.norm(np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 15).numpy() for k in range(n_machines)]).mean(0)-(x_norm*norm_x_true).numpy())**2)
                # avg_q = np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 15).numpy() for k in range(n_machines)]).mean(0)
                avg_q = np.array([quantize_hsq(gamma_hsq[k], torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]) / torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), 25).numpy() * np.std(g[0, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_hsq.append(np.linalg.norm(q_para-x.numpy())**2)
                err_hsq_rad.append(np.linalg.norm(q_rad)**2)
                ### avec et sans quantif aussi faire 2 colonnes
                avg_q = np.array([quantize_hsq(gamma_hsq[k], torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]) / torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), 25, quantized=True).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                # avg_q = np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 15, quantized=True).numpy() for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_hsq_q.append(np.linalg.norm(q_para-x.numpy())**2)
                err_hsq_q_rad.append(np.linalg.norm(q_rad)**2)
                
                
                # # err_greedy_hsq.append(np.linalg.norm(np.array([quantize_greedy_hsq(gamma_hsq[k], x_norm).numpy() for k in range(n_machines)]).mean(0)*norm_x-(x_norm*norm_x_true).numpy())**2)
                avg_q = np.array([quantize_greedy_hsq(gamma_hsq[k], torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]) / torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True)).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) * torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True).numpy() for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_greedy_hsq.append(np.linalg.norm(q_para-x.numpy())**2)
                err_greedy_hsq_rad.append(np.linalg.norm(q_rad)**2)
                # ### avec et sans quantif norm
                avg_q = np.array([quantize_greedy_hsq(gamma_hsq[k], torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]) / torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True)).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) * uniform_quantizer(torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True).numpy(), 6, 0, 12)  for k in range(n_machines)]).mean(0)
                err_greedy_hsq_q.append(np.linalg.norm(q_para-x.numpy())**2)
                err_greedy_hsq_q_rad.append(np.linalg.norm(q_rad)**2)
                
                # # err_vqsgd.append(np.linalg.norm(np.array([quantize_vqsgd(x_norm, d, 2).numpy() for k in range(n_machines)]).mean(0)*norm_x-(x_norm*norm_x_true).numpy())**2)
                avg_q = np.array([quantize_vqsgd(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]) / torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), d, 2).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) * torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True).numpy() for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_vqsgd.append(np.linalg.norm(q_para-x.numpy())**2)
                err_vqsgd_rad.append(np.linalg.norm(q_rad)**2)
                # ### avec et sans quantif norm
                avg_q = np.array([quantize_vqsgd(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]) / torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), d, 2).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) * uniform_quantizer(torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True).numpy(), 6, 0, 12) for k in range(n_machines)]).mean(0)
                err_vqsgd_q.append(np.linalg.norm(q_para-x.numpy())**2)
                err_vqsgd_q_rad.append(np.linalg.norm(q_rad)**2)
                
            if n_machines>1:
                Err_sign.append(np.array(err_sign).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_sign_rad.append(np.array(err_sign_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_random_voronoi.append(np.array(err_random_voronoi).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_random_voronoi_rad.append(np.array(err_random_voronoi_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_random_voronoi_unbiased.append(np.array(err_random_voronoi_unbiased).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_random_voronoi_unbiased_rad.append(np.array(err_random_voronoi_unbiased_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_random_voronoi_uq.append(np.array(err_random_voronoi_uq).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_random_voronoi_uq_rad.append(np.array(err_random_voronoi_uq_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_hsq.append(np.array(err_hsq).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_hsq_rad.append(np.array(err_hsq_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_hsq_q.append(np.array(err_hsq_q).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_hsq_q_rad.append(np.array(err_hsq_q_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_greedy_hsq.append(np.array(err_greedy_hsq).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_greedy_hsq_rad.append(np.array(err_greedy_hsq_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_greedy_hsq_q.append(np.array(err_greedy_hsq_q).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_greedy_hsq_q_rad.append(np.array(err_greedy_hsq_q_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_top2.append(np.array(err_top2).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_top2_rad.append(np.array(err_top2_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_rand2.append(np.array(err_rand2).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_rand2_rad.append(np.array(err_rand2_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_vqsgd.append(np.array(err_vqsgd).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_vqsgd_rad.append(np.array(err_vqsgd_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_vqsgd_q.append(np.array(err_vqsgd_q).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_vqsgd_q_rad.append(np.array(err_vqsgd_q_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
            else:
                Err_sign.append(np.array(err_sign).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_sign_rad.append(np.array(err_sign_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_random_voronoi.append(np.array(err_random_voronoi).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_random_voronoi_rad.append(np.array(err_random_voronoi_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_random_voronoi_unbiased.append(np.array(err_random_voronoi_unbiased).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_random_voronoi_unbiased_rad.append(np.array(err_random_voronoi_unbiased_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_random_voronoi_uq.append(np.array(err_random_voronoi_uq).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_random_voronoi_uq_rad.append(np.array(err_random_voronoi_uq_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_hsq.append(np.array(err_hsq).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_hsq_rad.append(np.array(err_hsq_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_hsq_q.append(np.array(err_hsq_q).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_hsq_q_rad.append(np.array(err_hsq_q_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_greedy_hsq.append(np.array(err_greedy_hsq).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_greedy_hsq_rad.append(np.array(err_greedy_hsq_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_greedy_hsq_q.append(np.array(err_greedy_hsq_q).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_greedy_hsq_q_rad.append(np.array(err_greedy_hsq_q_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_top2.append(np.array(err_top2).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_top2_rad.append(np.array(err_top2_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_rand2.append(np.array(err_rand2).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_rand2_rad.append(np.array(err_rand2_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_vqsgd.append(np.array(err_vqsgd).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_vqsgd_rad.append(np.array(err_vqsgd_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_vqsgd_q.append(np.array(err_vqsgd_q).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_vqsgd_q_rad.append(np.array(err_vqsgd_q_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
        
    print('err_sign', np.array(Err_sign).mean(), np.array(Err_sign_rad).mean())
    print("err_top2",  np.array(Err_top2).mean(), np.array(Err_top2_rad).mean())
    print("err_rand2",  np.array(Err_rand2).mean(), np.array(Err_rand2_rad).mean())
    print('err_polytope', np.array(Err_vqsgd).mean(), np.array(Err_vqsgd_rad).mean())
    print('err_polytope_q', np.array(Err_vqsgd_q).mean(), np.array(Err_vqsgd_q_rad).mean())
    print("err_hsq",  np.array(Err_hsq).mean(), np.array(Err_hsq_rad).mean())
    print("err_hsq_q",  np.array(Err_hsq_q).mean(), np.array(Err_hsq_q_rad).mean())
    print("err_greedy_hsq",  np.array(Err_greedy_hsq).mean(), np.array(Err_greedy_hsq_rad).mean())
    print("err_greedy_hsq_q",  np.array(Err_greedy_hsq_q).mean(), np.array(Err_greedy_hsq_q_rad).mean())
    print("err_random_voronoi",  np.array(Err_random_voronoi).mean(), np.array(Err_random_voronoi_rad).mean())
    print("err_random_voronoi_unbiased",  np.array(Err_random_voronoi_unbiased).mean(), np.array(Err_random_voronoi_unbiased_rad).mean())
    print("err_random_voronoi_uq",  np.array(Err_random_voronoi_uq).mean(), np.array(Err_random_voronoi_uq_rad).mean())

    print('err_sign', np.array(Err_sign).std(), np.array(Err_sign_rad).std())
    print("err_top2",  np.array(Err_top2).std(), np.array(Err_top2_rad).std())
    print("err_rand2",  np.array(Err_rand2).std(), np.array(Err_rand2_rad).std())
    print('err_polytope', np.array(Err_vqsgd).std(), np.array(Err_vqsgd_rad).std())
    print('err_polytope_q', np.array(Err_vqsgd_q).std(), np.array(Err_vqsgd_q_rad).std())
    print("err_hsq",  np.array(Err_hsq).std(), np.array(Err_hsq_rad).std())
    print("err_hsq_q",  np.array(Err_hsq_q).std(), np.array(Err_hsq_q_rad).std())
    print("err_greedy_hsq",  np.array(Err_greedy_hsq).std(), np.array(Err_greedy_hsq_rad).std())
    print("err_greedy_hsq_q",  np.array(Err_greedy_hsq_q).std(), np.array(Err_greedy_hsq_q_rad).std())
    print("err_random_voronoi",  np.array(Err_random_voronoi).std(), np.array(Err_random_voronoi_rad).std())
    print("err_random_voronoi_unbiased",  np.array(Err_random_voronoi_unbiased).std(), np.array(Err_random_voronoi_unbiased_rad).std())
    print("err_random_voronoi_uq",  np.array(Err_random_voronoi_uq).std(), np.array(Err_random_voronoi_uq_rad).std())

#%%
if __name__ == "__main__":
    lsr_grads_distortion(1)
    lsr_grads_distortion(8)
