from MQL import MQNode
import NueralNet as OurNNUtils
import numpy as np
import matplotlib
#matplotlib.use('Qt5Agg')  # Set the backend to Qt5Agg
import matplotlib.pyplot as plt
import datetime
import math #for ceiling, floor functions
import copy
import torch
import Encoding
import Plotting   

class Malcom4:
    def __init__(self, dataset="mnist", elias=None):
        self.N = 0
        self.iter_max = 0
        self.bits_per_iter = []
        self.acc_l = []
        self.loss_l = []
        self.t_loss = []
        self.bits_l=[]
        self.bins = {}
        self.conv = []
        self.temp_save_path ="data/temp/uniQuant"
        self.bits = []
        self.ebits=[]
        self.dataset= dataset
        self.test_data = None
        self.train_data = None
        self.elias=elias

    def saveGradients(self,  num_of_node, iterations, args):
        nn_dict = {}
        test_dl = OurNNUtils.generateTestDataLoader(OurNNUtils.minst_test_data)

        for i in range(num_of_node):
            data_loader_dict = OurNNUtils.generateTrainDataloadersNonIid(num_of_node,
                                                                         OurNNUtils.minst_training_data)
            nn_dict[i] = MQNode(data_loader=data_loader_dict[i], id=i,
                               nn_type=args["nn_type"], test_dl=test_dl, lr=args["lr"],
                               q_scalar=args["q_scalar"], dith=args["dith"], bits=args['bit'], reg=args['reg'], 
                               threshold=args['threshold'],quantize_loop=args['quantize_loop'],tr_decay=args['tr_decay'])
        
        k = 0
        while (k < iterations):
            nn_dict[0].gradNoThreshold()
            vec=nn_dict[0].weightDictTo1D()
            np.save("data/grads/one_iteration.npy", vec)
            return

    def genDataSets(self):
        if self.dataset=="mnist":
            self.test_data = OurNNUtils.minst_test_data
            self.train_data = OurNNUtils.minst_training_data
        else:
            self.test_data = OurNNUtils.cifar_testset
            self.train_data = OurNNUtils.cifar_trainset

    def generateNNDict(self, args, num_of_node, test_dl, dist,b):
        nn_dict = {}
        for i in range(num_of_node):
            data_loader_dict = OurNNUtils.twoClass(num_of_node, self.train_data, class_num=dist, batch_size=b) #
            nn_dict[i] = MQNode(data_loader=data_loader_dict[i], id=i,
                                nn_type=args["nn_type"], test_dl=test_dl, lr=args["lr"],
                                q_scalar=args["q_scalar"], dith=args["dith"], bits=args['bit'], reg=args['reg'], 
                                threshold=args['threshold'],quantize_loop=args['quantize_loop'],l=args['l'],
                                alpha=args['alpha'],eps=args['eps'], lr_const=args['lr_const'], 
                                device_str=args['device_str'], E= args['E'],optim=args['optim'],tr_decay=args['tr_decay'])
        return nn_dict

    def genNbdDict(self,nbds):
        nbd_dict={}
        nbd_rev ={}
        for nbd in range(len(nbds)):
            for i in nbds[nbd]:
                if i not in nbd_rev:
                    nbd_rev[i]=[]
                nbd_rev[i].append(nbd)
        for i in range(len(nbd_rev)):
            nbd_dict[i]=[]
            for j in nbd_rev[i]:
                for k in nbds[j]:
                    if k!=i:
                        nbd_dict[i].append(k)
        return nbd_dict
    
    def computeBits(self, q_i_j_c, args, model):
        if args['bit']==-1:
            return 32*(q_i_j_c.size(dim=0))

        Li = model.bits +1
        q_i_j_label = Encoding.getLabels(q_i_j_c, Li-1)
        return Encoding.calcBitsCum(q_i_j_label, Li, device=args["device_str"], elias=self.elias).item()

    def synchronous(self,num_of_node, iterations, args, dist =5, b=10, nbds=[]):
        self.genDataSets()

        test_dl = OurNNUtils.generateTestDataLoader(self.test_data)
        nn_dict=self.generateNNDict(args, num_of_node, test_dl, dist,b)
        nbd= self.genNbdDict(nbds)
        k=0 
        shared=np.zeros(shape=num_of_node)
        while (k<iterations):
            for i in range(num_of_node):
                shared[i]=0
            bits=0
            q_c_l=[]
            for i in range(num_of_node):
                nn_dict[i].gradNoThreshold()
            l=np.arange(num_of_node)
            np.random.shuffle(l)
            for i in l:
                nbd_l=nbd[i]
                np.random.shuffle(nbd_l)
                for j in nbd_l:
                    q_i_j, b_i_j = nn_dict[i].computeQ(j)
                    if k % args['acc_idx']==0:
                        if shared[i]==0:
                            bits+=len(nbd[i])*self.computeBits(copy.deepcopy(q_i_j),args=args,model=nn_dict[i])
                            shared[i]=1

                    nn_dict[j].updateY(i,q_i_j, b_i_j)
            
            for i in l:
                nn_dict[i].aggregateNbdThres(nbd[i])

            if k % args["acc_idx"] == 0:
                sum = 0
                loss = 0
                
                for id in nn_dict:
                    #bits+= nn_dict[id].getSupport() # activate if you want to see sparsity 
                    sum += nn_dict[id].test()
                    loss+= nn_dict[id].getTrainLoss()

                self.loss_l.append(loss/num_of_node)
                self.acc_l.append(sum/num_of_node)
                self.bits_l.append(bits)
                print("l2 norm of node 1 and 0: ", torch.norm(nn_dict[1].weightDictTo1D()-nn_dict[0].weightDictTo1D()).item())
                
                print("Iteration :", k, " acc: ", sum/num_of_node, " bits: ", bits, " loss: ", loss/num_of_node)           
            else:
                
                print("iteration: ", k)
            k+=1
        
    def runUniQuant(self, num_of_node, iterations, args, dist =5, b=10):
        self.genDataSets()
        test_dl = OurNNUtils.generateTestDataLoader(self.test_data)
        nn_dict=self.generateNNDict(args, num_of_node, test_dl, dist,b)
        nodes_l_m = []

        for k in range(iterations):
            nodes_l_m.append(np.random.choice(np.arange(num_of_node), replace=False,
                                         size=(2)))

        k = 0
        while (k < iterations):

            bits=0
            #Choose two nodes.
            i = nodes_l_m[k][0]
            j = nodes_l_m[k][1]
            #(Step 2) Train chosen nodes.
            nn_dict[i].gradNoThreshold()
            nn_dict[j].gradNoThreshold()
            #(Step 3) Quantize weight residuals.
            q_i_j, b_i_j = nn_dict[i].computeQ(j)
            q_j_i, b_j_i = nn_dict[j].computeQ(i)
            #Copy weight residuals.
            q_i_j_c = copy.deepcopy(q_i_j)
            q_j_i_c = copy.deepcopy(q_j_i)
            #(Step 7) Update nodes with neighbors' residuals.
            nn_dict[i].updateY(j,q_j_i, b_j_i)
            nn_dict[j].updateY(i,q_i_j, b_i_j)
            #(Step 8-9) Aggregation.
            nn_dict[i].aggregateThres(j)
            nn_dict[j].aggregateThres(i)
     
            if k % args["acc_idx"] == 0:
                sum = 0
                loss = 0
                
                for id in nn_dict:
                    #bits+= nn_dict[id].getSupport() # activate if you want to see sparsity 
                    sum += nn_dict[id].test()
                    loss+= nn_dict[id].getTrainLoss()
                
                if args['bit']==-1:
                    #The case of no quantization
                    bits=32*(q_i_j_c.size(dim=0)+q_j_i_c.size(dim=0))
                    self.bits_l.append(bits)
                else:
                    #Counting the number of bits using the proposed encoding scheme.
                    Li, Lj = nn_dict[i].bits +1, nn_dict[j].bits+1
                    q_i_j_label = Encoding.getLabels(q_i_j_c, Li-1)
                    q_j_i_label = Encoding.getLabels(q_j_i_c, Lj-1)
                    bits += Encoding.calcBitsCum(q_i_j_label, Li, device=args["device_str"],elias=self.elias)
                    bits += Encoding.calcBitsCum(q_j_i_label, Lj, device=args["device_str"], elias=self.elias) 
                    self.bits_l.append(bits.item())

                self.loss_l.append(loss/num_of_node)
                self.acc_l.append(sum/num_of_node)
                print("l2 norm of node 1 and 0: ", torch.norm(nn_dict[1].weightDictTo1D()-nn_dict[0].weightDictTo1D()).item())
                
                print("Iteration :", k, " acc: ", sum/num_of_node, " bits: ", bits, " loss: ", loss/num_of_node)           
            else:
                
                print("iteration: ", k)

            k += 1
        
        loss_path = (args['data_save_folder']+"loss:"+"it"+str(iterations)+'_lr'
                     +str(args['lr']) +"_tr"+str(args['threshold'])+"_lvl" 
                     +str(args['bit'])+"_b"+str(b)+'_e'+str(args['E'])+'dist_'+str(dist))
        loss_path += datetime.datetime.now().isoformat()
        loss_path = loss_path+".npy"
        np.save(loss_path,self.loss_l)

        acc_path = (args['data_save_folder']+"acc:"+"it"+str(iterations)+'_lr'
                    +str(args['lr'])+"_tr"+str(args['threshold'])+"_lvl" 
                    +str(args['bit'])+"_b"+str(b)+'_e'+str(args['E'])+'dist_'+str(dist))
        acc_path += datetime.datetime.now().isoformat()
        acc_path = acc_path+".npy"
        np.save(acc_path,self.acc_l)

        bit_path = (args['data_save_folder']+"bit:"+"it"+str(iterations)+'_lr'
                    +str(args['lr'])+"_tr"+str(args['threshold'])+"_lvl" 
                    +str(args['bit'])+"_b"+str(b)+'_e'+str(args['E'])+'dist_'+str(dist))
        bit_path += datetime.datetime.now().isoformat()
        bit_path = bit_path+".npy"
        np.save(bit_path,self.bits_l)

        




