# -*- coding: utf-8 -*-

import torch
import numpy as np
import matplotlib.pyplot as plt

class LeastSquaresRegressorTorch1():

    def __init__(self, n_iter=10, eta=0.1, batch_size=10, samples_per_worker=1, K=1):
        self.n_iter = n_iter
        self.eta = eta
        self.batch_size = batch_size
        self.samples_per_worker = samples_per_worker
        self.K = K
        
    def fit(self, X, Y, to_choose_workers, finalXtrain, init_w):

        n_instances, n_features = X.shape
        
        # we need to "wrap" the NumPy arrays X and Y as PyTorch tensors
        Xt = torch.tensor(X, dtype=torch.float)
        Yt = torch.tensor(Y, dtype=torch.float)
        
        # to_choose_workers = [np.random.choice([i for i in range(n_instances)], self.samples_per_worker, replace=True) for k in range(self.K)]

        # initialize the weight vector to all zeros
        # self.w = torch.zeros(n_features, requires_grad=True, dtype=torch.float)
        self.w = torch.tensor(init_w,  requires_grad=True)
        self.history = []
        # self.history.append(torch.sum((Xt.mv(self.w) - Yt)**2) / n_instances)
        self.history.append(torch.sum((Xt[finalXtrain].mv(self.w) - Yt[finalXtrain])**2) / len(finalXtrain))
        
        optimizer = torch.optim.SGD([self.w], lr=self.eta)
        
        for i in range(self.n_iter):
            print(i)
            concat_grad = []
            # Xt, Yt = shuffle(Xt, Yt)
            for k in range(self.K):
                X = Xt[to_choose_workers[k]]
                Y = Yt[to_choose_workers[k]]
                
                grad_batch = torch.zeros(n_features)
            
                for batch_start in range(0, self.samples_per_worker, self.batch_size):
                    batch_end = batch_start + self.batch_size
                    Xbatch = X[batch_start:batch_end, :]
                    Ybatch = Y[batch_start:batch_end]
                    # to_choose = np.random.choice([i for i in range(self.samples_per_worker)], self.batch_size, replace=True)
                    # Xbatch = Xt[to_choose, :]
                    # Ybatch = Yt[to_choose]
                
                    # mv = matrix-vector multiplication in Torch
                    G = Xbatch.mv(self.w)
                    
                    Error = G - Ybatch
                    loss_batch = torch.sum(Error**2) / self.batch_size
                    
                    # reset all gradients
                    optimizer.zero_grad()                  
    
                    # compute the gradients for the loss for this batch
                    loss_batch.backward()
                    
                    grad_batch += self.w.grad / int(self.samples_per_worker/self.batch_size)
                    
                concat_grad.append(grad_batch.numpy())
    
            self.w.grad = torch.tensor(np.array(concat_grad).mean(0))
                # for SGD, this is equivalent to w -= learning_rate * gradient as we saw before
            optimizer.step()
                          
            self.history.append(torch.sum((Xt[finalXtrain].mv(self.w) - Yt[finalXtrain])**2) / len(finalXtrain))
            
        

def nearest_neighbor(gamma, x):
    return gamma[torch.topk(-torch.norm(x-gamma, dim=1), k=1)[1]]

