from copy import deepcopy
import torch
import numpy as np
import sys


class Bilevel_Federation():
    
    def __init__(self, num_nodes, near_center_idx, far_center_idx, bl_type, loss_function, train_size, val_size, test_size):
        
        self.num_nodes = num_nodes
        self.near_center_idx = near_center_idx
        self.far_center_idx = far_center_idx
        
        if bl_type == "near":
            self.center_idx = near_center_idx
        else:
            self.center_idx = far_center_idx

        self.loss_function = loss_function
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.train_size = train_size
        self.val_size = val_size
        self.test_size = test_size
    
    # initialize dictionaries for storing the results
    def Initialize_Results(self):
        
        self.w_dict = { "num_points": [], "num_syn": [] }

        for node_idx in range(self.num_nodes):
            self.w_dict[node_idx] = []

        # self.metric_dict= { "num_points": [], "num_syn": [], "top1_train_acc": [], "top2_train_acc": [],
        self.metric_dict= { "num_points": [], "num_syn": [],
                            "near_val_loss": [], "near_top1_val_acc": [], "near_top2_val_acc": [],
                            "near_test_loss": [], "near_top1_test_acc": [], "near_top2_test_acc": [],
                            "far_val_loss": [], "far_top1_val_acc": [], "far_top2_val_acc": [],
                            "far_test_loss": [], "far_top1_test_acc": [], "far_top2_test_acc": [] }

        self.current_num_points = 0
        self.current_num_syn = 0

    # initilaize equal weights for partners and models with the same parameters
    def Initialize_Variables(self, model_dict, train_loaders, val_loaders, test_loaders, num_batches):

        self.model_dict = model_dict # theta inner-level variable and initialize theta
        self.current_num_syn += 1

        self.weight_dict = {} # w outer-level variable
        for node_idx in range(self.num_nodes):
            self.weight_dict[node_idx] = float(1 / self.num_nodes) # initialize w
            self.w_dict[node_idx].append(self.weight_dict[node_idx])
        self.w_dict["num_points"].append(self.current_num_points)
        self.w_dict["num_syn"].append(self.current_num_syn)

        self.Evaluation(val_loaders=val_loaders, test_loaders=test_loaders)

        # divide training data in to batch tuples
        self.mini_batch_tuple_dict = {}
       
        for batch_idx in range(num_batches):
           
            self.mini_batch_tuple_dict[batch_idx] = {}
            for node_idx in range(self.num_nodes):
                self.mini_batch_tuple_dict[batch_idx][node_idx] = list(train_loaders[node_idx])[batch_idx] 
         
    
    # input a gradient in vector form and return it in list form with model parameter like
    def Grad_Vec_to_List(self, model, grad_vec):
        
        pointer = 0
        grad_list = []
        for model_param in model.parameters():
            num_param = model_param.numel()
            grad_list.append(grad_vec[pointer:pointer+num_param].view_as(model_param).data)
            pointer += num_param
           
        return grad_list
    
    # gradient descent step
    def Gradient_Descent(self, node_idx, grad_list, lr):
       
        model = self.model_dict[node_idx]
        model_params_generator = model.parameters()
       
        model_indication = deepcopy(model)
        model_indication = model_indication.to(self.device)
       
        with torch.no_grad():
            # each parameter in the model do gradient descent
            for (idx, param) in enumerate(model_indication.parameters()):
                    param = next(model_params_generator)
                    param.data -= lr * grad_list[idx]
    
    # compute the gradient of theta w.r.t. the whole training data in a node
    def Compute_Full_Gradient_Train(self, node_idx, train_loaders, indication):
        
        # get theta or theta_reference
        if indication == "ref":
            model = self.model_dict_ref[node_idx]
            model.zero_grad()
        else:
            model = self.model_dict[node_idx]
            model.zero_grad()
        
        # get the model parameter generator
        model_params_generator = model.parameters()
        model_params = list(model_params_generator)

        full_dataset_size = 0 # prepare full_batch size for the weighted sum on mini_batches

        for step, batch_data in enumerate(train_loaders[node_idx]):      

            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            full_dataset_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )

            outputs = model(inputs)
            loss = self.loss_function(outputs, labels)

            # get mini_batch gradient
            mini_batch_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss, model_params, create_graph=False))

            # mini_batch gradient is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch gradient
            if step == 0:
                full_gradient = mini_batch_grad * int(mini_batch_size)  
            else:
                full_gradient += mini_batch_grad * int(mini_batch_size)

        # after all mini_batches, re-assign the weight for full_batch gradient
        # full_batch size is the size of training dateset
        full_gradient /= int(full_dataset_size)

        return full_gradient
    
    # compute the gradient of theta w.r.t. the whole validation data in a node
    def Compute_Full_Gradient_Val(self, val_loaders):
        
        # get theta
        model = self.model_dict[self.center_idx]
        model.zero_grad()
        
        # get the model parameter generator
        model_params_generator = model.parameters()
        model_params = list(model_params_generator)

        full_dataset_size = 0 # prepare full_batch size for the weighted sum on mini_batches

        for step, batch_data in enumerate(val_loaders[self.center_idx]):        
            
            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            full_dataset_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )

            outputs = model(inputs)            
            loss = self.loss_function(outputs, labels)

            # get mini_batch gradient
            mini_batch_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss, model_params, create_graph=False))
            
           # mini_batch gradient is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch gradient
            if step == 0:
                full_gradient = mini_batch_grad * int(mini_batch_size)   
            else:
                full_gradient += mini_batch_grad * int(mini_batch_size)
            
        # after all mini_batches, re-assign the weight for full_batch gradient
        # full_batch size is the size of training dateset
        full_gradient /= int(full_dataset_size)

        return full_gradient

    # communicate by weighted fed_avg algorithm
    def Weighted_Fed_Avg(self):
        
        # get all model parameter generators
        model_generator_dict = {}
        for node_idx in range(self.num_nodes):
            model_generator_dict[node_idx] = self.model_dict[node_idx].parameters()
        
        # get the pivot model for showing the order of model parameters
        model_fed_avg = deepcopy(self.model_dict[self.center_idx]) 
        model_fed_avg = model_fed_avg.to(self.device)
       
        with torch.no_grad():
            # parameters in each model multiplied by the respective weight
            for (idx, param) in enumerate(model_fed_avg.parameters()):
               
                param_dict = {}
                for node_idx in range(self.num_nodes):
                    param_dict[node_idx] = next(model_generator_dict[node_idx])
                    param_dict[node_idx] *= float(self.weight_dict[node_idx])
               
                # model parameter in center node collect other parameters and get the sum
                for node_idx in range(self.num_nodes):
                    if node_idx != self.center_idx:
                        param_dict[self.center_idx].data += param_dict[node_idx].data
               
            # other nodes get the result of communication
            model_params = self.model_dict[self.center_idx].state_dict()
           
            for node_idx in range(self.num_nodes):
                self.model_dict[node_idx].load_state_dict(model_params)
                self.model_dict[node_idx].zero_grad()

    # Local-SGD for updating theta
    def Update_Models_SGD(self, val_loaders, test_loaders, num_epoch, num_batches, communicate_period, lr_inner):
        # arguments: D^{train}, T_{1}, tau, eta
       
        step_count = 0 # count in t in { 1, 2, ..., T_{1} }
       
        for i in range(num_epoch): # for t = 0, 1, ..., T_{1}-1
            
            # note that here T_{1} is number of epoches * number of batches
            for batch_idx in range(num_batches): 
                   
                step_count += 1
                
                for node_idx in range(self.num_nodes): # for node i = 1, 2, ..., n in parallel do

                    model = self.model_dict[node_idx] # prepare theta
                    model.zero_grad()
                    model_params_generator = model.parameters()
                    model_params = list(model_params_generator)
                    
                    # choose the mini_batch in the current mini_batch tuple
                    batch_data = self.mini_batch_tuple_dict[batch_idx][node_idx]
                    mini_batch_size = batch_data[0].size(dim=0)

                    inputs, labels = (
                            batch_data[0].to(self.device),
                            batch_data[1].to(self.device),
                        )

                    # apply theta
                    outputs = model(inputs)
                    loss = self.loss_function(outputs, labels)
                    
                    # prepare node mini-batch gradient in vector form (if not, will get a tuple)
                    # this is sgd-gradient
                    mini_batch_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss, model_params, create_graph=False))
                    if node_idx == 0:
                        self.current_num_points += mini_batch_size
                    
                    # transfer the sgd-gradient from vector form to list form with model parameter like
                    grad_list = self.Grad_Vec_to_List(model=self.model_dict[node_idx], grad_vec=mini_batch_grad)

                    # perform gradient descent on theta
                    self.Gradient_Descent(node_idx=node_idx, grad_list=grad_list, lr=lr_inner)

                # perform weighted fed_avg algorithm if conditions are satisfied 
                # always communicate after the last iteration     
                if step_count % communicate_period == 0 or (i+1)*(batch_idx+1) == num_epoch*num_batches:

                    self.Weighted_Fed_Avg()
                    self.current_num_syn += 1
                    
                    if step_count % (5*communicate_period) == 0 or (i+1)*(batch_idx+1) == num_epoch*num_batches:
                        self.Evaluation(val_loaders=val_loaders, test_loaders=test_loaders)

    # Local-SVRG for updating theta
    def Update_Models_SVRG(self, train_loaders, val_loaders, test_loaders, num_epoch, num_batches, communicate_period, lr_inner):
        
        # arguments: D^{train}, T_{1}, tau, eta, q
        self.model_dict_ref = deepcopy(self.model_dict) # initialize theta_reference
        # False means that full gradient of the current model_reference has not been computed
        ref_indicator_dict = {}
        grad_full_dict = {}
        for node_idx in range(self.num_nodes):
            ref_indicator_dict[node_idx] = "False" 
            grad_full_dict[node_idx] = None
       
        step_count = 0 # count in t in { 1, 2, ..., T_{1} }
       
        for i in range(num_epoch): # for t = 0, 1, ..., T_{1}-1

            # note that here T_{1} is number of epoches * number of batches
            for batch_idx in range(num_batches):
                   
                    step_count += 1

                    for node_idx in range(self.num_nodes): # for node i = 1, 2, ..., n in parallel do

                        model = self.model_dict[node_idx] # prepare theta
                        model.zero_grad()
                        model_params_generator = model.parameters()
                        model_params = list(model_params_generator)
                       
                        model_ref = self.model_dict_ref[node_idx] # prepare theta_reference
                        model_ref.zero_grad()
                        model_ref_params_generator = model_ref.parameters()
                        model_ref_params = list(model_ref_params_generator)
                        
                        # choose the mini_batch in the current mini_batch tuple
                        batch_data = self.mini_batch_tuple_dict[batch_idx][node_idx] 
                        mini_batch_size = batch_data[0].size(dim=0)

                        inputs, labels = (
                                batch_data[0].to(self.device),
                                batch_data[1].to(self.device),
                            )

                        # apply theta
                        outputs = model(inputs)
                        loss = self.loss_function(outputs, labels)

                        # prepare node mini-batch gradient in vector form (if not, will get a tuple)
                        mini_batch_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss, model_params, create_graph=False))

                        # apply theta_reference
                        outputs_ref = model_ref(inputs)
                        loss_ref = self.loss_function(outputs_ref, labels)
                        
                        # prepare node mini-batch gradient in vector form (if not, will get a tuple)
                        mini_batch_grad_ref = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss_ref, model_ref_params, create_graph=False))
                        if node_idx == self.center_idx:
                            self.current_num_points += mini_batch_size * 2

                        # compute full gradient on theta_reference
                        if ref_indicator_dict[node_idx] == "False":

                            grad_full_dict[node_idx] = self.Compute_Full_Gradient_Train(node_idx=node_idx, train_loaders=train_loaders, indication="ref")
                            ref_indicator_dict[node_idx] = "True"

                            if node_idx == self.center_idx:
                                self.current_num_points += self.train_size

                        # compute svrg-gradient
                        mini_batch_grad = mini_batch_grad - mini_batch_grad_ref + grad_full_dict[node_idx]

                        # transfer the svrg-gradient from vector form to list form with model parameter like
                        grad_list = self.Grad_Vec_to_List(model=self.model_dict[node_idx], grad_vec=mini_batch_grad)
                        
                        # update theta_reference with probability q for each node
                        # q == 1/num_batches
                        prob_sample = np.random.uniform()
                        
                        # update theta_reference according to prob_sample for each node
                        if prob_sample < 1 / num_batches:
                            self.model_dict_ref[node_idx] = deepcopy(self.model_dict[node_idx])
                            self.model_dict_ref[node_idx].zero_grad()
                            ref_indicator_dict[node_idx] = "False"

                        else:
                            ref_indicator_dict[node_idx] = "True"

                        # perform gradient descent on theta
                        self.Gradient_Descent(node_idx=node_idx, grad_list=grad_list, lr=lr_inner)  

                    # perform weighted fed_avg algorithm if conditions are satisfied 
                    # always communicate after the last iteration  
                    if step_count % communicate_period == 0 or (i+1)*(batch_idx+1) == num_epoch*num_batches:

                        self.Weighted_Fed_Avg()
                        self.current_num_syn += 1

                        if step_count % (5*communicate_period) == 0 or (i+1)*(batch_idx+1) == num_epoch*num_batches:
                            self.Evaluation(val_loaders=val_loaders, test_loaders=test_loaders)
    
    # compute the hv product of g w.r.t. the whole training data in one node
    def Compute_Full_Hv_Product(self, node_idx, train_loaders, x):
        
        # get theta which is fixed in this method
        model = self.model_dict[self.center_idx]
        
        # get the model parameter generator
        model_params_generator = model.parameters()
        model_params = list(model_params_generator)
            
        full_dataset_size = 0 # prepare full_batch size for the weighted sum on mini_batches

        for step, batch_data in enumerate(train_loaders[node_idx]): 
            
            model.zero_grad()

            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            full_dataset_size += mini_batch_size

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )

            outputs = model(inputs)
            loss = self.loss_function(outputs, labels)
            
            # get mini_batch gradient
            mini_batch_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss, model_params, create_graph=True))
            
            # get mini_batch hv product
            mini_batch_hv = torch.autograd.grad(mini_batch_grad, model_params, grad_outputs=x)
            mini_batch_hv = torch.nn.utils.parameters_to_vector(mini_batch_hv).detach()
            
            # mini_batch hv product is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch hv product
            if step == 0:
                full_hv = mini_batch_hv * int(mini_batch_size)
            else:
                full_hv += mini_batch_hv * int(mini_batch_size)
        
        full_hv /= float(full_dataset_size)

        return full_hv

    # Local-SGD for computing hv product of g_val
    def Compute_Hv_Product_SGD(self, b, num_batches, num_epoch, communicate_period, lr_hv):
        
        self.x_dict = {}
        for node_idx in range(self.num_nodes):
            self.x_dict[node_idx] = b.clone().detach() # get the base b of quadratic problem as the initialization
    
        step_count = 0 # count in t in { 1, 2, ..., T_{1} }   

        for i in range(num_epoch): # for t = 0, 1, ..., T_{1}-1

            # note that here T_{1} is number of epoches * number of batches
            for batch_idx in range(num_batches):
                   
                step_count += 1

                for node_idx in range(self.num_nodes):
                    
                    # get theta which is fixed in this method
                    model = self.model_dict[node_idx]
                    model.zero_grad()

                    # get the model parameter generator
                    model_params_generator = model.parameters()
                    model_params = list(model_params_generator)

                    # choose the mini_batch in the current mini_batch tuple
                    batch_data = self.mini_batch_tuple_dict[batch_idx][node_idx]

                    mini_batch_size = batch_data[0].size(dim=0)

                    inputs, labels = (
                            batch_data[0].to(self.device),
                            batch_data[1].to(self.device),
                        )

                    outputs = model(inputs)
                    loss = self.loss_function(outputs, labels)
                   
                    # get mini_batch gradient
                    mini_batch_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss, model_params, create_graph=True))
                   
                    # get mini_batch hv product on x according to SGD
                    mini_batch_hv = torch.autograd.grad(mini_batch_grad, model_params, grad_outputs=self.x_dict[node_idx])
                    mini_batch_hv = torch.nn.utils.parameters_to_vector(mini_batch_hv).detach()
                    if node_idx == self.center_idx:
                        self.current_num_points += mini_batch_size
                    
                    self.x_dict[node_idx] = self.x_dict[node_idx] - lr_hv * ( mini_batch_hv - b )

                if step_count % communicate_period == 0 or (i+1)*(batch_idx+1) == num_epoch*num_batches:

                    for node_idx in range(self.num_nodes): 
                        if node_idx == 0:
                            x = self.x_dict[node_idx] * float(self.weight_dict[node_idx])
                        else:
                            x += self.x_dict[node_idx] * float(self.weight_dict[node_idx])
                    
                    for node_idx in range(self.num_nodes):
                        self.x_dict[node_idx] = x.clone().detach()

                    self.current_num_syn += 1 
        
        return x
    
    # Local-SVRG for computing hv product of g_val
    def Compute_Hv_Product_SVRG(self, train_loaders, b, num_batches, num_epoch, communicate_period, lr_hv):
        
        self.x_dict = {}
        self.x_ref_dict = {}
        self.ref_indicator_dict = {}
        self.full_hv_dict = {}
        for node_idx in range(self.num_nodes):
            self.x_dict[node_idx] = b.clone().detach() # get the base b of quadratic problem as the initialization of x
            self.x_ref_dict[node_idx] = b.clone().detach() # get the base b of quadratic problem as the initialization of x_ref
            self.ref_indicator_dict[node_idx] = "False" # False means that full hv product of the current x_ref has not been computed
            self.full_hv_dict[node_idx] = None

        step_count = 0 # count in t in { 1, 2, ..., T_{1} }   

        for i in range(num_epoch): # for t = 0, 1, ..., T_{1}-1

            # note that here T_{1} is number of epoches * number of batches
            for batch_idx in range(num_batches):
                   
                step_count += 1

                for node_idx in range(self.num_nodes):
                    
                    # get theta which is fixed in this method
                    model = self.model_dict[node_idx]
                    model.zero_grad()

                    # get the model parameter generator
                    model_params_generator = model.parameters()
                    model_params = list(model_params_generator)

                    # choose the mini_batch in the current mini_batch tuple
                    batch_data = self.mini_batch_tuple_dict[batch_idx][node_idx]

                    mini_batch_size = batch_data[0].size(dim=0)

                    inputs, labels = (
                            batch_data[0].to(self.device),
                            batch_data[1].to(self.device),
                        )

                    outputs = model(inputs)
                    loss = self.loss_function(outputs, labels)
                   
                    # get mini_batch gradient
                    mini_batch_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss, model_params, create_graph=True))
                   
                    # get mini_batch hv product on x-x_ref according to SVRG
                    mini_batch_hv = torch.autograd.grad(mini_batch_grad, model_params, grad_outputs=self.x_dict[node_idx]-self.x_ref_dict[node_idx])
                    mini_batch_hv = torch.nn.utils.parameters_to_vector(mini_batch_hv).detach()
                    if node_idx == self.center_idx:
                        self.current_num_points += mini_batch_size * 2
                    
                    # compute full hv product on x_ref
                    if self.ref_indicator_dict[node_idx] == "False":

                        full_hv = self.Compute_Full_Hv_Product(node_idx=node_idx, train_loaders=train_loaders, x=self.x_ref_dict[node_idx])
                        self.ref_indicator_dict[node_idx] = "True"

                        if node_idx == self.center_idx:
                            self.current_num_points += self.train_size
                    
                    self.x_dict[node_idx] = self.x_dict[node_idx] - lr_hv * ( mini_batch_hv + full_hv - b )

                    # update theta_reference with probability q for each node
                    # q == 1/num_batches
                    prob_sample = np.random.uniform()
                    
                    # update theta_reference according to prob_sample for each node
                    if prob_sample < 1 / num_batches:
                        self.x_ref_dict[node_idx] = self.x_dict[node_idx].clone().detach()
                        self.ref_indicator_dict[node_idx] = "False"

                    else:
                        self.ref_indicator_dict[node_idx] = "True"
                
                if step_count % communicate_period == 0 or (i+1)*(batch_idx+1) == num_epoch*num_batches:

                    for node_idx in range(self.num_nodes): 
                        if node_idx == 0:
                            x = self.x_dict[node_idx] * float(self.weight_dict[node_idx])
                        else:
                            x += self.x_dict[node_idx] * float(self.weight_dict[node_idx])
                    
                    for node_idx in range(self.num_nodes):
                        self.x_dict[node_idx] = x.clone().detach()

                    self.current_num_syn += 1 
        
        return x

    # compute gradient of w and update w
    def Update_Weights(self, outer_hv_g, train_loaders, lr_outer, stage, w_ub):  
        
        for node_idx in range(self.num_nodes):
            
            inner_grad = self.Compute_Full_Gradient_Train(node_idx=node_idx, 
                                                          train_loaders=train_loaders, 
                                                          indication="self")
            
            outer_grad = ( outer_hv_g @ inner_grad ).item()

            print("For node {}, the outer gradient in stage {} is: ".format(node_idx, stage), outer_grad)
            
            # outer_grad is computed w.r.t. -g, thus here use addition
            self.weight_dict[node_idx] += lr_outer * outer_grad 
        
        self.current_num_points += self.train_size
        self.current_num_syn += 1

        # project the updated weights to the simplex to keep their sum as 1
        # x = self.Proj_Weights_To_Simplex()
        x = self.Proj_Weights_To_Capped_Simplex(w_ub=w_ub)
        
        print("After stage {}, the weights for partners w.r.t. node {} are".format(stage, self.center_idx), x, "after projection.")
        print("-" * 30)

    # project the updated w to the probability simplex
    def Proj_Weights_To_Simplex(self):
        
        # step 1: w is a list
        # create a copy of the list for sorting
        w = []
        w_sorted = [] # this is w_(i)
        for node_idx in range(self.num_nodes):
            w.append(self.weight_dict[node_idx])
            w_sorted.append(self.weight_dict[node_idx]) # copy all elements from w to w_(i)
    
        # step 2 i): sort w_(i) in ascending order
        w_sorted.sort()
    
        # step 2 ii): get n and i
        n = len(w_sorted) # n is the length of w
        i = n - 1 # i = n-1
    
        # step 3: this is a while loop on i
        while i > 0:
        
            # step 3 i): compute t_i
            t_i = 0
        
            # j take values on i, i+1, ..., n-1
            for j in range(i, n):
                # in python list, they are w_(i+1), w_(i+2), ..., w_(n)
                t_i += w_sorted[j]

            t_i -= 1
            t_i /= (n-i) # get t_i
        
            # step 3 ii): if t_i >= w_(i), then t_hat = t_i and go to step 5
            if t_i >= w_sorted[i-1]:
                t_hat = t_i
            
                # step 5: get x = (w-t_hat)_{+} and return x
                x = [0] * n

                for idx in range(n):
                    x[idx] = max(w[idx]-t_hat, 0) # x = (w-t_hat)_{+}
                
                for node_idx in range(self.num_nodes):
                    self.weight_dict[node_idx] = x[node_idx]
                    self.w_dict[node_idx].append(x[node_idx])
                self.w_dict["num_points"].append(self.current_num_points)
                self.w_dict["num_syn"].append(self.current_num_syn)
                
                # return x as the output
                return x
        
            # else, if i >= 1, return to step 3
            else:
                i = i - 1
    
        # if i == 0, go to step 4
        # step 4: compute t_hat
        t_hat = 0

        for idx in range(n):
            t_hat += w[idx]

        t_hat -= 1
        t_hat /= n # get t_hat
    
        # step 5: get x = (w-t_hat)_{+} and return x
        x = [0] * n

        for idx in range(n):
            x[idx] = max(w[idx]-t_hat, 0) # x = (w-t_hat)_{+}

        for node_idx in range(self.num_nodes):
            self.weight_dict[node_idx] = x[node_idx]
            self.w_dict[node_idx].append(x[node_idx])
        self.w_dict["num_points"].append(self.current_num_points)
        self.w_dict["num_syn"].append(self.current_num_syn)

        # return x as the output
        return x
    
    # project the updated w to the capped probability simplex
    # each weight is restircted in [0, 1/w_ub] where 0 << w_ub << J < K
    def Proj_Weights_To_Capped_Simplex(self, w_ub): 
        
        w = [] # w is an np array
        lb = [] # lb is an np array of 0
        ub = [] # ub is an np array of 1/w_ub

        for node_idx in range(self.num_nodes):
            w.append(self.weight_dict[node_idx])
            lb.append(0)
            ub.append(1/w_ub)

        w = np.array(w)
        lb = np.array(lb)
        ub = np.array(ub)

        n = w.size #  n is the length of w
        total = np.sum(lb)

        lambdas = np.append(lb-w, ub-w)
        idx = np.argsort(lambdas)
        lambdas = lambdas[idx]

        active = 1
        for i in range(1, 2*n):
            total += active*(lambdas[i] - lambdas[i-1])

            if total >= 1:
                lam = (1-total) / active + lambdas[i]

                x = np.clip(w + lam, lb, ub)

                for node_idx in range(self.num_nodes):
                    self.weight_dict[node_idx] = x[node_idx]
                    self.w_dict[node_idx].append(x[node_idx])
                self.w_dict["num_points"].append(self.current_num_points)
                self.w_dict["num_syn"].append(self.current_num_syn)

                return x

            elif idx[i] < n:
                active += 1
            else:
                active -= 1
    
    # evaluate validation and loss, and training, validation and testing accuracy 
    # def Evaluation(self, train_loaders, val_loaders, test_loaders):
    def Evaluation(self, val_loaders, test_loaders):

        # get theta which is fixed in this method
        model = self.model_dict[self.center_idx]
        
        '''
        # training accuracy in bilevel problem
        # this is training accuracy in the inner level of the bilevel problem
        full_train_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
        full_top1_train_acc = 0
        full_top2_train_acc = 0

        for node_idx in range(self.num_nodes):
                
            for step, batch_data in enumerate(train_loaders[node_idx]):        
                
                model.zero_grad()

                mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
                full_train_set_size += int(mini_batch_size)

                inputs, labels = (
                        batch_data[0].to(self.device),
                        batch_data[1].to(self.device),
                    )
                
                outputs = model(inputs)

                # get predicted class labels
                values, top1_predicts = torch.max(outputs, dim=1)
                top1_predicts = top1_predicts.long().to(self.device)

                mini_batch_top1_train_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

                values, top2_predicts = torch.topk(outputs, k=2, dim=1)
                top2_predicts = top2_predicts.long().to(self.device)
                mini_batch_top2_train_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                        .to(self.device).nonzero().reshape(-1).size(dim=0))

                # mini_batch loss is averaged according to this mini_batch, so need to
                # re-assign the weight for mini_batch loss 
                full_top1_train_acc += mini_batch_top1_train_acc 
                full_top2_train_acc += mini_batch_top2_train_acc 

        full_top1_train_acc /= int(full_train_set_size)
        full_top2_train_acc /= int(full_train_set_size)
        self.metric_dict["top1_train_acc"].append(full_top1_train_acc)
        self.metric_dict["top2_train_acc"].append(full_top2_train_acc)
        print("full_top1_train_acc: ", full_top1_train_acc)
        print("full_top2_train_acc: ", full_top2_train_acc)
        print("-" * 30)     
        '''

        # validation loss and validation accuracy in local train
        near_full_val_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
        near_full_val_loss = 0
        near_full_top1_val_acc = 0
        near_full_top2_val_acc = 0

        for step, batch_data in enumerate(val_loaders[self.near_center_idx]):        
            
            model.zero_grad()

            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            near_full_val_set_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )
            
            outputs = model(inputs)

            # get mini_batch loss
            mini_batch_val_loss = self.loss_function(outputs, labels)
            mini_batch_val_loss = mini_batch_val_loss.item()
            
            # get predicted class labels
            values, top1_predicts = torch.max(outputs, dim=1)
            top1_predicts = top1_predicts.long().to(self.device)

            mini_batch_top1_val_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

            values, top2_predicts = torch.topk(outputs, k=2, dim=1)
            top2_predicts = top2_predicts.long().to(self.device)
            mini_batch_top2_val_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                     .to(self.device).nonzero().reshape(-1).size(dim=0))

            # mini_batch loss is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch loss 
            near_full_val_loss += mini_batch_val_loss * int(mini_batch_size)
            near_full_top1_val_acc += mini_batch_top1_val_acc 
            near_full_top2_val_acc += mini_batch_top2_val_acc 

        near_full_val_loss /= int(near_full_val_set_size)
        near_full_top1_val_acc /= int(near_full_val_set_size)
        near_full_top2_val_acc /= int(near_full_val_set_size)
        self.metric_dict["near_val_loss"].append(near_full_val_loss)
        self.metric_dict["near_top1_val_acc"].append(near_full_top1_val_acc)
        self.metric_dict["near_top2_val_acc"].append(near_full_top2_val_acc)
        print("near_full_val_loss: ", near_full_val_loss) 
        print("near_full_top1_val_acc: ", near_full_top1_val_acc)
        print("near_full_top2_val_acc: ", near_full_top2_val_acc)
        print("-" * 30)
        
        far_full_val_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
        far_full_val_loss = 0
        far_full_top1_val_acc = 0
        far_full_top2_val_acc = 0

        for step, batch_data in enumerate(val_loaders[self.far_center_idx]):        
            
            model.zero_grad()

            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            far_full_val_set_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )
            
            outputs = model(inputs)

            # get mini_batch loss
            mini_batch_val_loss = self.loss_function(outputs, labels)
            mini_batch_val_loss = mini_batch_val_loss.item()
            
            # get predicted class labels
            values, top1_predicts = torch.max(outputs, dim=1)
            top1_predicts = top1_predicts.long().to(self.device)

            mini_batch_top1_val_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

            values, top2_predicts = torch.topk(outputs, k=2, dim=1)
            top2_predicts = top2_predicts.long().to(self.device)
            mini_batch_top2_val_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                     .to(self.device).nonzero().reshape(-1).size(dim=0))

            # mini_batch loss is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch loss 
            far_full_val_loss += mini_batch_val_loss * int(mini_batch_size)
            far_full_top1_val_acc += mini_batch_top1_val_acc 
            far_full_top2_val_acc += mini_batch_top2_val_acc 

        far_full_val_loss /= int(far_full_val_set_size)
        far_full_top1_val_acc /= int(far_full_val_set_size)
        far_full_top2_val_acc /= int(far_full_val_set_size)
        self.metric_dict["far_val_loss"].append(far_full_val_loss)
        self.metric_dict["far_top1_val_acc"].append(far_full_top1_val_acc)
        self.metric_dict["far_top2_val_acc"].append(far_full_top2_val_acc)
        print("far_full_val_loss: ", far_full_val_loss) 
        print("far_full_top1_val_acc: ", far_full_top1_val_acc)
        print("far_full_top2_val_acc: ", far_full_top2_val_acc)
        print("-" * 30)

        # testing loss and testing accuracy in bilevel problem
        near_full_test_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
        near_full_test_loss = 0
        near_full_top1_test_acc = 0
        near_full_top2_test_acc = 0

        for step, batch_data in enumerate(test_loaders[self.near_center_idx]):        
            
            model.zero_grad()
            
            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            near_full_test_set_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )
            
            outputs = model(inputs)
            # get mini_batch loss
            mini_batch_test_loss = self.loss_function(outputs, labels)
            mini_batch_test_loss = mini_batch_test_loss.item()
            
            # get predicted class labels
            values, top1_predicts = torch.max(outputs, dim=1)
            top1_predicts = top1_predicts.long().to(self.device)

            mini_batch_top1_test_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

            values, top2_predicts = torch.topk(outputs, k=2, dim=1)
            top2_predicts = top2_predicts.long().to(self.device)
            mini_batch_top2_test_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                      .to(self.device).nonzero().reshape(-1).size(dim=0))

            # mini_batch loss is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch loss 
            near_full_test_loss += mini_batch_test_loss * int(mini_batch_size)
            near_full_top1_test_acc += mini_batch_top1_test_acc 
            near_full_top2_test_acc += mini_batch_top2_test_acc 

        near_full_test_loss /= int(near_full_test_set_size)
        near_full_top1_test_acc /= int(near_full_test_set_size)
        near_full_top2_test_acc /= int(near_full_test_set_size)
        self.metric_dict["near_test_loss"].append(near_full_test_loss)
        self.metric_dict["near_top1_test_acc"].append(near_full_top1_test_acc)
        self.metric_dict["near_top2_test_acc"].append(near_full_top2_test_acc)
        print("near_full_test_loss: ", near_full_test_loss) 
        print("near_full_top1_test_acc: ", near_full_top1_test_acc)
        print("near_full_top2_test_acc: ", near_full_top2_test_acc)
        print("-" * 30)
        
        far_full_test_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
        far_full_test_loss = 0
        far_full_top1_test_acc = 0
        far_full_top2_test_acc = 0

        for step, batch_data in enumerate(test_loaders[self.far_center_idx]):        
            
            model.zero_grad()
            
            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            far_full_test_set_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )
            
            outputs = model(inputs)
            # get mini_batch loss
            mini_batch_test_loss = self.loss_function(outputs, labels)
            mini_batch_test_loss = mini_batch_test_loss.item()
            
            # get predicted class labels
            values, top1_predicts = torch.max(outputs, dim=1)
            top1_predicts = top1_predicts.long().to(self.device)

            mini_batch_top1_test_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

            values, top2_predicts = torch.topk(outputs, k=2, dim=1)
            top2_predicts = top2_predicts.long().to(self.device)
            mini_batch_top2_test_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                      .to(self.device).nonzero().reshape(-1).size(dim=0))

            # mini_batch loss is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch loss 
            far_full_test_loss += mini_batch_test_loss * int(mini_batch_size)
            far_full_top1_test_acc += mini_batch_top1_test_acc 
            far_full_top2_test_acc += mini_batch_top2_test_acc 

        far_full_test_loss /= int(far_full_test_set_size)
        far_full_top1_test_acc /= int(far_full_test_set_size)
        far_full_top2_test_acc /= int(far_full_test_set_size)
        self.metric_dict["far_test_loss"].append(far_full_test_loss)
        self.metric_dict["far_top1_test_acc"].append(far_full_top1_test_acc)
        self.metric_dict["far_top2_test_acc"].append(far_full_top2_test_acc)
        print("far_full_test_loss: ", far_full_test_loss) 
        print("far_full_top1_test_acc: ", far_full_top1_test_acc)
        print("far_full_top2_test_acc: ", far_full_top2_test_acc)
        print("-" * 30)

        sys.stdout.flush()
        self.metric_dict["num_points"].append(self.current_num_points)
        self.metric_dict["num_syn"].append(self.current_num_syn)
    
    # return all evaluation results after all stages
    def Output_Results(self):
        
        # Output the averaged weight across all stages for all nodes
        self.w_result = {}
        for node_idx in range(self.num_nodes):
            self.w_result[node_idx] = sum(self.w_dict[node_idx]) / len(self.w_dict[node_idx])
            
            print("Node {} has averaged weight across stages as: {:.4f}.".format(node_idx, self.w_result[node_idx]))
            print("-" * 30)
            sys.stdout.flush()
        
        return self.w_dict, self.metric_dict