import torch
import numpy as np
import pickle

#%%
def nearest_neighbor(gamma, x):
    return gamma[torch.topk(-torch.norm(x-gamma, dim=1), k=1)[1]]

def voronoi(n_atoms, n_iter, alpha, d):
    gamma = torch.normal(0, np.sqrt(1+2/d), size=(n_atoms, d))
    #gamma = [np.random.normal(0, 1, d) for i in range(n_atoms)]
    #gamma = [e/np.linalg.norm(e) for e in gamma]
    for k in range(n_iter):
        x = torch.normal(0, 1, size=(1, d))
        #x = np.random.normal(0, 1, d)
        #x /= np.linalg.norm(x)      #x on the sphere
        index = torch.topk(-torch.norm(x-gamma, dim=1), k=1)[1]
        gamma[index] -= alpha*(gamma[index]-x)
    return gamma


def quantize_sign(x):
    return torch.sign(x)


def quantize_vqsgd(x, d, workers):
    gamma = np.concatenate([np.sqrt(d)*np.eye(d), -np.sqrt(d)*np.eye(d)], 1)
    ga = 1-torch.norm(x, p=1)/np.sqrt(d)
    proba = []
    for i in range(d):
        if x[0][i] > 0:
            proba.append(x[0][i]/np.sqrt(d)+ga/(2*d))
        else:
            proba.append(ga/(2*d))
    for i in range(d, 2*d):
        if x[0][i-d] <= 0:
            proba.append(-x[0][i-d]/np.sqrt(d)+ga/(2*d))
        else:
            proba.append(ga/(2*d))
    
    proba = np.nan_to_num(proba, nan=1)
    proba /= np.linalg.norm(proba, 1)

    quantized = []
    for k in range(workers):
        quantized.append(np.transpose(gamma)[np.random.choice([i for i in range(gamma.shape[1])], 1, replace=True, p=proba)])
        
    return torch.tensor(np.array(quantized).mean(0))

def quantize_greedy_hsq(gamma, x):
    return nearest_neighbor(gamma, x)

def quantize_hsq(gamma, x, norm, nom, quantized=False):
    gamma = gamma.t()
    g = x.t()*norm
    C_dag = torch.matmul(gamma.t(), torch.pinverse(torch.matmul(gamma, gamma.t())))
    proba = torch.matmul(C_dag, g)#C_dag*g#torch.matmul(C_dag, g)
    # proba = torch.tensor(np.nan_to_num(proba, nan=1))
    p_tilde = torch.abs(proba) / torch.norm(proba, p=1)
    ###
    p_tilde = np.nan_to_num(p_tilde.numpy(), nan=1)
    p_tilde = torch.tensor(p_tilde / np.linalg.norm(p_tilde, 1))
    ###
    gamma = gamma.t()
    chos = np.random.choice([i for i in range(gamma.shape[0])], 1, replace=True, p=p_tilde.t().numpy()[0])
    # print((np.sign(proba[chos])*torch.norm(proba, p=1)).numpy())
    if quantized:
        pseudo_norm = uniform_quantizer((np.sign(proba[chos])*torch.norm(proba, p=1)).numpy(), 6, -nom, nom)
        return gamma[chos] * pseudo_norm
    # print((np.sign(proba[chos])*torch.norm(proba, p=1)).numpy())
    else:
        return gamma[chos] * (np.sign(proba[chos])*torch.norm(proba, p=1)).item()

def quantize_top1(x):
    vec = x
    ind = torch.zeros_like(vec)
    idx = torch.topk(torch.abs(vec), k=1, dim=1)[1]
    ind.scatter_(1, idx, 1)
    return vec * ind

def quantize_rand1(x, d):
    vec = x
    ind = torch.zeros_like(vec)
    idx = torch.randint(0, d, (1, 1))
    ind.scatter_(1, idx, 1)
    return d*vec * ind

def quantize_top2(x):
    crimefile = open('optimal-quantifier-8.txt', 'r')
    yourResult = [line.split(',') for line in crimefile.readlines()]
    optimal_scalaire_8_bits = [float(e[0][:-2]) for e in yourResult]
    crimefile.close()
    
    vec = x
    ind = torch.zeros_like(vec)
    idx = torch.topk(torch.abs(vec), k=2, dim=1)[1]
    ind.scatter_(1, idx, 1)
    qu =  vec * ind
    for i, e in enumerate(qu[0]):
        if e != 0:
            qu[0][i] = nearest_neighbor(torch.tensor(optimal_scalaire_8_bits).view(256, 1), e)
    return qu

def quantize_rand2(x, d):
    crimefile = open('optimal-quantifier-8.txt', 'r')
    yourResult = [line.split(',') for line in crimefile.readlines()]
    optimal_scalaire_8_bits = [float(e[0][:-2]) for e in yourResult]
    crimefile.close()
    
    vec = x
    ind = torch.zeros_like(vec)
    idx = torch.randint(0, d, (1, 2))
    ind.scatter_(1, idx, 1)
    qu = vec * ind
    for i, e in enumerate(qu[0]):
        if e != 0:
            qu[0][i] = nearest_neighbor(torch.tensor(optimal_scalaire_8_bits).view(256, 1), e)
    return d/2 * qu

#%%
def fix_bias(p, x):
    # if np.linalg.norm(x) > 5.9:
    #     return p[-1][0]
    # elif np.linalg.norm(x) < 0.1:
    #     return p[1][0]
    # else:
    #     return p[int(10*np.linalg.norm(x))][0]
    index = np.argmin(np.abs(np.array([i for i in range(61)])-10*np.linalg.norm(x)))
    if index == 0:    ###avoid nan inf
        index+=1
    return p[index][0]