class Choco:
    def __init__(self, dataset="mnist", elias=None):
        self.N = 0
        self.iter_max = 0
        self.bits_per_iter = []
        self.acc_l = []
        self.loss_l = []
        self.bits_l = []
        self.t_loss =[]
        self.temp_save_path ="data/temp/choco"
        self.dataset=dataset
        self.bits = []
        self.test_data = None
        self.train_data = None
        self.elias=elias

    def genDataSets(self):
        if self.dataset=="mnist":
            self.test_data = OurNNUtils.minst_test_data
            self.train_data = OurNNUtils.minst_training_data
        else:
            self.test_data = OurNNUtils.cifar_testset
            self.train_data = OurNNUtils.cifar_trainset

    def generateNNDict(self, args, num_of_node, test_dl, dist,b):
        nn_dict = {}
        for i in range(num_of_node):
            data_loader_dict = OurNNUtils.twoClass(num_of_node, self.train_data, class_num=dist, batch_size=b) #
            nn_dict[i] = MQNode(data_loader=data_loader_dict[i], id=i,
                                nn_type=args["nn_type"], test_dl=test_dl, lr=args["lr"],
                                q_scalar=args["q_scalar"], dith=args["dith"], bits=args['bit'], reg=args['reg'], 
                                threshold=args['threshold'],quantize_loop=args['quantize_loop'],l=args['l'],
                                alpha=args['alpha'],eps=args['eps'], lr_const=args['lr_const'], 
                                device_str=args['device_str'], E= args['E'],optim=args['optim'],tr_decay=args['tr_decay'])
        return nn_dict

    def genNbdDict(self,nbds):
        nbd_dict={}
        nbd_rev ={}
        for nbd in range(len(nbds)):
            for i in nbds[nbd]:
                if i not in nbd_rev:
                    nbd_rev[i]=[]
                nbd_rev[i].append(nbd)
        for i in range(len(nbd_rev)):
            nbd_dict[i]=[]
            for j in nbd_rev[i]:
                for k in nbds[j]:
                    if k!=i:
                        nbd_dict[i].append(k)

        return nbd_dict
    
    def computeBits(self, q_i_j_label, args, model):
        if args['bit']==-1:
            return 32*(q_i_j_label.size(dim=0))

        return Encoding.calcBitsQSGD(q_i_j_label, args['bit']+1, device=args["device_str"],elias=self.elias).item()

    def synchronous(self,num_of_node, iterations, args, dist =5, b=10, nbds=[]):
        self.genDataSets()

        test_dl = OurNNUtils.generateTestDataLoader(self.test_data)
        nn_dict=self.generateNNDict(args, num_of_node, test_dl, dist,b)
        nbd= self.genNbdDict(nbds)
        k=0 
        shared=np.zeros(shape=num_of_node)
        while (k<iterations):
            for i in range(num_of_node):
                shared[i]=0
            bits=0
            q_c_l=[]
            for i in range(num_of_node):
                nn_dict[i].gradNoThreshold()

            
            for i in range(num_of_node):
                for j in nbd[i]:
                    q_i_j, b_i_j, q_i_j_levels = nn_dict[i].computeQ_C(j)
                    has_nan = torch.isnan(q_i_j).any()
                    if has_nan:
                        print("quantization prodcued a nan!") 
 
                    if k % args['acc_idx']==0:
                        if shared[i]==0:
                            bits+=len(nbd[i])*self.computeBits(q_i_j_levels,args=args,model=nn_dict[i])
                            shared[i]=1
                    nn_dict[j].updateY(i,q_i_j, b_i_j)
            
            for i in range(num_of_node):
                nn_dict[i].aggregateNbd(nbd[i])

            if k % args["acc_idx"] == 0:
                sum = 0
                loss = 0
                
                for id in nn_dict:
                    #bits+= nn_dict[id].getSupport() # activate if you want to see sparsity 
                    sum += nn_dict[id].test()
                    loss+= nn_dict[id].getTrainLoss()

                self.loss_l.append(loss/num_of_node)
                self.acc_l.append(sum/num_of_node)
                self.bits_l.append(bits)
                print("l2 norm of node 1 and 0: ", torch.norm(nn_dict[1].weightDictTo1D()-nn_dict[0].weightDictTo1D()).item())
                print(nn_dict[1].weightDictTo1D()[1234:1244])
                print("Iteration :", k, " acc: ", sum/num_of_node, " bits: ", bits, " loss: ", loss/num_of_node)           
            else:
                
                print("iteration: ", k)
            k+=1    

    def runChoco(self, num_of_node, iterations, args, dist =5, b=10):
        nn_dict = {}
        if self.dataset=="mnist":
            test_data = OurNNUtils.minst_test_data
            train_data = OurNNUtils.minst_training_data
        else:
            test_data = OurNNUtils.cifar_testset
            train_data = OurNNUtils.cifar_trainset

        test_dl = OurNNUtils.generateTestDataLoader(test_data)

        for i in range(num_of_node):
            data_loader_dict = OurNNUtils.twoClass(num_of_node, train_data, class_num=dist, batch_size=b) #
            nn_dict[i] = MQNode(data_loader=data_loader_dict[i], id=i,
                                nn_type=args["nn_type"], test_dl=test_dl, lr=args["lr"],
                                q_scalar=args["q_scalar"], dith=args["dith"], bits=args['bit'], reg=args['reg'], 
                                threshold=args['threshold'],quantize_loop=args['quantize_loop'],l=args['l'],
                                alpha=args['alpha'],eps=args['eps'], lr_const=args['lr_const'], 
                                device_str=args['device_str'], E= args['E'],optim=args['optim'])
        k = 0
        cum_bits=0

        while (k < iterations):
            bits=0
            nodes_l_m = np.random.choice(np.arange(num_of_node), replace=False,
                                         size=2)
            #Choose two nodes.
            i = nodes_l_m[0]
            j = nodes_l_m[1]
            #(Step 2) Train chosen nodes.
            nn_dict[i].gradNoThreshold()
            nn_dict[j].gradNoThreshold()
            #(Step 3) Quantize weight residuals.
            q_i_j, b_i_j, q_i_j_label = nn_dict[i].computeQ_C(j)
            q_j_i, b_j_i, q_j_i_label = nn_dict[j].computeQ_C(i)
            #(Step 7) Update nodes with neighbors' residuals.
            nn_dict[i].updateY(j,q_j_i, b_j_i)
            nn_dict[j].updateY(i,q_i_j, b_i_j)
            #(Step 8-9) Aggregation.
            nn_dict[i].aggregate(j, choco=True)
            nn_dict[j].aggregate(i, choco=True)


            if k % args["acc_idx"] == 0:
                sum = 0
                loss = 0

                for id in nn_dict:
                    sum += nn_dict[id].test()
                    loss+= nn_dict[id].getTrainLoss()
                
                Li, Lj = nn_dict[i].bits +1, nn_dict[j].bits+1
                #Counting the number of bits (QSGD).
                bits += Encoding.calcBitsQSGD(q_i_j_label, Li, device=args['device_str'], elias=self.elias).item() #+len(torch.nonzero(q_j_i_c))
                bits += Encoding.calcBitsQSGD(q_j_i_label, Lj, device=args['device_str'], elias=self.elias).item() #+len(torch.nonzero(q_i_j_c))
            

                self.loss_l.append(loss/num_of_node)
                self.acc_l.append(sum/num_of_node)
                print("l2 norm of node 1 and 0: ", torch.norm(nn_dict[1].weightDictTo1D()-nn_dict[0].weightDictTo1D()).item())
                self.bits_l.append(bits)
                
                print("Iteration: ", k, " acc: ", sum/num_of_node, " bits: ", bits, " loss: ", loss/num_of_node)           
            
            else:
                print("iteration: ", k)
            
            k += 1
        bit_path = (args['data_save_folder']+"cho_bit:"+"it"+str(iterations)+'_lr'
                    +str(args['lr'])+"_tr"+str(args['threshold'])+"_lvl" 
                    +str(args['bit'])+"_b"+str(b)+'_e'+str(args['E'])+'dist_'+str(dist))
        bit_path += datetime.datetime.now().isoformat()
        bit_path = bit_path+".npy"
        np.save(bit_path,self.bits_l)