class LeastSquaresRegressorTorch2():

    def __init__(self, n_iter=10, eta=0.1, batch_size=10, samples_per_worker=1, K=1, buckets=1):
        self.n_iter = n_iter
        self.eta = eta
        self.batch_size = batch_size
        self.samples_per_worker = samples_per_worker
        self.K = K
        self.buckets = buckets
        
    def fit(self, X, Y, gamma, to_choose_workers, finalXtrain, init_w):

        n_instances, n_features = X.shape
        
        # we need to "wrap" the NumPy arrays X and Y as PyTorch tensors
        Xt = torch.tensor(X, dtype=torch.float)
        Yt = torch.tensor(Y, dtype=torch.float)
        
        # to_choose_workers = [np.random.choice([i for i in range(n_instances)], self.samples_per_worker, replace=True) for k in range(self.K)]

        # initialize the weight vector to all zeros
        self.w = torch.tensor(init_w,  requires_grad=True)
        self.history = []
        # self.history.append(torch.sum((Xt.mv(self.w) - Yt)**2) / n_instances)
        self.history.append(torch.sum((Xt[finalXtrain].mv(self.w) - Yt[finalXtrain])**2) / len(finalXtrain))
        
        optimizer = torch.optim.SGD([self.w], lr=self.eta)
        
        for i in range(self.n_iter):
            print(i)
            concat_grad = []
            # Xt, Yt = shuffle(Xt, Yt)
            for k in range(self.K):
                X = Xt[to_choose_workers[k]]
                Y = Yt[to_choose_workers[k]]
                
                grad_batch = torch.zeros(n_features)
            
                for batch_start in range(0, self.samples_per_worker, self.batch_size):
                    batch_end = batch_start + self.batch_size
                    Xbatch = X[batch_start:batch_end, :]
                    Ybatch = Y[batch_start:batch_end]
                    # to_choose = np.random.choice([i for i in range(self.samples_per_worker)], self.batch_size, replace=True)
                    # Xbatch = Xt[to_choose, :]
                    # Ybatch = Yt[to_choose]
                
                    # mv = matrix-vector multiplication in Torch
                    G = Xbatch.mv(self.w)
                    
                    Error = G - Ybatch
                    loss_batch = torch.sum(Error**2) / self.batch_size
                    
                    # reset all gradients
                    optimizer.zero_grad()                  
    
                    # compute the gradients for the loss for this batch
                    loss_batch.backward()
                    
                    grad_batch += self.w.grad / int(self.samples_per_worker/self.batch_size)
                    
                quantized_grad_batch = list()
                for b in range(self.buckets):
                    grad_b = grad_batch[b*16:(b+1)*16]
                    norm = torch.norm(grad_b)
                    norm_quantized = uniform_quantizer(norm.numpy(), 6, 0, 6) 
                    quantized_grad_batch.append(norm_quantized * nearest_neighbor(gamma, grad_b/norm)[0])
                    
                concat_grad.append(torch.stack(quantized_grad_batch).view(n_features).numpy())
    
            self.w.grad = torch.tensor(np.array(concat_grad).mean(0))
                # for SGD, this is equivalent to w -= learning_rate * gradient as we saw before
            optimizer.step()
                          
            self.history.append(torch.sum((Xt[finalXtrain].mv(self.w) - Yt[finalXtrain])**2) / len(finalXtrain))
        
def uniform_quantizer(norm_x_true, b, umin, umax):
    unif = np.arange(umin, umax, (umax-umin)/2**b)
    index = np.argmin(np.abs(unif-norm_x_true))
    return unif[index]

