from copy import deepcopy
import torch
import numpy as np
import sys

class Benchmark_Fed_Avg():

    def __init__(self, num_nodes, near_center_idx, far_center_idx, loss_function, train_size, val_size, test_size):
        
        self.num_nodes = num_nodes
        self.near_center_idx = near_center_idx
        self.far_center_idx = far_center_idx
        self.loss_function = loss_function
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.train_size = train_size
        self.val_size = val_size
        self.test_size = test_size
    
    # initialize dictionaries for storing the results
    def Initialize_Results(self):
        
        self.metric_dict = {}
        # self.metric_dict= { "num_points": [], "num_syn": [], "top1_train_acc": [], "top2_train_acc": [],
        self.metric_dict= { "num_points": [], "num_syn": [],
                            "near_val_loss": [], "near_top1_val_acc": [], "near_top2_val_acc": [],
                            "near_test_loss": [], "near_top1_test_acc": [], "near_top2_test_acc": [],
                            "far_val_loss": [], "far_top1_val_acc": [], "far_top2_val_acc": [],
                            "far_test_loss": [], "far_top1_test_acc": [], "far_top2_test_acc": [] }

        self.current_num_points = 0
        self.current_num_syn = 0

    # initilaize equal weights for partners and models with the same parameters
    def Initialize_Variables(self, model_dict, train_loaders, val_loaders, test_loaders, num_batches):

        self.model_dict = model_dict # theta inner-level variable and initialize theta
        self.current_num_syn += 1

        self.Evaluation(val_loaders=val_loaders, test_loaders=test_loaders) 

        # divide training data in to batch tuples
        self.mini_batch_tuple_dict = {}
       
        for batch_idx in range(num_batches):
           
            self.mini_batch_tuple_dict[batch_idx] = {}
            for node_idx in range(self.num_nodes):
                self.mini_batch_tuple_dict[batch_idx][node_idx] = list(train_loaders[node_idx])[batch_idx]      

    # input a gradient in vector form and return it in list form with model parameter like
    def Grad_Vec_to_List(self, model, grad_vec):
        
        pointer = 0
        grad_list = []
        for model_param in model.parameters():
            num_param = model_param.numel()
            grad_list.append(grad_vec[pointer:pointer+num_param].view_as(model_param).data)
            pointer += num_param
           
        return grad_list
    
    # gradient descent step
    def Gradient_Descent(self, node_idx, grad_list, lr):
       
        model = self.model_dict[node_idx]
        model_params_generator = model.parameters()
       
        model_indication = deepcopy(model)
        model_indication = model_indication.to(self.device)
       
        with torch.no_grad():
            # each parameter in the model do gradient descent
            for (idx, param) in enumerate(model_indication.parameters()):
                    param = next(model_params_generator)
                    param.data -= lr * grad_list[idx]
    
    # compute the gradient of theta w.r.t. the whole training data in a node
    def Compute_Full_Gradient_Train(self, node_idx, train_loaders, indication):
        
        # get theta or theta_reference
        if indication == "ref":
            model = self.model_dict_ref[node_idx]
            model.zero_grad()
        else:
            model = self.model_dict[node_idx]
            model.zero_grad()
        
        # get the model parameter generator
        model_params_generator = model.parameters()
        model_params = list(model_params_generator)

        full_dataset_size = 0 # prepare full_batch size for the weighted sum on mini_batches

        for step, batch_data in enumerate(train_loaders[node_idx]):      

            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            full_dataset_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )

            outputs = model(inputs)
            loss = self.loss_function(outputs, labels)

            # get mini_batch gradient
            mini_batch_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss, model_params, create_graph=False))

            # mini_batch gradient is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch gradient
            if step == 0:
                full_gradient = mini_batch_grad * int(mini_batch_size)  
            else:
                full_gradient += mini_batch_grad * int(mini_batch_size)

        # after all mini_batches, re-assign the weight for full_batch gradient
        # full_batch size is the size of training dateset
        full_gradient /= int(full_dataset_size)

        return full_gradient
    
    # communicate by fed_avg algorithm
    def Fed_Avg(self):
        
        # get all model parameter generators
        model_generator_dict = {}
        for node_idx in range(self.num_nodes):
            model_generator_dict[node_idx] = self.model_dict[node_idx].parameters()
        
        # get the pivot model for showing the order of model parameters
        model_fed_avg = deepcopy(self.model_dict[self.near_center_idx]) 
        model_fed_avg = model_fed_avg.to(self.device)
       
        with torch.no_grad():
            # parameters in each model multiplied by the respective weight
            for (idx, param) in enumerate(model_fed_avg.parameters()):
               
                param_dict = {}
                for node_idx in range(self.num_nodes):
                    param_dict[node_idx] = next(model_generator_dict[node_idx])
               
                # model parameter in node 0 collect other parameters and get the sum
                for node_idx in range(self.num_nodes):
                    if node_idx != self.near_center_idx:
                        param_dict[self.near_center_idx].data += param_dict[node_idx].data
                
                param_dict[self.near_center_idx].data /= self.num_nodes
               
            # other nodes get the result of communication
            model_params = self.model_dict[self.near_center_idx].state_dict()
           
            for node_idx in range(self.num_nodes):
                self.model_dict[node_idx].load_state_dict(model_params)
                self.model_dict[node_idx].zero_grad()
    
    # Local-SGD for updating theta
    def Update_Models_SGD(self, train_loaders, val_loaders, test_loaders, num_epoch, num_batches, communicate_period, lr_inner):
        # arguments: D^{train}, T_{1}, tau, eta
       
        step_count = 0 # count in t in { 1, 2, ..., T_{1} }
       
        for i in range(num_epoch): # for t = 0, 1, ..., T_{1}-1
            
            # note that here T_{1} is number of epoches * number of batches
            for batch_idx in range(num_batches): 
                   
                    step_count += 1
                   
                    for node_idx in range(self.num_nodes): # for node i = 1, 2, ..., n in parallel do

                        model = self.model_dict[node_idx] # prepare theta
                        model.zero_grad()
                        model_params_generator = model.parameters()
                        model_params = list(model_params_generator)
                        
                        # choose the mini_batch in the current mini_batch tuple
                        batch_data = self.mini_batch_tuple_dict[batch_idx][node_idx]
                        mini_batch_size = batch_data[0].size(dim=0)

                        inputs, labels = (
                                batch_data[0].to(self.device),
                                batch_data[1].to(self.device),
                            )

                        # apply theta
                        outputs = model(inputs)
                        loss = self.loss_function(outputs, labels)
                        
                        # prepare node mini-batch gradient in vector form (if not, will get a tuple)
                        # this is sgd-gradient
                        mini_batch_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss, model_params, create_graph=False))
                        if node_idx == self.near_center_idx:
                            self.current_num_points += mini_batch_size
                        
                        # transfer the sgd-gradient from vector form to list form with model parameter like
                        grad_list = self.Grad_Vec_to_List(model=self.model_dict[node_idx], grad_vec=mini_batch_grad)

                        # perform gradient descent on theta
                        self.Gradient_Descent(node_idx=node_idx, grad_list=grad_list, lr=lr_inner)

                    # perform weighted fed_avg algorithm if conditions are satisfied 
                    # always communicate after the last iteration     
                    if step_count % communicate_period == 0 or (i+1)*(batch_idx+1) == num_epoch*num_batches:

                        self.Fed_Avg()
                        self.current_num_syn += 1
                        
                        if step_count % (5*communicate_period) == 0 or (i+1)*(batch_idx+1) == num_epoch*num_batches:
                            self.Evaluation(val_loaders=val_loaders, test_loaders=test_loaders)

    # Local-SVRG for updating theta
    def Update_Models_SVRG(self, train_loaders, val_loaders, test_loaders, num_epoch, num_batches, communicate_period, lr_inner):
        # arguments: D^{train}, T_{1}, tau, eta, q
        
        self.model_dict_ref = deepcopy(self.model_dict) # initialize theta_reference
        # False means that full gradient of the current model_reference has not been computed
        self.ref_indicator_dict = {}
        self.full_gradient_dict = {}
        for node_idx in range(self.num_nodes):
            self.ref_indicator_dict[node_idx] = "False" 
            self.full_gradient_dict[node_idx] = None
       
        step_count = 0 # count in t in { 1, 2, ..., T_{1} }
       
        for i in range(num_epoch): # for t = 0, 1, ..., T_{1}-1

            # note that here T_{1} is number of epoches * number of batches
            for batch_idx in range(num_batches):
                   
                    step_count += 1

                    for node_idx in range(self.num_nodes): # for node i = 1, 2, ..., n in parallel do
                            
                        model = self.model_dict[node_idx] # prepare theta
                        model.zero_grad()
                        model_params_generator = model.parameters()
                        model_params = list(model_params_generator)
                    
                        model_ref = self.model_dict_ref[node_idx] # prepare theta_reference
                        model_ref.zero_grad()
                        model_ref_params_generator = model_ref.parameters()
                        model_ref_params = list(model_ref_params_generator)
                        
                        # choose the mini_batch in the current mini_batch tuple
                        batch_data = self.mini_batch_tuple_dict[batch_idx][node_idx] 
                        mini_batch_size = batch_data[0].size(dim=0)

                        inputs, labels = (
                                batch_data[0].to(self.device),
                                batch_data[1].to(self.device),
                            )

                        # apply theta
                        outputs = model(inputs)
                        loss = self.loss_function(outputs, labels)

                        # prepare node mini-batch gradient in vector form (if not, will get a tuple)
                        mini_batch_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss, model_params, create_graph=False))

                        # apply theta_reference
                        outputs_ref = model_ref(inputs)
                        loss_ref = self.loss_function(outputs_ref, labels)
                        
                        # prepare node mini-batch gradient in vector form (if not, will get a tuple)
                        mini_batch_grad_ref = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss_ref, model_ref_params, create_graph=False))
                        if node_idx == self.near_center_idx:
                            self.current_num_points += mini_batch_size * 2

                        # compute full gradient on theta_reference
                        if self.ref_indicator_dict[node_idx] == "False":

                            self.full_gradient_dict[node_idx] = self.Compute_Full_Gradient_Train(node_idx=node_idx, train_loaders=train_loaders, indication="ref")
                            self.ref_indicator_dict[node_idx] = "True"

                            if node_idx == self.near_center_idx:
                                self.current_num_points += self.train_size
                                

                        # compute svrg-gradient
                        mini_batch_grad = mini_batch_grad - mini_batch_grad_ref + self.full_gradient_dict[node_idx]

                        # transfer the svrg-gradient from vector form to list form with model parameter like
                        grad_list = self.Grad_Vec_to_List(model=self.model_dict[node_idx], grad_vec=mini_batch_grad)
                        
                        # update theta_reference with probability q for each node
                        # q == 1 / num_batches
                        prob_sample = np.random.uniform()
                        
                        # update theta_reference according to prob_sample for each node
                        if prob_sample < 1 / num_batches:
                            self.model_dict_ref[node_idx] = deepcopy(self.model_dict[node_idx])
                            self.ref_indicator_dict[node_idx] = "False"
                        else:
                            self.ref_indicator_dict[node_idx] = "True"

                        # perform gradient descent on theta
                        self.Gradient_Descent(node_idx=node_idx, grad_list=grad_list, lr=lr_inner)

                    # perform fed_avg algorithm if conditions are satisfied 
                    # always communicate after the last iteration  
                    if step_count % communicate_period == 0 or (i+1)*(batch_idx+1) == num_epoch*num_batches:
                        
                        self.Fed_Avg()
                        self.current_num_syn += 1

                        if step_count % (5*communicate_period) == 0 or (i+1)*(batch_idx+1) == num_epoch*num_batches:

                            self.Evaluation(val_loaders=val_loaders, test_loaders=test_loaders)
    
    # evaluate validation and loss, and training, validation and testing accuracy 
    # def Evaluation(self, train_loaders, val_loaders, test_loaders):
    def Evaluation(self, val_loaders, test_loaders):
        
        # get theta which is fixed in this method
        model = self.model_dict[self.near_center_idx]
        
        '''
        # training accuracy in bilevel problem
        # this is training accuracy in the inner level of the bilevel problem
        full_train_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
        full_top1_train_acc = 0
        full_top2_train_acc = 0

        for node_idx in range(self.num_nodes):
                
            for step, batch_data in enumerate(train_loaders[node_idx]):        
                
                model.zero_grad()

                mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
                full_train_set_size += int(mini_batch_size)

                inputs, labels = (
                        batch_data[0].to(self.device),
                        batch_data[1].to(self.device),
                    )
                
                outputs = model(inputs)

                # get predicted class labels
                values, top1_predicts = torch.max(outputs, dim=1)
                top1_predicts = top1_predicts.long().to(self.device)

                mini_batch_top1_train_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

                values, top2_predicts = torch.topk(outputs, k=2, dim=1)
                top2_predicts = top2_predicts.long().to(self.device)
                mini_batch_top2_train_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                        .to(self.device).nonzero().reshape(-1).size(dim=0))

                # mini_batch loss is averaged according to this mini_batch, so need to
                # re-assign the weight for mini_batch loss 
                full_top1_train_acc += mini_batch_top1_train_acc 
                full_top2_train_acc += mini_batch_top2_train_acc 

        full_top1_train_acc /= int(full_train_set_size)
        full_top2_train_acc /= int(full_train_set_size)
        self.metric_dict["top1_train_acc"].append(full_top1_train_acc)
        self.metric_dict["top2_train_acc"].append(full_top2_train_acc)
        print("full_top1_train_acc: ", full_top1_train_acc)
        print("full_top2_train_acc: ", full_top2_train_acc)
        print("-" * 30)     
        '''

        # validation loss and validation accuracy in local train
        near_full_val_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
        near_full_val_loss = 0
        near_full_top1_val_acc = 0
        near_full_top2_val_acc = 0

        for step, batch_data in enumerate(val_loaders[self.near_center_idx]):        
            
            model.zero_grad()

            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            near_full_val_set_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )
            
            outputs = model(inputs)

            # get mini_batch loss
            mini_batch_val_loss = self.loss_function(outputs, labels)
            mini_batch_val_loss = mini_batch_val_loss.item()
            
            # get predicted class labels
            values, top1_predicts = torch.max(outputs, dim=1)
            top1_predicts = top1_predicts.long().to(self.device)

            mini_batch_top1_val_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

            values, top2_predicts = torch.topk(outputs, k=2, dim=1)
            top2_predicts = top2_predicts.long().to(self.device)
            mini_batch_top2_val_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                     .to(self.device).nonzero().reshape(-1).size(dim=0))

            # mini_batch loss is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch loss 
            near_full_val_loss += mini_batch_val_loss * int(mini_batch_size)
            near_full_top1_val_acc += mini_batch_top1_val_acc 
            near_full_top2_val_acc += mini_batch_top2_val_acc 

        near_full_val_loss /= int(near_full_val_set_size)
        near_full_top1_val_acc /= int(near_full_val_set_size)
        near_full_top2_val_acc /= int(near_full_val_set_size)
        self.metric_dict["near_val_loss"].append(near_full_val_loss)
        self.metric_dict["near_top1_val_acc"].append(near_full_top1_val_acc)
        self.metric_dict["near_top2_val_acc"].append(near_full_top2_val_acc)
        print("near_full_val_loss: ", near_full_val_loss) 
        print("near_full_top1_val_acc: ", near_full_top1_val_acc)
        print("near_full_top2_val_acc: ", near_full_top2_val_acc)
        print("-" * 30)
        
        far_full_val_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
        far_full_val_loss = 0
        far_full_top1_val_acc = 0
        far_full_top2_val_acc = 0

        for step, batch_data in enumerate(val_loaders[self.far_center_idx]):        
            
            model.zero_grad()

            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            far_full_val_set_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )
            
            outputs = model(inputs)

            # get mini_batch loss
            mini_batch_val_loss = self.loss_function(outputs, labels)
            mini_batch_val_loss = mini_batch_val_loss.item()
            
            # get predicted class labels
            values, top1_predicts = torch.max(outputs, dim=1)
            top1_predicts = top1_predicts.long().to(self.device)

            mini_batch_top1_val_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

            values, top2_predicts = torch.topk(outputs, k=2, dim=1)
            top2_predicts = top2_predicts.long().to(self.device)
            mini_batch_top2_val_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                     .to(self.device).nonzero().reshape(-1).size(dim=0))

            # mini_batch loss is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch loss 
            far_full_val_loss += mini_batch_val_loss * int(mini_batch_size)
            far_full_top1_val_acc += mini_batch_top1_val_acc 
            far_full_top2_val_acc += mini_batch_top2_val_acc 

        far_full_val_loss /= int(far_full_val_set_size)
        far_full_top1_val_acc /= int(far_full_val_set_size)
        far_full_top2_val_acc /= int(far_full_val_set_size)
        self.metric_dict["far_val_loss"].append(far_full_val_loss)
        self.metric_dict["far_top1_val_acc"].append(far_full_top1_val_acc)
        self.metric_dict["far_top2_val_acc"].append(far_full_top2_val_acc)
        print("far_full_val_loss: ", far_full_val_loss) 
        print("far_full_top1_val_acc: ", far_full_top1_val_acc)
        print("far_full_top2_val_acc: ", far_full_top2_val_acc)
        print("-" * 30)

        # testing loss and testing accuracy in bilevel problem
        near_full_test_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
        near_full_test_loss = 0
        near_full_top1_test_acc = 0
        near_full_top2_test_acc = 0

        for step, batch_data in enumerate(test_loaders[self.near_center_idx]):        
            
            model.zero_grad()
            
            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            near_full_test_set_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )
            
            outputs = model(inputs)
            # get mini_batch loss
            mini_batch_test_loss = self.loss_function(outputs, labels)
            mini_batch_test_loss = mini_batch_test_loss.item()
            
            # get predicted class labels
            values, top1_predicts = torch.max(outputs, dim=1)
            top1_predicts = top1_predicts.long().to(self.device)

            mini_batch_top1_test_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

            values, top2_predicts = torch.topk(outputs, k=2, dim=1)
            top2_predicts = top2_predicts.long().to(self.device)
            mini_batch_top2_test_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                      .to(self.device).nonzero().reshape(-1).size(dim=0))

            # mini_batch loss is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch loss 
            near_full_test_loss += mini_batch_test_loss * int(mini_batch_size)
            near_full_top1_test_acc += mini_batch_top1_test_acc 
            near_full_top2_test_acc += mini_batch_top2_test_acc 

        near_full_test_loss /= int(near_full_test_set_size)
        near_full_top1_test_acc /= int(near_full_test_set_size)
        near_full_top2_test_acc /= int(near_full_test_set_size)
        self.metric_dict["near_test_loss"].append(near_full_test_loss)
        self.metric_dict["near_top1_test_acc"].append(near_full_top1_test_acc)
        self.metric_dict["near_top2_test_acc"].append(near_full_top2_test_acc)
        print("near_full_test_loss: ", near_full_test_loss) 
        print("near_full_top1_test_acc: ", near_full_top1_test_acc)
        print("near_full_top2_test_acc: ", near_full_top2_test_acc)
        print("-" * 30)
        
        far_full_test_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
        far_full_test_loss = 0
        far_full_top1_test_acc = 0
        far_full_top2_test_acc = 0

        for step, batch_data in enumerate(test_loaders[self.far_center_idx]):        
            
            model.zero_grad()
            
            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            far_full_test_set_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )
            
            outputs = model(inputs)
            # get mini_batch loss
            mini_batch_test_loss = self.loss_function(outputs, labels)
            mini_batch_test_loss = mini_batch_test_loss.item()
            
            # get predicted class labels
            values, top1_predicts = torch.max(outputs, dim=1)
            top1_predicts = top1_predicts.long().to(self.device)

            mini_batch_top1_test_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

            values, top2_predicts = torch.topk(outputs, k=2, dim=1)
            top2_predicts = top2_predicts.long().to(self.device)
            mini_batch_top2_test_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                      .to(self.device).nonzero().reshape(-1).size(dim=0))

            # mini_batch loss is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch loss 
            far_full_test_loss += mini_batch_test_loss * int(mini_batch_size)
            far_full_top1_test_acc += mini_batch_top1_test_acc 
            far_full_top2_test_acc += mini_batch_top2_test_acc 

        far_full_test_loss /= int(far_full_test_set_size)
        far_full_top1_test_acc /= int(far_full_test_set_size)
        far_full_top2_test_acc /= int(far_full_test_set_size)
        self.metric_dict["far_test_loss"].append(far_full_test_loss)
        self.metric_dict["far_top1_test_acc"].append(far_full_top1_test_acc)
        self.metric_dict["far_top2_test_acc"].append(far_full_top2_test_acc)
        print("far_full_test_loss: ", far_full_test_loss) 
        print("far_full_top1_test_acc: ", far_full_top1_test_acc)
        print("far_full_top2_test_acc: ", far_full_top2_test_acc)
        print("-" * 30)

        sys.stdout.flush()
        self.metric_dict["num_points"].append(self.current_num_points)
        self.metric_dict["num_syn"].append(self.current_num_syn)
    
    # return all evaluation results after all stages
    def Output_Results(self):
        
        return self.metric_dict