# def fix_bias_3_bits(ip, x):
#     proba = 1 / np.abs(np.array(ip)-x.numpy())
#     return np.array(ip)[np.random.choice([i for i in range(8)], 1, replace=True, p=proba / proba.sum())]
def fix_bias_3_bits(ip, x):
    if x.numpy() > np.array(ip).max():
        return np.array(ip).max()
    elif x.numpy() < np.array(ip).min():
        return np.array(ip).min()
    else:
        idx = np.abs(np.array(ip)-x.numpy()).argsort()[0]
        if x.numpy() > np.array(ip)[idx]:
            proba = (x.numpy() - np.array(ip)[idx]) / (np.array(ip)[idx+1] - np.array(ip)[idx])
            return  (torch.bernoulli(torch.tensor(proba)) * (np.array(ip)[idx+1] - np.array(ip)[idx]) + np.array(ip)[idx]).numpy()
        if x.numpy() < np.array(ip)[idx]:
            proba = (x.numpy() - np.array(ip)[idx-1]) / (np.array(ip)[idx] - np.array(ip)[idx-1])
            return  (torch.bernoulli(torch.tensor(proba)) * (np.array(ip)[idx] - np.array(ip)[idx-1]) + np.array(ip)[idx-1]).numpy()
    
    
def uniform_quantizer(norm_x_true, b, umin, umax):
    unif = np.arange(umin, umax, (umax-umin)/2**b)
    index = np.argmin(np.abs(unif-norm_x_true))
    return unif[index]