class LeastSquaresRegressorTorch4():
    def __init__(self, n_iter=10, eta=0.1, batch_size=10, samples_per_worker=1, K=1, buckets=1, n_atoms=256):
        self.n_iter = n_iter
        self.eta = eta
        self.batch_size = batch_size
        self.samples_per_worker = samples_per_worker
        self.K = K
        self.buckets = buckets
        self.n_atoms = n_atoms

        
    def fit(self, X, Y, to_choose_workers, finalXtrain, init_w):

        n_instances, n_features = X.shape
        
        # we need to "wrap" the NumPy arrays X and Y as PyTorch tensors
        Xt = torch.tensor(X, dtype=torch.float)
        Yt = torch.tensor(Y, dtype=torch.float)
        
        # to_choose_workers = [np.random.choice([i for i in range(n_instances)], self.samples_per_worker, replace=True) for k in range(self.K)]

        # initialize the weight vector to all zeros
        self.w = torch.tensor(init_w,  requires_grad=True)
        self.history = []
        # self.history.append(torch.sum((Xt.mv(self.w) - Yt)**2) / n_instances)
        self.history.append(torch.sum((Xt[finalXtrain].mv(self.w) - Yt[finalXtrain])**2) / len(finalXtrain))
        
        optimizer = torch.optim.SGD([self.w], lr=self.eta)
        
        for i in range(self.n_iter):
            print(i)
            concat_grad = []
            # Xt, Yt = shuffle(Xt, Yt)
            for k in range(self.K):
                X = Xt[to_choose_workers[k]]
                Y = Yt[to_choose_workers[k]]
                
                grad_batch = torch.zeros(n_features)
            
                for batch_start in range(0, self.samples_per_worker, self.batch_size):
                    batch_end = batch_start + self.batch_size
                    Xbatch = X[batch_start:batch_end, :]
                    Ybatch = Y[batch_start:batch_end]
                    # to_choose = np.random.choice([i for i in range(self.samples_per_worker)], self.batch_size, replace=True)
                    # Xbatch = Xt[to_choose, :]
                    # Ybatch = Yt[to_choose]
                
                    # mv = matrix-vector multiplication in Torch
                    G = Xbatch.mv(self.w)
                    
                    Error = G - Ybatch
                    loss_batch = torch.sum(Error**2) / self.batch_size
                     
                     
                    # reset all gradients
                    optimizer.zero_grad()                  
    
                    # compute the gradients for the loss for this batch
                    loss_batch.backward()
                    
                    grad_batch += self.w.grad / int(self.samples_per_worker/self.batch_size)
                    
                quantized_grad_batch = list()
                norm = torch.std(grad_batch)
                for b in range(self.buckets):
                    grad_b = grad_batch[b*16:(b+1)*16]
                    gamma = torch.normal(0, np.sqrt(1+2/16), size=(2**13, 16))
                    quantized_grad_batch.append(norm* nearest_neighbor(gamma, grad_b/norm)[0])
                    
                concat_grad.append(torch.stack(quantized_grad_batch).view(n_features).numpy())
    
            self.w.grad = torch.tensor(np.array(concat_grad).mean(0))
                # for SGD, this is equivalent to w -= learning_rate * gradient as we saw before
            optimizer.step()
                          
            self.history.append(torch.sum((Xt[finalXtrain].mv(self.w) - Yt[finalXtrain])**2) / len(finalXtrain))