def convergence(node_dict):
    avgw = node_dict[0].prev_dif
    difmat = np.zeros(((len(node_dict), len(node_dict[0].prev_dif))))
    frob = sum(avgw)
    for node_num, node in node_dict.items():
        wdif = np.absolute(node_dict[node_num].prev_dif)
        frob += sum(wdif)

    return frob

def versusNoPlot(iter, N, args, b=20,dist=5, dataset="mnist", sync=False, nbds=[],choco_bit=7, elias=None):
    base_case =copy.deepcopy(args)
    base_case['bit']=-1
    base_case['threshold']=0

    choco_args=copy.deepcopy(args)
    choco_args['bit'] = choco_bit

    choco = Choco(dataset=dataset, elias=elias)
    malcom4 = Malcom4(dataset=dataset, elias=elias)
    base = Malcom4(dataset=dataset, elias=elias)
    if sync:
        choco.synchronous(N, iter,choco_args, dist=dist, b=b, nbds=nbds)
        malcom4.synchronous(N, iter, args, dist=dist, b=b,nbds=nbds)
        base.synchronous(N, iter, base_case, dist=dist, b=b, nbds=nbds)
    else:
        choco.runChoco(N, iter, choco_args, dist=dist, b=b)
        malcom4.runUniQuant(N, iter, args, dist=dist, b=b)
        base.runUniQuant(N, iter, base_case, dist=dist, b=b)

    accuracies =[malcom4.acc_l,choco.acc_l, base.acc_l]
    losses = [malcom4.loss_l,choco.loss_l, base.loss_l]
    bits = [malcom4.bits_l,choco.bits_l, base.bits_l]
    return accuracies,losses,bits
    