#%%
def table2_radial(n_machines):
    d=16
    n_iter=1000
    n_samples_x=10000
    n_samples_grid=1
    lr=0.1
    ### quadratic error empirical both on X and grid
    err_sign = []
    err_sign_rad = []
    err_random_voronoi = []
    err_random_voronoi_rad = []
    err_random_voronoi_unbiased = []
    err_random_voronoi_unbiased_rad = []
    err_random_voronoi_uq = []
    err_random_voronoi_uq_rad = []
    err_hsq = []
    err_hsq_rad = []
    err_hsq_q = []
    err_hsq_q_rad = []
    err_greedy_hsq = []
    err_greedy_hsq_rad = []
    err_greedy_hsq_q = []
    err_greedy_hsq_q_rad = []
    err_top2 = []
    err_top2_rad = []
    err_rand2 = []
    err_rand2_rad = []
    err_vqsgd = []
    err_vqsgd_rad = []
    err_vqsgd_q = []
    err_vqsgd_q_rad = []
    
    inv_pich = [1.302554292, 1.348986681, 1.391496061, 1.433970247, 1.48116051, 1.535987439, 1.604238237, 1.700437542]
    with open("./radial_biases/radial_bias_"+str(16)+"_"+str(2**13)+"_"+".txt", "rb") as fp:   #Pickling
        pich = pickle.load(fp)
        
    gamma_hsq = voronoi(2**10, n_iter, lr, d)    ### optimal Voronoi
    print("optimal 0 done")
    gamma_hsq = [gamma_hsq for k in range(n_machines)]
    gamma_hsq = [g / torch.norm(g, dim=1, keepdim=True) for g in gamma_hsq]
    

    for i in range(n_samples_x):
        x = torch.normal(0, 1, size=(1, d))
        print(i)
        for j in range(n_samples_grid):
            avg_q = np.array([quantize_sign(x).numpy() for k in range(n_machines)]).mean(0)
            q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
            q_rad = avg_q-q_para
            err_sign.append(np.linalg.norm(q_para-x.numpy())**2)
            err_sign_rad.append(np.linalg.norm(q_rad)**2)
            # err_sign.append(np.linalg.norm(np.array([quantize_sign(x).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
            # err_qsgd.append(np.linalg.norm(np.array([quantize_qsgd(x, 1, d).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)

            gamma_random = [torch.normal(0, np.sqrt(1+2/d), size=(2**13, d)) for k in range(n_machines)]
            
            # err_random_voronoi.append(np.linalg.norm(np.array([nearest_neighbor(gamma_random[k], x/torch.std(x)).numpy()*fix_bias_3_bits(inv_pich, 1/fix_bias(pich, torch.norm(x, dim=1, keepdim=True))) for k in range(n_machines)]).mean(0)-x.numpy())**2)
            ###
            ### Picherande non codé TODO
            ###
            # err_random_voronoi.append(np.linalg.norm(np.array([nearest_neighbor(gamma_random[k], x/torch.std(x)).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
            # err_random_voronoi.append(np.linalg.norm(np.array([nearest_neighbor(gamma_random[k], x/torch.std(x)).numpy() / fix_bias(pich, torch.norm(x, dim=1, keepdim=True))[0].item() for k in range(n_machines)]).mean(0)-x.numpy())**2)
            avg_q = np.array([nearest_neighbor(gamma_random[k], x).numpy() for k in range(n_machines)]).mean(0)
            q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
            q_rad = avg_q-q_para
            err_random_voronoi.append(np.linalg.norm(q_para-x.numpy())**2)
            err_random_voronoi_rad.append(np.linalg.norm(q_rad)**2)

            avg_q = np.array([nearest_neighbor(gamma_random[k], x).numpy() / fix_bias(pich, torch.norm(x, dim=1, keepdim=True)).numpy() for k in range(n_machines)]).mean(0)
            q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
            q_rad = avg_q-q_para
            err_random_voronoi_unbiased.append(np.linalg.norm(q_para-x.numpy())**2)
            err_random_voronoi_unbiased_rad.append(np.linalg.norm(q_rad)**2)         
            
            avg_q = np.array([nearest_neighbor(gamma_random[k], x).numpy() * fix_bias_3_bits(inv_pich, 1/fix_bias(pich, torch.norm(x, dim=1, keepdim=True))) for k in range(n_machines)]).mean(0)
            q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
            q_rad = avg_q-q_para
            err_random_voronoi_uq.append(np.linalg.norm(q_para-x.numpy())**2)
            err_random_voronoi_uq_rad.append(np.linalg.norm(q_rad)**2)
            
            # err_top2.append(np.linalg.norm(np.array([quantize_top2(x).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
            avg_q = np.array([quantize_top2(x).numpy() for k in range(n_machines)]).mean(0)
            q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
            q_rad = avg_q-q_para
            err_top2.append(np.linalg.norm(q_para-x.numpy())**2)
            err_top2_rad.append(np.linalg.norm(q_rad)**2)
            # err_rand2.append(np.linalg.norm(np.array([quantize_rand2(x, d).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
            avg_q = np.array([quantize_rand2(x, d).numpy() for k in range(n_machines)]).mean(0)
            q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
            q_rad = avg_q-q_para
            err_rand2.append(np.linalg.norm(q_para-x.numpy())**2)
            err_rand2_rad.append(np.linalg.norm(q_rad)**2)
 
            norm_x_true = torch.norm(x, dim=1, keepdim=True)
            x_norm = x / norm_x_true
            norm_x = uniform_quantizer(norm_x_true.numpy(), 6, 0, 6)     
            
            # err_hsq.append(np.linalg.norm(np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 6).numpy() for k in range(n_machines)]).mean(0)-(x_norm*norm_x_true).numpy())**2)
            # err_hsq.append(np.linalg.norm(np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 15).numpy() for k in range(n_machines)]).mean(0)-(x_norm*norm_x_true).numpy())**2)
            avg_q = np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 15).numpy() for k in range(n_machines)]).mean(0)
            q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
            q_rad = avg_q-q_para
            err_hsq.append(np.linalg.norm(q_para-(x_norm*norm_x_true).numpy())**2)
            err_hsq_rad.append(np.linalg.norm(q_rad)**2)
            ### avec et sans quantif aussi faire 2 colonnes
            avg_q = np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 15, quantized=True).numpy() for k in range(n_machines)]).mean(0)
            q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
            q_rad = avg_q-q_para
            err_hsq_q.append(np.linalg.norm(q_para-(x_norm*norm_x_true).numpy())**2)
            err_hsq_q_rad.append(np.linalg.norm(q_rad)**2)
            
            
            # err_greedy_hsq.append(np.linalg.norm(np.array([quantize_greedy_hsq(gamma_hsq[k], x_norm).numpy() for k in range(n_machines)]).mean(0)*norm_x-(x_norm*norm_x_true).numpy())**2)
            avg_q = np.array([quantize_greedy_hsq(gamma_hsq[k], x_norm).numpy() for k in range(n_machines)]).mean(0)
            q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
            q_rad = avg_q-q_para
            err_greedy_hsq.append(np.linalg.norm(q_para*norm_x_true.numpy()-(x_norm*norm_x_true).numpy())**2)
            err_greedy_hsq_rad.append(np.linalg.norm(q_rad*norm_x_true.numpy())**2)
            ### avec et sans quantif norm
            err_greedy_hsq_q.append(np.linalg.norm(q_para*norm_x-(x_norm*norm_x_true).numpy())**2)
            err_greedy_hsq_q_rad.append(np.linalg.norm(q_rad*norm_x)**2)
            
            # err_vqsgd.append(np.linalg.norm(np.array([quantize_vqsgd(x_norm, d, 2).numpy() for k in range(n_machines)]).mean(0)*norm_x-(x_norm*norm_x_true).numpy())**2)
            avg_q = np.array([quantize_vqsgd(x_norm, d, 2).numpy() for k in range(n_machines)]).mean(0)
            q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
            q_rad = avg_q-q_para
            err_vqsgd.append(np.linalg.norm(q_para*norm_x_true.numpy()-(x_norm*norm_x_true).numpy())**2)
            err_vqsgd_rad.append(np.linalg.norm(q_rad*norm_x_true.numpy())**2)
            ### avec et sans quantif norm
            err_vqsgd_q.append(np.linalg.norm(q_para*norm_x-(x_norm*norm_x_true).numpy())**2)
            err_vqsgd_q_rad.append(np.linalg.norm(q_rad*norm_x)**2)
            
            
    print('err_sign', np.array(err_sign).mean(), np.array(err_sign_rad).mean())
    print("err_top2",  np.array(err_top2).mean(), np.array(err_top2_rad).mean())
    print("err_rand2",  np.array(err_rand2).mean(), np.array(err_rand2_rad).mean())
    print('err_polytope', np.array(err_vqsgd).mean(), np.array(err_vqsgd_rad).mean())
    print('err_polytope_q', np.array(err_vqsgd_q).mean(), np.array(err_vqsgd_q_rad).mean())
    print("err_hsq",  np.array(err_hsq).mean(), np.array(err_hsq_rad).mean())
    print("err_hsq_q",  np.array(err_hsq_q).mean(), np.array(err_hsq_q_rad).mean())
    print("err_greedy_hsq",  np.array(err_greedy_hsq).mean(), np.array(err_greedy_hsq_rad).mean())
    print("err_greedy_hsq_q",  np.array(err_greedy_hsq_q).mean(), np.array(err_greedy_hsq_q_rad).mean())
    print("err_random_voronoi",  np.array(err_random_voronoi).mean(), np.array(err_random_voronoi_rad).mean())
    print("err_random_voronoi_unbiased",  np.array(err_random_voronoi_unbiased).mean(), np.array(err_random_voronoi_unbiased_rad).mean())
    print("err_random_voronoi_uq",  np.array(err_random_voronoi_uq).mean(), np.array(err_random_voronoi_uq_rad).mean())
    
    print('err_sign', np.array(err_sign).std(), np.array(err_sign_rad).std())
    print("err_top2",  np.array(err_top2).std(), np.array(err_top2_rad).std())
    print("err_rand2",  np.array(err_rand2).std(), np.array(err_rand2_rad).std())
    print('err_polytope', np.array(err_vqsgd).std(), np.array(err_vqsgd_rad).std())
    print('err_polytope_q', np.array(err_vqsgd_q).std(), np.array(err_vqsgd_q_rad).std())
    print("err_hsq",  np.array(err_hsq).std(), np.array(err_hsq_rad).std())
    print("err_hsq_q",  np.array(err_hsq_q).std(), np.array(err_hsq_q_rad).std())
    print("err_greedy_hsq",  np.array(err_greedy_hsq).std(), np.array(err_greedy_hsq_rad).std())
    print("err_greedy_hsq_q",  np.array(err_greedy_hsq_q).std(), np.array(err_greedy_hsq_q_rad).std())
    print("err_random_voronoi",  np.array(err_random_voronoi).std(), np.array(err_random_voronoi_rad).std())
    print("err_random_voronoi_unbiased",  np.array(err_random_voronoi_unbiased).std(), np.array(err_random_voronoi_unbiased_rad).std())
    print("err_random_voronoi_uq",  np.array(err_random_voronoi_uq).std(), np.array(err_random_voronoi_uq_rad).std())