class LeastSquaresRegressorTorch5():
    def __init__(self, n_iter=10, eta=0.1, batch_size=10, samples_per_worker=1, K=1, buckets=1, n_atoms=256):
        self.n_iter = n_iter
        self.eta = eta
        self.batch_size = batch_size
        self.samples_per_worker = samples_per_worker
        self.K = K
        self.buckets = buckets
        self.n_atoms = n_atoms

        
    def fit(self, X, Y, pich, to_choose_workers, finalXtrain, init_w):

        n_instances, n_features = X.shape
        
        # we need to "wrap" the NumPy arrays X and Y as PyTorch tensors
        Xt = torch.tensor(X, dtype=torch.float)
        Yt = torch.tensor(Y, dtype=torch.float)
        
        # to_choose_workers = [np.random.choice([i for i in range(n_instances)], self.samples_per_worker, replace=True) for k in range(self.K)]

        # initialize the weight vector to all zeros
        self.w = torch.tensor(init_w,  requires_grad=True)
        self.history = []
        # self.history.append(torch.sum((Xt.mv(self.w) - Yt)**2) / n_instances)
        self.history.append(torch.sum((Xt[finalXtrain].mv(self.w) - Yt[finalXtrain])**2) / len(finalXtrain))
        
        optimizer = torch.optim.SGD([self.w], lr=self.eta)
        
        for i in range(self.n_iter):
            print(i)
            concat_grad = []
            # Xt, Yt = shuffle(Xt, Yt)
            for k in range(self.K):
                X = Xt[to_choose_workers[k]]
                Y = Yt[to_choose_workers[k]]
                
                grad_batch = torch.zeros(n_features)
            
                for batch_start in range(0, self.samples_per_worker, self.batch_size):
                    batch_end = batch_start + self.batch_size
                    Xbatch = X[batch_start:batch_end, :]
                    Ybatch = Y[batch_start:batch_end]
                    # to_choose = np.random.choice([i for i in range(self.samples_per_worker)], self.batch_size, replace=True)
                    # Xbatch = Xt[to_choose, :]
                    # Ybatch = Yt[to_choose]
                
                    # mv = matrix-vector multiplication in Torch
                    G = Xbatch.mv(self.w)
                    
                    Error = G - Ybatch
                    # loss_batch = torch.sum(Error**2) / self.samples_per_worker
                    loss_batch = torch.sum(Error**2) / self.batch_size
                    
                    # reset all gradients
                    optimizer.zero_grad()                  
    
                    # compute the gradients for the loss for this batch
                    loss_batch.backward()
                    
                    # grad_batch += self.w.grad
                    grad_batch += self.w.grad / int(self.samples_per_worker/self.batch_size)
                    
                quantized_grad_batch = list()
                norm = torch.std(grad_batch)
                for b in range(self.buckets):
                    grad_b = grad_batch[b*16:(b+1)*16]
                    gamma = torch.normal(0, np.sqrt(1+2/16), size=(2**13, 16))
                    quantized_grad_batch.append(norm * nearest_neighbor(gamma, grad_b/norm)[0] * fix_bias_3_bits(inv_pich, 1/fix_bias(pich, torch.norm(grad_b/norm))))
                    
                concat_grad.append(torch.stack(quantized_grad_batch).view(n_features).numpy())
    
            self.w.grad = torch.tensor(np.array(concat_grad).mean(0), dtype=self.w.grad.dtype)
                # for SGD, this is equivalent to w -= learning_rate * gradient as we saw before
            optimizer.step()
                          
            # self.history.append(torch.sum((Xt.mv(self.w) - Yt)**2) / n_instances)
            self.history.append(torch.sum((Xt[finalXtrain].mv(self.w) - Yt[finalXtrain])**2) / len(finalXtrain))

def fix_bias(p, x):
    if torch.norm(x) > 5.9:
        return p[-1]
    elif torch.norm(x) < 0.1:
        return p[0]
    else:
        return p[int(10*torch.norm(x))]
def fix_bias_3_bits(ip, x):
    proba = 1 / np.abs(np.array(ip)-x.numpy())
    return np.array(ip)[np.random.choice([i for i in range(8)], 1, replace=True, p=proba / proba.sum())]
inv_pich = [1.302554292, 1.348986681, 1.391496061, 1.433970247, 1.48116051, 1.535987439, 1.604238237, 1.700437542]

#%%
### Compare 
n_samples = 2**14
n_features = 16 * 2**5  #= 2**9
nb_buckets = 2**5
# n_atoms = 2**16 ###total bit budget instead
niter = 10
bs = 128 #2**6   
workers = 2**5 #K=10;100
spw = 2**8 #2**9