def monteCarloVS(iter, N, args, b, dist, dataset, num_trials, sync=False, nbds=[[0,1,2,3,4,5,6,7,8,9]], choco_bit=3, elias=None):
    dim=int((iter-1)/args['acc_idx'])+1
    all_acc = np.zeros(shape=(3,dim))
    all_loss = np.zeros(shape=(3,dim))
    all_bits = np.zeros(shape=(3,dim))

    for i in range(num_trials):
        a, l, bits = versusNoPlot(iter, N, args, b, dist, dataset, sync, nbds, choco_bit, elias)

        all_acc+=a
        all_loss+=l
        all_bits+=bits
    
    all_acc/=num_trials
    all_loss/=num_trials
    all_bits/=num_trials

    loss_path = (args['data_save_folder']+"MC_loss:"+"it"+str(iter)+'_lr'+str(args['lr'])
                        +"_tr"+str(args['threshold'])+"_lvl" 
                        +str(args['bit'])+"_b"+str(b)+'_e'+str(args['E'])+'dist_'+str(dist))
    loss_path += datetime.datetime.now().isoformat()
    loss_path = loss_path+".npy"
    np.save(loss_path,all_loss)

    acc_path = (args['data_save_folder']+"MC_acc:"+"it"+str(iter)+'_lr'+str(args['lr'])+"_tr"+str(args['threshold'])+"_lvl" 
        +str(args['bit'])+"_b"+str(b)+'_e'+str(args['E'])+'dist_'+str(dist))
    acc_path += datetime.datetime.now().isoformat()
    acc_path = acc_path+".npy"
    np.save(acc_path,all_acc)

    bit_path = (args['data_save_folder']+"MC_bit:"+"it"+str(iter)+'_lr'+str(args['lr'])+"_tr"+str(args['threshold'])+"_lvl" 
        +str(args['bit'])+"_b"+str(b)+'_e'+str(args['E'])+'dist_'+str(dist))
    bit_path += datetime.datetime.now().isoformat()
    bit_path = bit_path+".npy"
    np.save(bit_path,all_bits)

    legend =["Choco-SGD","Decentralized-SGD", 'Malcom-PSGD']

    title_acc="Average Network Accuracy over Iterations"
    Plotting.plotAccFinal(all_acc, args["acc_idx"], acc_path, iter, 
            title_str=title_acc, legend=legend, args=args, b=b, dist=dist,y_lims=[10,95], 
            axis_font=20,axis_ticks=16,legend_font=19, redux=1) #"choco",
    title_acc="Average Loss Accuracy over Iterations"
    Plotting.plotlossFinal(all_loss, args["acc_idx"], loss_path, iter, 
            title_str=title_acc, legend=legend, args=args, b=b, dist=dist,
            axis_font=20,axis_ticks=16,legend_font=19, redux=2)
    Plotting.plotBitFinal(all_bits, args["acc_idx"], bit_path, iter, 
            legend=legend, args=args, b=b, dist=dist,axis_font=20,axis_ticks=16,legend_font=19, redux=1)
        






if __name__ == "__main__":
    args = {"lr": .2, "acc_idx": 1, "threshold": 5e-5, 'bit': 7,
             "q_scalar": 1, "dith": 0, "save_path": "data/fnn/", 
             "nn_type":"full", 'reg':0, 'quantize_loop':None,'l':21,
             'alpha':640, 'eps':.7, 'lr_const':True, 'device_str':'cpu', 'E':1, 'optim':"sgd", 
             "data_save_folder":"data/fnn/", 'tr_decay':1} # 1640, 20000, .02 
    
    elias = Encoding.return_dict(args['nn_type'])
    elias = torch.tensor(list(elias.values()),device=args['device_str'])
    monteCarloVS(100, 10, args, dataset='mnist',dist=5,b=200,num_trials=5, choco_bit=args['bit'], elias=elias,sync=True)
    #100: number of iterations
    #10: number of nodes
    #args: parrams for algos
    #dataset: "cifar10"
    #dist: for iid use 1
    #b batch size
    #num_trials: number of times for monte carlo sampling
    #choco_bit=number of levels for choco-SGD
    #elais: required dictionary for encoding
    #sync: Set False to run asynchronous test.