#%%
def table4_radial_all(n_machines):
    d=16
    n_iter=1000
    lr=0.1
    ### quadratic error empirical both on X and grid
    Err_sign = []
    Err_sign_rad = []
    Err_random_voronoi = []
    Err_random_voronoi_rad = []
    Err_random_voronoi_unbiased = []
    Err_random_voronoi_unbiased_rad = []
    Err_random_voronoi_uq = []
    Err_random_voronoi_uq_rad = []
    Err_hsq = []
    Err_hsq_rad = []
    Err_hsq_q = []
    Err_hsq_q_rad = []
    Err_greedy_hsq = []
    Err_greedy_hsq_rad = []
    Err_greedy_hsq_q = []
    Err_greedy_hsq_q_rad = []
    Err_top2 = []
    Err_top2_rad = []
    Err_rand2 = []
    Err_rand2_rad = []
    Err_vqsgd = []
    Err_vqsgd_rad = []
    Err_vqsgd_q = []
    Err_vqsgd_q_rad = []
    
    inv_pich = [1.302554292, 1.348986681, 1.391496061, 1.433970247, 1.48116051, 1.535987439, 1.604238237, 1.700437542]
    with open("./radial_biases/radial_bias_"+str(16)+"_"+str(2**13)+"_"+".txt", "rb") as fp:   #Pickling
        pich = pickle.load(fp)
        
    gamma_hsq = voronoi(2**10, n_iter, lr, d)    ### optimal Voronoi
    print("optimal 0 done")
    gamma_hsq = [gamma_hsq for k in range(n_machines)]
    gamma_hsq = [g / torch.norm(g, dim=1, keepdim=True) for g in gamma_hsq]
    

    for b in range(10):
        with open("./saved_grads/saved_grad_epoch_10_"+str(b)+"_cifar10_vgg16_sgd_8_256_1.txt", "rb") as fp:   #Pickling
            g = np.array(pickle.load(fp))
        g=g.reshape((8, -1))
        
        for s in range(70):
            print(s)
        
            err_sign = []
            err_sign_rad = []
            err_random_voronoi = []
            err_random_voronoi_rad = []
            err_random_voronoi_unbiased = []
            err_random_voronoi_unbiased_rad = []
            err_random_voronoi_uq = []
            err_random_voronoi_uq_rad = []
            err_hsq = []
            err_hsq_rad = []
            err_hsq_q = []
            err_hsq_q_rad = []
            err_greedy_hsq = []
            err_greedy_hsq_rad = []
            err_greedy_hsq_q = []
            err_greedy_hsq_q_rad = []
            err_top2 = []
            err_top2_rad = []
            err_rand2 = []
            err_rand2_rad = []
            err_vqsgd = []
            err_vqsgd_rad = []
            err_vqsgd_q = []
            err_vqsgd_q_rad = []
            
            for j in range(32*s, 32*(s+1)):
                
                x = torch.tensor(np.array([torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16)).numpy() for k in range(n_machines)]).mean(0)) #x_avg
                
                avg_q = np.array([quantize_sign(torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)])).numpy() * np.std(g[0, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_sign.append(np.linalg.norm(q_para-x.numpy())**2)
                err_sign_rad.append(np.linalg.norm(q_rad)**2)
                # err_sign.append(np.linalg.norm(np.array([quantize_sign(x).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
                # err_qsgd.append(np.linalg.norm(np.array([quantize_qsgd(x, 1, d).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
    
                if torch.norm(torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True) < 6:

                    gamma_random = [torch.normal(0, np.sqrt(1+2/d), size=(2**13, d)) for k in range(n_machines)]
                    
                    # err_random_voronoi.append(np.linalg.norm(np.array([nearest_neighbor(gamma_random[k], x/torch.std(x)).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
                    # err_random_voronoi.append(np.linalg.norm(np.array([nearest_neighbor(gamma_random[k], x/torch.std(x)).numpy() / fix_bias(pich, torch.norm(x, dim=1, keepdim=True))[0].item() for k in range(n_machines)]).mean(0)-x.numpy())**2)
                    avg_q = np.array([nearest_neighbor(gamma_random[k], torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)])).numpy() * np.std(g[0, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                    q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                    q_rad = avg_q-q_para
                    err_random_voronoi.append(np.linalg.norm(q_para-x.numpy())**2)
                    err_random_voronoi_rad.append(np.linalg.norm(q_rad)**2)

                    avg_q = np.array([nearest_neighbor(gamma_random[k], torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)])).numpy() / fix_bias(pich, torch.norm(torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True)).numpy() * np.std(g[0, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                    q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                    q_rad = avg_q-q_para
                    err_random_voronoi_unbiased.append(np.linalg.norm(q_para-x.numpy())**2)
                    err_random_voronoi_unbiased_rad.append(np.linalg.norm(q_rad)**2)         
                    
                    avg_q = np.array([nearest_neighbor(gamma_random[k], torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)])).numpy() * fix_bias_3_bits(inv_pich, 1/fix_bias(pich, torch.norm(torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True))) * np.std(g[0, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                    # avg_q = np.array([nearest_neighbor(gamma_random[k], x/torch.std(x)).numpy() * fix_bias_3_bits(inv_pich, 1/fix_bias(pich, torch.norm(x, dim=1, keepdim=True))) for k in range(n_machines)]).mean(0)
                    q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                    q_rad = avg_q-q_para
                    err_random_voronoi_uq.append(np.linalg.norm(q_para-x.numpy())**2)
                    err_random_voronoi_uq_rad.append(np.linalg.norm(q_rad)**2)
                
                # err_top2.append(np.linalg.norm(np.array([quantize_top2(x).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
                avg_q = np.array([quantize_top2(torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)])).numpy() * np.std(g[0, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_top2.append(np.linalg.norm(q_para-x.numpy())**2)
                err_top2_rad.append(np.linalg.norm(q_rad)**2)
                # err_rand2.append(np.linalg.norm(np.array([quantize_rand2(x, d).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
    
                avg_q = np.array([quantize_rand2(torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)]), d).numpy() * np.std(g[0, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_rand2.append(np.linalg.norm(q_para-x.numpy())**2)
                err_rand2_rad.append(np.linalg.norm(q_rad)**2)
     
                norm_x_true = torch.norm(torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True)
                # x_norm = x / norm_x_true
                norm_x = uniform_quantizer(norm_x_true.numpy(), 6, 0, 12)     
                
                # err_hsq.append(np.linalg.norm(np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 6).numpy() for k in range(n_machines)]).mean(0)-(x_norm*norm_x_true).numpy())**2)
                # err_hsq.append(np.linalg.norm(np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 15).numpy() for k in range(n_machines)]).mean(0)-(x_norm*norm_x_true).numpy())**2)
                # avg_q = np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 15).numpy() for k in range(n_machines)]).mean(0)
                avg_q = np.array([quantize_hsq(gamma_hsq[k], torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)]) / torch.norm(torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), torch.norm(torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), 25).numpy() * np.std(g[0, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_hsq.append(np.linalg.norm(q_para-x.numpy())**2)
                err_hsq_rad.append(np.linalg.norm(q_rad)**2)
                ### avec et sans quantif aussi faire 2 colonnes
                avg_q = np.array([quantize_hsq(gamma_hsq[k], torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)]) / torch.norm(torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), torch.norm(torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), 25, quantized=True).numpy() * np.std(g[0, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                # avg_q = np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 15, quantized=True).numpy() for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_hsq_q.append(np.linalg.norm(q_para-x.numpy())**2)
                err_hsq_q_rad.append(np.linalg.norm(q_rad)**2)
                
                
                # err_greedy_hsq.append(np.linalg.norm(np.array([quantize_greedy_hsq(gamma_hsq[k], x_norm).numpy() for k in range(n_machines)]).mean(0)*norm_x-(x_norm*norm_x_true).numpy())**2)
                avg_q = np.array([quantize_greedy_hsq(gamma_hsq[k], torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)]) / norm_x_true).numpy() * np.std(g[0, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_greedy_hsq.append(np.linalg.norm(q_para*norm_x_true.numpy()-x.numpy())**2)
                err_greedy_hsq_rad.append(np.linalg.norm(q_rad*norm_x_true.numpy())**2)
                ### avec et sans quantif norm
                err_greedy_hsq_q.append(np.linalg.norm(q_para*norm_x-x.numpy())**2)
                err_greedy_hsq_q_rad.append(np.linalg.norm(q_rad*norm_x)**2)
                
                # err_vqsgd.append(np.linalg.norm(np.array([quantize_vqsgd(x_norm, d, 2).numpy() for k in range(n_machines)]).mean(0)*norm_x-(x_norm*norm_x_true).numpy())**2)
                avg_q = np.array([quantize_vqsgd(torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)]) / norm_x_true, d, 2).numpy() * np.std(g[0, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_vqsgd.append(np.linalg.norm(q_para*norm_x_true.numpy()-x.numpy())**2)
                err_vqsgd_rad.append(np.linalg.norm(q_rad*norm_x_true.numpy())**2)
                ### avec et sans quantif norm
                err_vqsgd_q.append(np.linalg.norm(q_para*norm_x-x.numpy())**2)
                err_vqsgd_q_rad.append(np.linalg.norm(q_rad*norm_x)**2)
                
            Err_sign.append(np.array(err_sign).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_sign_rad.append(np.array(err_sign_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_random_voronoi.append(np.array(err_random_voronoi).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_random_voronoi_rad.append(np.array(err_random_voronoi_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_random_voronoi_unbiased.append(np.array(err_random_voronoi_unbiased).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_random_voronoi_unbiased_rad.append(np.array(err_random_voronoi_unbiased_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_random_voronoi_uq.append(np.array(err_random_voronoi_uq).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_random_voronoi_uq_rad.append(np.array(err_random_voronoi_uq_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_hsq.append(np.array(err_hsq).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_hsq_rad.append(np.array(err_hsq_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_hsq_q.append(np.array(err_hsq_q).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_hsq_q_rad.append(np.array(err_hsq_q_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_greedy_hsq.append(np.array(err_greedy_hsq).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_greedy_hsq_rad.append(np.array(err_greedy_hsq_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_greedy_hsq_q.append(np.array(err_greedy_hsq_q).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_greedy_hsq_q_rad.append(np.array(err_greedy_hsq_q_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_top2.append(np.array(err_top2).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_top2_rad.append(np.array(err_top2_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_rand2.append(np.array(err_rand2).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_rand2_rad.append(np.array(err_rand2_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_vqsgd.append(np.array(err_vqsgd).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_vqsgd_rad.append(np.array(err_vqsgd_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_vqsgd_q.append(np.array(err_vqsgd_q).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
            Err_vqsgd_q_rad.append(np.array(err_vqsgd_q_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
        
        
    print('err_sign', np.array(Err_sign).mean(), np.array(Err_sign_rad).mean())
    print("err_top2",  np.array(Err_top2).mean(), np.array(Err_top2_rad).mean())
    print("err_rand2",  np.array(Err_rand2).mean(), np.array(Err_rand2_rad).mean())
    print('err_polytope', np.array(Err_vqsgd).mean(), np.array(Err_vqsgd_rad).mean())
    print('err_polytope_q', np.array(Err_vqsgd_q).mean(), np.array(Err_vqsgd_q_rad).mean())
    print("err_hsq",  np.array(Err_hsq).mean(), np.array(Err_hsq_rad).mean())
    print("err_hsq_q",  np.array(Err_hsq_q).mean(), np.array(Err_hsq_q_rad).mean())
    print("err_greedy_hsq",  np.array(Err_greedy_hsq).mean(), np.array(Err_greedy_hsq_rad).mean())
    print("err_greedy_hsq_q",  np.array(Err_greedy_hsq_q).mean(), np.array(Err_greedy_hsq_q_rad).mean())
    print("err_random_voronoi",  np.array(Err_random_voronoi).mean(), np.array(Err_random_voronoi_rad).mean())
    print("err_random_voronoi_unbiased",  np.array(Err_random_voronoi_unbiased).mean(), np.array(Err_random_voronoi_unbiased_rad).mean())
    print("err_random_voronoi_uq",  np.array(Err_random_voronoi_uq).mean(), np.array(Err_random_voronoi_uq_rad).mean())
    
    print('err_sign', np.array(Err_sign).std(), np.array(Err_sign_rad).std())
    print("err_top2",  np.array(Err_top2).std(), np.array(Err_top2_rad).std())
    print("err_rand2",  np.array(Err_rand2).std(), np.array(Err_rand2_rad).std())
    print('err_polytope', np.array(Err_vqsgd).std(), np.array(Err_vqsgd_rad).std())
    print('err_polytope_q', np.array(Err_vqsgd_q).std(), np.array(Err_vqsgd_q_rad).std())
    print("err_hsq",  np.array(Err_hsq).std(), np.array(Err_hsq_rad).std())
    print("err_hsq_q",  np.array(Err_hsq_q).std(), np.array(Err_hsq_q_rad).std())
    print("err_greedy_hsq",  np.array(Err_greedy_hsq).std(), np.array(Err_greedy_hsq_rad).std())
    print("err_greedy_hsq_q",  np.array(Err_greedy_hsq_q).std(), np.array(Err_greedy_hsq_q_rad).std())
    print("err_random_voronoi",  np.array(Err_random_voronoi).std(), np.array(Err_random_voronoi_rad).std())
    print("err_random_voronoi_unbiased",  np.array(Err_random_voronoi_unbiased).std(), np.array(Err_random_voronoi_unbiased_rad).std())
    print("err_random_voronoi_uq",  np.array(Err_random_voronoi_uq).std(), np.array(Err_random_voronoi_uq_rad).std())


def table4_radial_distributed_grad(n_machines):
    d=16
    n_iter=1000
    lr=0.1
    ### quadratic error empirical both on X and grid
    Err_sign = []
    Err_sign_rad = []
    Err_random_voronoi = []
    Err_random_voronoi_rad = []
    Err_random_voronoi_unbiased = []
    Err_random_voronoi_unbiased_rad = []
    Err_random_voronoi_uq = []
    Err_random_voronoi_uq_rad = []
    Err_hsq = []
    Err_hsq_rad = []
    Err_hsq_q = []
    Err_hsq_q_rad = []
    Err_greedy_hsq = []
    Err_greedy_hsq_rad = []
    Err_greedy_hsq_q = []
    Err_greedy_hsq_q_rad = []
    Err_top2 = []
    Err_top2_rad = []
    Err_rand2 = []
    Err_rand2_rad = []
    Err_vqsgd = []
    Err_vqsgd_rad = []
    Err_vqsgd_q = []
    Err_vqsgd_q_rad = []
    
    inv_pich = [1.302554292, 1.348986681, 1.391496061, 1.433970247, 1.48116051, 1.535987439, 1.604238237, 1.700437542]
    with open("./radial_biases/radial_bias_"+str(16)+"_"+str(2**13)+"_"+".txt", "rb") as fp:   #Pickling
        pich = pickle.load(fp)
        
    gamma_hsq = voronoi(2**10, n_iter, lr, d)    ### optimal Voronoi
    print("optimal 0 done")
    gamma_hsq = [gamma_hsq for k in range(n_machines)]
    gamma_hsq = [g / torch.norm(g, dim=1, keepdim=True) for g in gamma_hsq]
    

    for b in range(1):
        with open("./saved_grads/saved_grad_epoch_10_"+str(b)+"_cifar10_vgg16_sgd_8_4096_1.txt", "rb") as fp:   #Pickling
            g = np.array(pickle.load(fp))
        g=g.reshape((8, -1))
        
        
        for s in range(70):
            print(s)
        
            err_sign = []
            err_sign_rad = []
            err_random_voronoi = []
            err_random_voronoi_rad = []
            err_random_voronoi_unbiased = []
            err_random_voronoi_unbiased_rad = []
            err_random_voronoi_uq = []
            err_random_voronoi_uq_rad = []
            err_hsq = []
            err_hsq_rad = []
            err_hsq_q = []
            err_hsq_q_rad = []
            err_greedy_hsq = []
            err_greedy_hsq_rad = []
            err_greedy_hsq_q = []
            err_greedy_hsq_q_rad = []
            err_top2 = []
            err_top2_rad = []
            err_rand2 = []
            err_rand2_rad = []
            err_vqsgd = []
            err_vqsgd_rad = []
            err_vqsgd_q = []
            err_vqsgd_q_rad = []
            
            for j in range(32*s, 32*(s+1)):
                
                x = torch.tensor(np.array([torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16)).numpy() for k in range(n_machines)]).mean(0)) #x_avg
                
                avg_q = np.array([quantize_sign(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)])).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_sign.append(np.linalg.norm(q_para-x.numpy())**2)
                err_sign_rad.append(np.linalg.norm(q_rad)**2)
                # err_sign.append(np.linalg.norm(np.array([quantize_sign(x).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
                # err_qsgd.append(np.linalg.norm(np.array([quantize_qsgd(x, 1, d).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
    
                if torch.norm(torch.tensor(np.array([torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16)).numpy()/np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0))) < 6:

                    gamma_random = [torch.normal(0, np.sqrt(1+2/d), size=(2**13, d)) for k in range(n_machines)]                

                    # err_random_voronoi.append(np.linalg.norm(np.array([nearest_neighbor(gamma_random[k], x/torch.std(x)).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
                    # err_random_voronoi.append(np.linalg.norm(np.array([nearest_neighbor(gamma_random[k], x/torch.std(x)).numpy() / fix_bias(pich, torch.norm(x, dim=1, keepdim=True))[0].item() for k in range(n_machines)]).mean(0)-x.numpy())**2)
                    avg_q = np.array([nearest_neighbor(gamma_random[k], torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)])).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                    q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                    q_rad = avg_q-q_para
                    err_random_voronoi.append(np.linalg.norm(q_para-x.numpy())**2)
                    err_random_voronoi_rad.append(np.linalg.norm(q_rad)**2)
    
                    avg_q = np.array([nearest_neighbor(gamma_random[k], torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)])).numpy() / fix_bias(pich, torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True)).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                    q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                    q_rad = avg_q-q_para
                    err_random_voronoi_unbiased.append(np.linalg.norm(q_para-x.numpy())**2)
                    err_random_voronoi_unbiased_rad.append(np.linalg.norm(q_rad)**2)         
                    
                    avg_q = np.array([nearest_neighbor(gamma_random[k], torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)])).numpy() * fix_bias_3_bits(inv_pich, 1/fix_bias(pich, torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True))) * np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                    # avg_q = np.array([nearest_neighbor(gamma_random[k], x/torch.std(x)).numpy() * fix_bias_3_bits(inv_pich, 1/fix_bias(pich, torch.norm(x, dim=1, keepdim=True))) for k in range(n_machines)]).mean(0)
                    q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                    q_rad = avg_q-q_para
                    err_random_voronoi_uq.append(np.linalg.norm(q_para-x.numpy())**2)
                    err_random_voronoi_uq_rad.append(np.linalg.norm(q_rad)**2)
                    
                # err_top2.append(np.linalg.norm(np.array([quantize_top2(x).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
                avg_q = np.array([quantize_top2(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)])).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_top2.append(np.linalg.norm(q_para-x.numpy())**2)
                err_top2_rad.append(np.linalg.norm(q_rad)**2)
                # err_rand2.append(np.linalg.norm(np.array([quantize_rand2(x, d).numpy() for k in range(n_machines)]).mean(0)-x.numpy())**2)
    
                avg_q = np.array([quantize_rand2(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), d).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_rand2.append(np.linalg.norm(q_para-x.numpy())**2)
                err_rand2_rad.append(np.linalg.norm(q_rad)**2)
     
                # norm_x_true = torch.norm(torch.tensor(g[0, 16*j:16*(j+1)]).view((1, 16))/np.std(g[0, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True)
                # # x_norm = x / norm_x_true
                # norm_x = uniform_quantizer(norm_x_true.numpy(), 6, 0, 12)     
                
                # err_hsq.append(np.linalg.norm(np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 6).numpy() for k in range(n_machines)]).mean(0)-(x_norm*norm_x_true).numpy())**2)
                # err_hsq.append(np.linalg.norm(np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 15).numpy() for k in range(n_machines)]).mean(0)-(x_norm*norm_x_true).numpy())**2)
                # avg_q = np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 15).numpy() for k in range(n_machines)]).mean(0)
                avg_q = np.array([quantize_hsq(gamma_hsq[k], torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]) / torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), 25).numpy() * np.std(g[0, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_hsq.append(np.linalg.norm(q_para-x.numpy())**2)
                err_hsq_rad.append(np.linalg.norm(q_rad)**2)
                ### avec et sans quantif aussi faire 2 colonnes
                avg_q = np.array([quantize_hsq(gamma_hsq[k], torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]) / torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), 25, quantized=True).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) for k in range(n_machines)]).mean(0)
                # avg_q = np.array([quantize_hsq(gamma_hsq[k], x_norm, norm_x_true, 15, quantized=True).numpy() for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_hsq_q.append(np.linalg.norm(q_para-x.numpy())**2)
                err_hsq_q_rad.append(np.linalg.norm(q_rad)**2)
                
                
                # # err_greedy_hsq.append(np.linalg.norm(np.array([quantize_greedy_hsq(gamma_hsq[k], x_norm).numpy() for k in range(n_machines)]).mean(0)*norm_x-(x_norm*norm_x_true).numpy())**2)
                avg_q = np.array([quantize_greedy_hsq(gamma_hsq[k], torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]) / torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True)).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) * torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True).numpy() for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_greedy_hsq.append(np.linalg.norm(q_para-x.numpy())**2)
                err_greedy_hsq_rad.append(np.linalg.norm(q_rad)**2)
                # ### avec et sans quantif norm
                avg_q = np.array([quantize_greedy_hsq(gamma_hsq[k], torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]) / torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True)).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) * uniform_quantizer(torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True).numpy(), 6, 0, 12)  for k in range(n_machines)]).mean(0)
                err_greedy_hsq_q.append(np.linalg.norm(q_para-x.numpy())**2)
                err_greedy_hsq_q_rad.append(np.linalg.norm(q_rad)**2)
                
                # # err_vqsgd.append(np.linalg.norm(np.array([quantize_vqsgd(x_norm, d, 2).numpy() for k in range(n_machines)]).mean(0)*norm_x-(x_norm*norm_x_true).numpy())**2)
                avg_q = np.array([quantize_vqsgd(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]) / torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), d, 2).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) * torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True).numpy() for k in range(n_machines)]).mean(0)
                q_para = np.matmul( (x*x.t() / torch.norm(x)**2).numpy(), avg_q.reshape((16)))
                q_rad = avg_q-q_para
                err_vqsgd.append(np.linalg.norm(q_para-x.numpy())**2)
                err_vqsgd_rad.append(np.linalg.norm(q_rad)**2)
                # ### avec et sans quantif norm
                avg_q = np.array([quantize_vqsgd(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]) / torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True), d, 2).numpy() * np.std(g[k, 32*16*s:32*16*(s+1)]) * uniform_quantizer(torch.norm(torch.tensor(g[k, 16*j:16*(j+1)]).view((1, 16))/np.std(g[k, 32*16*s:32*16*(s+1)]), dim=1, keepdim=True).numpy(), 6, 0, 12) for k in range(n_machines)]).mean(0)
                err_vqsgd_q.append(np.linalg.norm(q_para-x.numpy())**2)
                err_vqsgd_q_rad.append(np.linalg.norm(q_rad)**2)
                
            if n_machines==8:
                Err_sign.append(np.array(err_sign).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_sign_rad.append(np.array(err_sign_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_random_voronoi.append(np.array(err_random_voronoi).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_random_voronoi_rad.append(np.array(err_random_voronoi_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_random_voronoi_unbiased.append(np.array(err_random_voronoi_unbiased).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_random_voronoi_unbiased_rad.append(np.array(err_random_voronoi_unbiased_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_random_voronoi_uq.append(np.array(err_random_voronoi_uq).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_random_voronoi_uq_rad.append(np.array(err_random_voronoi_uq_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_hsq.append(np.array(err_hsq).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_hsq_rad.append(np.array(err_hsq_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_hsq_q.append(np.array(err_hsq_q).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_hsq_q_rad.append(np.array(err_hsq_q_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_greedy_hsq.append(np.array(err_greedy_hsq).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_greedy_hsq_rad.append(np.array(err_greedy_hsq_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_greedy_hsq_q.append(np.array(err_greedy_hsq_q).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_greedy_hsq_q_rad.append(np.array(err_greedy_hsq_q_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_top2.append(np.array(err_top2).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_top2_rad.append(np.array(err_top2_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_rand2.append(np.array(err_rand2).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_rand2_rad.append(np.array(err_rand2_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_vqsgd.append(np.array(err_vqsgd).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_vqsgd_rad.append(np.array(err_vqsgd_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_vqsgd_q.append(np.array(err_vqsgd_q).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
                Err_vqsgd_q_rad.append(np.array(err_vqsgd_q_rad).sum() / (np.linalg.norm(g[:, 32*16*s:32*16*(s+1)].mean(0))**2))
            else:
                Err_sign.append(np.array(err_sign).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_sign_rad.append(np.array(err_sign_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_random_voronoi.append(np.array(err_random_voronoi).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_random_voronoi_rad.append(np.array(err_random_voronoi_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_random_voronoi_unbiased.append(np.array(err_random_voronoi_unbiased).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_random_voronoi_unbiased_rad.append(np.array(err_random_voronoi_unbiased_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_random_voronoi_uq.append(np.array(err_random_voronoi_uq).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_random_voronoi_uq_rad.append(np.array(err_random_voronoi_uq_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_hsq.append(np.array(err_hsq).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_hsq_rad.append(np.array(err_hsq_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_hsq_q.append(np.array(err_hsq_q).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_hsq_q_rad.append(np.array(err_hsq_q_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_greedy_hsq.append(np.array(err_greedy_hsq).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_greedy_hsq_rad.append(np.array(err_greedy_hsq_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_greedy_hsq_q.append(np.array(err_greedy_hsq_q).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_greedy_hsq_q_rad.append(np.array(err_greedy_hsq_q_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_top2.append(np.array(err_top2).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_top2_rad.append(np.array(err_top2_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_rand2.append(np.array(err_rand2).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_rand2_rad.append(np.array(err_rand2_rad).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_vqsgd.append(np.array(err_vqsgd).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_vqsgd_rad.append(np.array(err_vqsgd_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_vqsgd_q.append(np.array(err_vqsgd_q).sum() / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
                Err_vqsgd_q_rad.append(np.array(err_vqsgd_q_rad).sum()  / (np.linalg.norm(g[0, 32*16*s:32*16*(s+1)])**2))
        
    print('err_sign', np.array(Err_sign).mean(), np.array(Err_sign_rad).mean())
    print("err_top2",  np.array(Err_top2).mean(), np.array(Err_top2_rad).mean())
    print("err_rand2",  np.array(Err_rand2).mean(), np.array(Err_rand2_rad).mean())
    print('err_polytope', np.array(Err_vqsgd).mean(), np.array(Err_vqsgd_rad).mean())
    print('err_polytope_q', np.array(Err_vqsgd_q).mean(), np.array(Err_vqsgd_q_rad).mean())
    print("err_hsq",  np.array(Err_hsq).mean(), np.array(Err_hsq_rad).mean())
    print("err_hsq_q",  np.array(Err_hsq_q).mean(), np.array(Err_hsq_q_rad).mean())
    print("err_greedy_hsq",  np.array(Err_greedy_hsq).mean(), np.array(Err_greedy_hsq_rad).mean())
    print("err_greedy_hsq_q",  np.array(Err_greedy_hsq_q).mean(), np.array(Err_greedy_hsq_q_rad).mean())
    print("err_random_voronoi",  np.array(Err_random_voronoi).mean(), np.array(Err_random_voronoi_rad).mean())
    print("err_random_voronoi_unbiased",  np.array(Err_random_voronoi_unbiased).mean(), np.array(Err_random_voronoi_unbiased_rad).mean())
    print("err_random_voronoi_uq",  np.array(Err_random_voronoi_uq).mean(), np.array(Err_random_voronoi_uq_rad).mean())

    print('err_sign', np.array(Err_sign).std(), np.array(Err_sign_rad).std())
    print("err_top2",  np.array(Err_top2).std(), np.array(Err_top2_rad).std())
    print("err_rand2",  np.array(Err_rand2).std(), np.array(Err_rand2_rad).std())
    print('err_polytope', np.array(Err_vqsgd).std(), np.array(Err_vqsgd_rad).std())
    print('err_polytope_q', np.array(Err_vqsgd_q).std(), np.array(Err_vqsgd_q_rad).std())
    print("err_hsq",  np.array(Err_hsq).std(), np.array(Err_hsq_rad).std())
    print("err_hsq_q",  np.array(Err_hsq_q).std(), np.array(Err_hsq_q_rad).std())
    print("err_greedy_hsq",  np.array(Err_greedy_hsq).std(), np.array(Err_greedy_hsq_rad).std())
    print("err_greedy_hsq_q",  np.array(Err_greedy_hsq_q).std(), np.array(Err_greedy_hsq_q_rad).std())
    print("err_random_voronoi",  np.array(Err_random_voronoi).std(), np.array(Err_random_voronoi_rad).std())
    print("err_random_voronoi_unbiased",  np.array(Err_random_voronoi_unbiased).std(), np.array(Err_random_voronoi_unbiased_rad).std())
    print("err_random_voronoi_uq",  np.array(Err_random_voronoi_uq).std(), np.array(Err_random_voronoi_uq_rad).std())

#%%
    
if __name__ == "__main__":
    # table2_radial(1)   
    # table2_radial(20)
    # table4_radial_all(1)    
    # table4_radial_all(8)
    # table4_radial_distributed_grad(1)
    table4_radial_distributed_grad(8)