def main(f):
    
    Xtrain = np.random.normal(size = (n_samples, n_features))
    Ytrain = -0.9 * Xtrain[:,0] + 0.4 * np.random.normal(size=Xtrain.shape[0])
    to_choose_workers = [np.random.choice([i for i in range(n_samples)], spw, replace=True) for k in range(workers)]
    
    finalXtrain = []
    for k in range(workers):
        for i in range(n_samples):
            if i not in finalXtrain and i in to_choose_workers[k]:
                finalXtrain.append(i)
                
    
    ### Find lr optimal
    u, s, vh = np.linalg.svd(np.matmul(np.transpose(Xtrain), Xtrain))
    L = s.max() / n_samples
    # lr_gd = 1/(1.7*L)
    # lr_hsq = 1/(2*L)
    # lr_biased_stvq = 1/(1*L)
    # lr_stvq = 1/(2*L)
    lr_gd = 1/(f*L)
    lr_hsq = 1/(f*L)
    lr_biased_stvq = 1/(f*L)
    lr_stvq = 1/(f*L)    
    
    init_w = 2 * torch.rand(n_features, dtype=torch.float) - 1
    ### find loss_star:
    regr1 = LeastSquaresRegressorTorch1(n_iter=50, eta=lr_gd, batch_size=bs, samples_per_worker=spw, K=workers)
    regr1.fit(Xtrain, Ytrain, to_choose_workers, finalXtrain, init_w)
    loss_star = regr1.history[-1].item()
    
    plt.figure()
    plt.plot((np.array(regr1.history[:niter+1])-loss_star), '.-', label='GD');
    # plt.plot((np.array(regr1.history)-loss_star), '.-', label='GD');
    
    ### COde HSQ norm on 6bits
    gamma = torch.normal(0, np.sqrt(1+2/16), size=(2**10, 16))
    gamma /= torch.norm(gamma, dim=1, keepdim=True)
    ### Same codebook for all workers and all epoch
    regr2 = LeastSquaresRegressorTorch2(n_iter=niter, eta=lr_hsq, batch_size=bs, samples_per_worker=spw, K=workers, buckets=nb_buckets)
    regr2.fit(Xtrain, Ytrain, gamma, to_choose_workers, finalXtrain, init_w)
    ### plt.plot(np.log10(np.array(regr2.history)-loss_star), '.--', label='HSQ-greed');
    plt.plot((np.array(regr2.history)-loss_star), '.--', label='HSQ-greed');
    
    
    # ### Different codebook for each worker and each epoch
    # regr4 = LeastSquaresRegressorTorch4(n_iter=niter, eta=lr_biased_stvq, batch_size=bs, samples_per_worker=spw, K=workers, buckets=nb_buckets, n_atoms=2**13)
    # regr4.fit(Xtrain, Ytrain)
    # plt.plot((np.array(regr4.history)-loss_star), '.-', label='Dir.Unbiased−DoStoVoQ');
    
    ### pich is the unbias function corresponding to the given parameters
    ### eg here : pich = pickle load ""radial_bias_16_2048_.txt"
    import pickle
    with open("./radial_biases/radial_bias_"+str(16)+"_"+str(2**13)+"_"+".txt", "rb") as fp:   #Pickling
        pich = pickle.load(fp)
    p=[e[0] for e in pich]
    
    regr5 = LeastSquaresRegressorTorch5(n_iter=niter, eta=lr_stvq, batch_size=bs, samples_per_worker=spw, K=workers, buckets=nb_buckets, n_atoms=2**13)
    regr5.fit(Xtrain, Ytrain, p, to_choose_workers, finalXtrain, init_w)
    plt.plot((np.array(regr5.history)-loss_star), '.-', label='DoStoVoQ');
    
    
    #plt.legend()
    plt.xlabel('Number of iterations')
    plt.ylabel('log(excess train loss)')
    plt.yscale('log')
    # plt.semilogy()
    plt.grid(True,which="both", linestyle='--')  
    
    # plt.savefig('log10_linear_excess_train_loss_'+str(workers)+'_workers_'+str(f)+'L.pdf', format='pdf', dpi=1200)
    plt.savefig('corr_log10_linear_excess_train_loss_'+str(workers)+'_workers_'+str(f)+'L.pdf', format='pdf', dpi=1200)
    plt.show()

#%%
if __name__ == "__main__":    
    # for factor in [1.2, 2, 4, 8]:
    # for factor in [1.3, 1.5, 1.7, 2, 3, 4, 8]:
    #     main(factor)
    main(3)
