import numpy as np
import random
import torch
import collections as ct
from math import floor


# Class for pre-processing dataset
class Data_Preprocess():

    def __init__(self, train_set_data, train_set_targets, test_set_data, test_set_targets, label_list_1, 
                 label_list_2, label_list_3, label_list_4, near_center_idx, far_center_idx):
       
        self.train_set_data = train_set_data 
        self.train_set_targets = train_set_targets 
        self.train_set_targets_list = list(train_set_targets)
       
        self.test_set_data = test_set_data 
        self.test_set_targets = test_set_targets 
        self.test_set_targets_list = list(test_set_targets)

        self.label_list_1 = label_list_1 
        self.label_list_2 = label_list_2 
        self.label_list_3 = label_list_3
        self.label_list_4 = label_list_4

        self.label_list = self.label_list_1 + self.label_list_2 + self.label_list_3 + self.label_list_4

        self.near_center_idx = near_center_idx
        self.far_center_idx = far_center_idx
   
    def Decide_Dataset_Sizes(self, num_nodes, train_size, val_size, test_size):
        
        self.num_nodes = num_nodes 

        self.train_size = train_size 
        self.val_size = val_size 
        self.test_size = test_size 
        
        self.datasets_info = {}
        for node_idx in range(num_nodes):
            self.datasets_info[node_idx] = {"train": { "size": self.train_size } }
            self.datasets_info[node_idx]["val"] = { "size": self.val_size }
            self.datasets_info[node_idx]["test"] = { "size": self.test_size }
           
        for node_idx in range(num_nodes):
            if node_idx == self.near_center_idx:
                print("Node {} (near center) has {} training images, {} validation images and {} testing images.".format(node_idx,
                                                                                                                         self.datasets_info[node_idx]["train"]["size"],
                                                                                                                         self.datasets_info[node_idx]["val"]["size"],
                                                                                                                         self.datasets_info[node_idx]["test"]["size"]))
            elif node_idx == self.far_center_idx:
                print("Node {} (far center) has {} training images, {} validation images and {} testing images.".format(node_idx,
                                                                                                                        self.datasets_info[node_idx]["train"]["size"],
                                                                                                                        self.datasets_info[node_idx]["val"]["size"],
                                                                                                                        self.datasets_info[node_idx]["test"]["size"]))
            else:
                print("Node {} has {} training images, {} validation images and {} testing images.".format(node_idx,
                                                                                                           self.datasets_info[node_idx]["train"]["size"],
                                                                                                           self.datasets_info[node_idx]["val"]["size"],
                                                                                                           self.datasets_info[node_idx]["test"]["size"]))
            print("-" * 30)
           
    def Resample_Data(self, near_node_list, seed, near_p_dict, far_p_dict, margin):

        np.random.seed(seed=seed)
        self.near_node_list = near_node_list

        for node_idx in range(self.num_nodes):
            if node_idx in self.near_node_list:
                self.Data_Size_Generating(node_idx=node_idx, indication="train", p_dict=near_p_dict, margin=margin)
                self.Data_Size_Generating(node_idx=node_idx, indication="val", p_dict=near_p_dict, margin=margin)
                self.Data_Size_Generating(node_idx=node_idx,  indication="test", p_dict=near_p_dict, margin=margin)
            else:
                self.Data_Size_Generating(node_idx=node_idx, indication="train", p_dict=far_p_dict, margin=margin)
                self.Data_Size_Generating(node_idx=node_idx, indication="val", p_dict=far_p_dict, margin=margin)
                self.Data_Size_Generating(node_idx=node_idx,  indication="test", p_dict=far_p_dict, margin=margin)

        for node_idx in range(self.num_nodes):
            if node_idx == self.near_center_idx:
                print("Node {} (near center) has training labels distributed by:".format(node_idx), self.datasets_info[node_idx]["train"])
                print("Node {} (near center) has validation labels distributed by:".format(node_idx), self.datasets_info[node_idx]["val"])
                print("Node {} (near center) has testing labels distributed by:".format(node_idx), self.datasets_info[node_idx]["test"])
            elif node_idx == self.far_center_idx:
                print("Node {} (far center) has training labels distributed by:".format(node_idx), self.datasets_info[node_idx]["train"])
                print("Node {} (far center) has validation labels distributed by:".format(node_idx), self.datasets_info[node_idx]["val"])
                print("Node {} (far center) has testing labels distributed by:".format(node_idx), self.datasets_info[node_idx]["test"])
            else:
                print("Node {} has training labels distributed by:".format(node_idx), self.datasets_info[node_idx]["train"])
                print("Node {} has validation labels distributed by:".format(node_idx), self.datasets_info[node_idx]["val"])
                print("Node {} has testing labels distributed by:".format(node_idx), self.datasets_info[node_idx]["test"])
            print("-" * 30)

    def Data_Size_Generating(self, node_idx, indication, p_dict, margin):
        
        # get probabilities for 4 parts in the label list
        label_list_1_p = p_dict["label_list_1"]
        label_list_2_p = p_dict["label_list_2"]
        label_list_3_p = p_dict["label_list_3"]
        label_list_4_p = p_dict["label_list_4"]
        
        # get dataset size
        size = self.datasets_info[node_idx][indication]["size"]
            
        label_list_1_size = int(round(size*label_list_1_p, ndigits=0)) # round the number and change it to integer type
        partial_sum = 0
        for idx, label in enumerate(self.label_list_1):
            # if this is the last label in the sublist, get number of data on this label 
            # by subtracting the partial sum ftom sublist size
            if idx == len(self.label_list_1) - 1:
                self.datasets_info[node_idx][indication][label] = label_list_1_size - partial_sum  
                continue

            # each label approximates the mean
            label_size = label_list_1_size/len(self.label_list_1) + np.random.uniform(low=-margin, high=margin)
            label_size = int(round(label_size, ndigits=0))

            self.datasets_info[node_idx][indication][label] = label_size
            # add the size to the partial sum
            partial_sum += label_size
        
        label_list_2_size = int(round(size*label_list_2_p, ndigits=0))
        partial_sum = 0
        for idx, label in enumerate(self.label_list_2):
            if idx == len(self.label_list_2) - 1:
                self.datasets_info[node_idx][indication][label] = label_list_2_size - partial_sum  
                continue
            label_size = label_list_2_size/len(self.label_list_2) + np.random.uniform(low=-margin, high=margin)
            label_size = int(round(label_size, ndigits=0))
            self.datasets_info[node_idx][indication][label] = label_size
            partial_sum += label_size
        
        label_list_3_size = int(round(size*label_list_3_p, ndigits=0))
        partial_sum = 0
        for idx, label in enumerate(self.label_list_3):
            if idx == len(self.label_list_3) - 1:
                self.datasets_info[node_idx][indication][label] = label_list_3_size - partial_sum  
                continue
            label_size = label_list_3_size/len(self.label_list_3) + np.random.uniform(low=-margin, high=margin)
            label_size = int(round(label_size, ndigits=0))
            self.datasets_info[node_idx][indication][label] = label_size
            partial_sum += label_size
        
        label_list_4_size = size - label_list_1_size - label_list_2_size - label_list_3_size
        partial_sum = 0
        for idx, label in enumerate(self.label_list_4):
            if idx == len(self.label_list_4) - 1:
                self.datasets_info[node_idx][indication][label] = label_list_4_size - partial_sum  
                continue
            label_size = label_list_4_size/len(self.label_list_4) + np.random.uniform(low=-margin, high=margin)
            label_size = int(round(label_size, ndigits=0))
            self.datasets_info[node_idx][indication][label] = label_size
            partial_sum += label_size
        '''
        else: # for dataset in far nodes

            label_list_1_size = int(round(size*label_list_1_p, ndigits=0))
            label_list_1_p_list = []
            label_list_1_p_sum = 0

            for idx in range(len(self.label_list_1)):
                label_list_1_p_list.append(np.random.uniform(0.25, 1))
                label_list_1_p_sum += label_list_1_p_list[-1]
            for idx in range(len(label_list_1_p_list)):
                label_list_1_p_list[idx] /= label_list_1_p_sum
            label_list_1_distribution = np.random.multinomial(label_list_1_size, label_list_1_p_list)
            for idx, label in enumerate(self.label_list_1):
                self.datasets_info[node_idx][indication][label] = label_list_1_distribution[idx]

            label_list_2_size = int(round(size*label_list_2_p, ndigits=0))
            label_list_2_p_list = []
            label_list_2_p_sum = 0

            for idx in range(len(self.label_list_2)):
                label_list_2_p_list.append(np.random.uniform(0.9999, 1))
                label_list_2_p_sum += label_list_2_p_list[-1]
            for idx in range(len(label_list_2_p_list)):
                label_list_2_p_list[idx] /= label_list_2_p_sum
            label_list_2_distribution = np.random.multinomial(label_list_2_size, label_list_2_p_list)
            for idx, label in enumerate(self.label_list_2):
                self.datasets_info[node_idx][indication][label] = label_list_2_distribution[idx]

            label_list_3_size = int(round(size*label_list_3_p, ndigits=0))
            label_list_3_p_list = []
            label_list_3_p_sum = 0

            for idx in range(len(self.label_list_3)):
                label_list_3_p_list.append(np.random.uniform(0.25, 1))
                label_list_3_p_sum += label_list_3_p_list[-1]
            for idx in range(len(label_list_3_p_list)):
                label_list_3_p_list[idx] /= label_list_3_p_sum
            label_list_3_distribution = np.random.multinomial(label_list_3_size, label_list_3_p_list)
            for idx, label in enumerate(self.label_list_3):
                self.datasets_info[node_idx][indication][label] = label_list_3_distribution[idx]
        
            label_list_4_size = size - label_list_1_size - label_list_2_size - label_list_3_size
            label_list_4_p_list = []
            label_list_4_p_sum = 0

            for idx in range(len(self.label_list_4)):
                label_list_4_p_list.append(np.random.uniform(0.5, 1))
                label_list_4_p_sum += label_list_4_p_list[-1]
            for idx in range(len(label_list_4_p_list)):
                label_list_4_p_list[idx] /= label_list_4_p_sum
            label_list_4_distribution = np.random.multinomial(label_list_4_size, label_list_4_p_list)
            for idx, label in enumerate(self.label_list_4):
                self.datasets_info[node_idx][indication][label] = label_list_4_distribution[idx]
        '''
    # Setting 1 in dataset generation
    def Dataset_Generating(self, seed):

        random.seed(a=seed)
        self.datasets = {}
        for node_idx in range(self.num_nodes):

            self.datasets[node_idx] = { "train": { "images": None, "targets": None } }
     
            for idx, label in enumerate(self.label_list):

                train_this_label_list = [idx for idx in range(len(self.train_set_targets_list)) if self.train_set_targets_list[idx] == label]
                this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["train"][label])
                this_label_index_list = torch.LongTensor(this_label_index_list)

                if idx == 0:
                    self.datasets[node_idx]["train"]["images"] = self.train_set_data.index_select(dim=0, index=this_label_index_list)
                    self.datasets[node_idx]["train"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                else:
                    self.datasets[node_idx]["train"]["images"] = torch.cat((self.datasets[node_idx]["train"]["images"],
                                                                            self.train_set_data.index_select(dim=0, index=this_label_index_list)), dim=0)
                    self.datasets[node_idx]["train"]["targets"] = torch.cat((self.datasets[node_idx]["train"]["targets"],
                                                                            torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            self.datasets[node_idx]["val"] = { "images": None, "targets": None }

            for idx, label in enumerate(self.label_list):

                train_this_label_list = [idx for idx in range(len(self.train_set_targets_list)) if self.train_set_targets_list[idx] == label]
                this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["val"][label])
                this_label_index_list = torch.LongTensor(this_label_index_list)

                if idx == 0:
                    self.datasets[node_idx]["val"]["images"] = self.train_set_data.index_select(dim=0, index=this_label_index_list)
                    self.datasets[node_idx]["val"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                else:
                    self.datasets[node_idx]["val"]["images"] = torch.cat((self.datasets[node_idx]["val"]["images"],
                                                                          self.train_set_data.index_select(dim=0, index=this_label_index_list)), dim=0)
                    self.datasets[node_idx]["val"]["targets"] = torch.cat((self.datasets[node_idx]["val"]["targets"],
                                                                           torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
        
            self.datasets[node_idx]["test"] = { "images": None, "targets": None }

            for idx, label in enumerate(self.label_list):

                test_this_label_list = [idx for idx in range(len(self.test_set_targets_list)) if self.test_set_targets_list[idx] == label]
                this_label_index_list = random.sample(test_this_label_list, self.datasets_info[node_idx]["test"][label])
                this_label_index_list = torch.LongTensor(this_label_index_list)
                if idx == 0:
                    self.datasets[node_idx]["test"]["images"] = self.test_set_data.index_select(dim=0, index=this_label_index_list)
    
                    self.datasets[node_idx]["test"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                else:
                    self.datasets[node_idx]["test"]["images"] = torch.cat((self.datasets[node_idx]["test"]["images"],
                                                                           self.test_set_data.index_select(dim=0, index=this_label_index_list)), dim=0)
        
                    self.datasets[node_idx]["test"]["targets"] = torch.cat((self.datasets[node_idx]["test"]["targets"],
                                                                            torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
           
        for node_idx in range(self.num_nodes):
            
            if node_idx == self.near_center_idx:

                print("Node {} (near center) has training labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))
                print("Node {} (near center) has validation labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))
                print("Node {} (near center) has testing labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))
            
            elif node_idx == self.far_center_idx:
                
                print("Node {} (far center) has training labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))
                print("Node {} (far center) has validation labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))
                print("Node {} (far center) has testing labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))

            else:

                print("Node {} has training labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))
                print("Node {} has validation labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))
                print("Node {} has testing labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))
            
            print("-" * 30)

    # Setting 2 in dataset generation
    def Dataset_Generating_and_Label_Permute(self, seed, label_permute):

        random.seed(a=seed)
        self.datasets = {}
        for node_idx in range(self.num_nodes):

            self.datasets[node_idx] = { "train": { "images": None, "targets": None } }

            if node_idx in self.near_node_list:

                for idx, label in enumerate(self.label_list):

                    train_this_label_list = [idx for idx in range(len(self.train_set_targets_list)) if self.train_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["train"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)
                    
                    if idx == 0:
                        self.datasets[node_idx]["train"]["images"] = self.train_set_data.index_select(dim=0, index=this_label_index_list)
                        self.datasets[node_idx]["train"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["train"]["images"] = torch.cat((self.datasets[node_idx]["train"]["images"],
                                                                                self.train_set_data.index_select(dim=0, index=this_label_index_list)), dim=0)
                        self.datasets[node_idx]["train"]["targets"] = torch.cat((self.datasets[node_idx]["train"]["targets"],
                                                                                 torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            else:
                # permute labels according to the preset dictionary for training set in far nodes
                for idx, label in enumerate(self.label_list):

                    train_this_label_list = [idx for idx in range(len(self.train_set_targets_list)) if self.train_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["train"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)

                    if idx == 0:
                        self.datasets[node_idx]["train"]["images"] = self.train_set_data.index_select(dim=0, index=this_label_index_list)

                        if label in label_permute.keys():
                            self.datasets[node_idx]["train"]["targets"] = torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))
                        else:
                            self.datasets[node_idx]["train"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["train"]["images"] = torch.cat((self.datasets[node_idx]["train"]["images"],
                                                                                self.train_set_data.index_select(dim=0, index=this_label_index_list)), dim=0)

                        if label in label_permute.keys():
                            self.datasets[node_idx]["train"]["targets"] = torch.cat((self.datasets[node_idx]["train"]["targets"],
                                                                                     torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))), dim=0)
                        else:
                            self.datasets[node_idx]["train"]["targets"] = torch.cat((self.datasets[node_idx]["train"]["targets"],
                                                                                    torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)

            self.datasets[node_idx]["val"] = { "images": None, "targets": None }

            if node_idx in self.near_node_list:
                
                for idx, label in enumerate(self.label_list):

                    train_this_label_list = [idx for idx in range(len(self.train_set_targets_list)) if self.train_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["val"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)
                    
                    if idx == 0:
                        self.datasets[node_idx]["val"]["images"] = self.train_set_data.index_select(dim=0, index=this_label_index_list)
                        self.datasets[node_idx]["val"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["val"]["images"] = torch.cat((self.datasets[node_idx]["val"]["images"],
                                                                                self.train_set_data.index_select(dim=0, index=this_label_index_list)), dim=0)
                        self.datasets[node_idx]["val"]["targets"] = torch.cat((self.datasets[node_idx]["val"]["targets"],
                                                                                 torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            else:
                # permute labels according to the preset dictionary for validation set in far nodes
                for idx, label in enumerate(self.label_list):

                    train_this_label_list = [idx for idx in range(len(self.train_set_targets_list)) if self.train_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["val"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)

                    if idx == 0:
                        self.datasets[node_idx]["val"]["images"] = self.train_set_data.index_select(dim=0, index=this_label_index_list)

                        if label in label_permute.keys():
                            self.datasets[node_idx]["val"]["targets"] = torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))
                        else:
                            self.datasets[node_idx]["val"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["val"]["images"] = torch.cat((self.datasets[node_idx]["val"]["images"],
                                                                              self.train_set_data.index_select(dim=0, index=this_label_index_list)), dim=0)

                        if label in label_permute.keys():
                            self.datasets[node_idx]["val"]["targets"] = torch.cat((self.datasets[node_idx]["val"]["targets"],
                                                                                   torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))), dim=0)
                        else:
                            self.datasets[node_idx]["val"]["targets"] = torch.cat((self.datasets[node_idx]["val"]["targets"],
                                                                                   torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            self.datasets[node_idx]["test"] = { "images": None, "targets": None }

            if node_idx in self.near_node_list:
                
                for idx, label in enumerate(self.label_list):

                    test_this_label_list = [idx for idx in range(len(self.test_set_targets_list)) if self.test_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(test_this_label_list, self.datasets_info[node_idx]["test"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)
                    
                    if idx == 0:
                        self.datasets[node_idx]["test"]["images"] = self.test_set_data.index_select(dim=0, index=this_label_index_list)
                        self.datasets[node_idx]["test"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["test"]["images"] = torch.cat((self.datasets[node_idx]["test"]["images"],
                                                                                self.test_set_data.index_select(dim=0, index=this_label_index_list)), dim=0)
                        self.datasets[node_idx]["test"]["targets"] = torch.cat((self.datasets[node_idx]["test"]["targets"],
                                                                                 torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            else:
                # permute labels according to the preset dictionary for testing set in far nodes
                for idx, label in enumerate(self.label_list):

                    test_this_label_list = [idx for idx in range(len(self.test_set_targets_list)) if self.test_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(test_this_label_list, self.datasets_info[node_idx]["test"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)

                    if idx == 0:
                        self.datasets[node_idx]["test"]["images"] = self.test_set_data.index_select(dim=0, index=this_label_index_list)

                        if label in label_permute.keys():
                            self.datasets[node_idx]["test"]["targets"] = torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))
                        else:
                            self.datasets[node_idx]["test"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["test"]["images"] = torch.cat((self.datasets[node_idx]["test"]["images"],
                                                                              self.test_set_data.index_select(dim=0, index=this_label_index_list)), dim=0)

                        if label in label_permute.keys():
                            self.datasets[node_idx]["test"]["targets"] = torch.cat((self.datasets[node_idx]["test"]["targets"],
                                                                                   torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))), dim=0)
                        else:
                            self.datasets[node_idx]["test"]["targets"] = torch.cat((self.datasets[node_idx]["test"]["targets"],
                                                                                   torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
           
        for node_idx in range(self.num_nodes):
            
            if node_idx == self.near_center_idx:

                print("Node {} (near center) has training labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["train"])
                print("Node {} (near center) has training labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))

                print("Node {} (near center) has validation labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["val"])
                print("Node {} (near center) has validation labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))

                print("Node {} (near center) has testing labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["test"])
                print("Node {} (near center) has testing labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))
            
            elif node_idx == self.far_center_idx:

                print("Node {} (far center) has training labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["train"])
                print("Node {} (far center) has training labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))

                print("Node {} (far center) has validation labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["val"])
                print("Node {} (far center) has validation labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))

                print("Node {} (far center) has testing labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["test"])
                print("Node {} (far center) has testing labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))

            else:

                print("Node {} has training labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["train"])
                print("Node {} has training labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))
                
                print("Node {} has validation labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["val"])
                print("Node {} has validation labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))

                print("Node {} has testing labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["test"])
                print("Node {} has testing labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))
            
            print("-" * 30)
    
    # Setting 3 in dataset generation
    def Dataset_Generating_and_Rotation(self, seed):

        random.seed(a=seed)
        # By torch.rot90(), rotation direction is from the first towards the second axis if rotate_choice == 1, 
        # and from the second towards the first for rotate_choice == -1.
        rotate_choice = random.choice(seq=[-1,1])
        if rotate_choice == 1:
            print("rotation direction is from the first towards the second axis.")
        else:
            print("rotation direction is from the second towards the first axis.")
        print("-" * 30)

        self.datasets = {}
        for node_idx in range(self.num_nodes):

            self.datasets[node_idx] = { "train": { "images": None, "targets": None } }

            if node_idx in self.near_node_list:
                for idx, label in enumerate(self.label_list):
                    
                    train_this_label_list = [idx for idx in range(len(self.train_set_targets_list)) if self.train_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["train"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)
                    
                    if idx == 0:
                        self.datasets[node_idx]["train"]["images"] = self.train_set_data.index_select(dim=0, index=this_label_index_list)
                        self.datasets[node_idx]["train"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["train"]["images"] = torch.cat((self.datasets[node_idx]["train"]["images"],
                                                                                self.train_set_data.index_select(dim=0, index=this_label_index_list)), dim=0)
                        self.datasets[node_idx]["train"]["targets"] = torch.cat((self.datasets[node_idx]["train"]["targets"],
                                                                                 torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            else:
                # rotate data by image for training set in far nodes
                for idx, label in enumerate(self.label_list):

                    train_this_label_list = [idx for idx in range(len(self.train_set_targets_list)) if self.train_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["train"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)

                    if idx == 0:

                        data_points = self.train_set_data.index_select(dim=0, index=this_label_index_list)  
                        num_data_points = data_points.size(dim=0)

                        for data_point_idx in range(num_data_points):
                            if data_point_idx == 0:
                                self.datasets[node_idx]["train"]["images"] = torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                         k=rotate_choice, dims=[2,3])
                                '''
                                self.datasets[node_idx]["train"]["images"] = torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                         k=random.choice(seq=[-1,0,1]), dims=[2,3])
                                '''
                            else:
                                self.datasets[node_idx]["train"]["images"] = torch.cat((self.datasets[node_idx]["train"]["images"],
                                                                                        torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                                    k=rotate_choice, dims=[2,3])), dim=0)

                        self.datasets[node_idx]["train"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    
                    else:

                        data_points = self.train_set_data.index_select(dim=0, index=this_label_index_list)
                        num_data_points = data_points.size(dim=0)
                        
                        for data_point_idx in range(num_data_points):
                        
                            self.datasets[node_idx]["train"]["images"] = torch.cat((self.datasets[node_idx]["train"]["images"],
                                                                                    torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                                k=rotate_choice, dims=[2,3])), dim=0)
                    
                        self.datasets[node_idx]["train"]["targets"] = torch.cat((self.datasets[node_idx]["train"]["targets"],
                                                                                 torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)

            self.datasets[node_idx]["val"] = { "images": None, "targets": None }

            if node_idx in self.near_node_list:
                for idx, label in enumerate(self.label_list):
                    
                    train_this_label_list = [idx for idx in range(len(self.train_set_targets_list)) if self.train_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["val"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)
                    
                    if idx == 0:
                        self.datasets[node_idx]["val"]["images"] = self.train_set_data.index_select(dim=0, index=this_label_index_list)
                        self.datasets[node_idx]["val"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["val"]["images"] = torch.cat((self.datasets[node_idx]["val"]["images"],
                                                                                self.train_set_data.index_select(dim=0, index=this_label_index_list)), dim=0)
                        self.datasets[node_idx]["val"]["targets"] = torch.cat((self.datasets[node_idx]["val"]["targets"],
                                                                                 torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            else:
                # rotate data by image for validation set in far nodes
                for idx, label in enumerate(self.label_list):

                    train_this_label_list = [idx for idx in range(len(self.train_set_targets_list)) if self.train_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["val"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)

                    if idx == 0:

                        data_points = self.train_set_data.index_select(dim=0, index=this_label_index_list)  
                        num_data_points = data_points.size(dim=0)

                        for data_point_idx in range(num_data_points):
                            if data_point_idx == 0:
                                self.datasets[node_idx]["val"]["images"] = torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                         k=rotate_choice, dims=[2,3])
                            else:
                                self.datasets[node_idx]["val"]["images"] = torch.cat((self.datasets[node_idx]["val"]["images"],
                                                                                        torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                                    k=rotate_choice, dims=[2,3])), dim=0)

                        self.datasets[node_idx]["val"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    
                    else:

                        data_points = self.train_set_data.index_select(dim=0, index=this_label_index_list)
                        num_data_points = data_points.size(dim=0)
                        
                        for data_point_idx in range(num_data_points):
                        
                            self.datasets[node_idx]["val"]["images"] = torch.cat((self.datasets[node_idx]["val"]["images"],
                                                                                    torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                                k=rotate_choice, dims=[2,3])), dim=0)
                    
                        self.datasets[node_idx]["val"]["targets"] = torch.cat((self.datasets[node_idx]["val"]["targets"],
                                                                                 torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
                           
            self.datasets[node_idx]["test"] = { "images": None, "targets": None }

            if node_idx in self.near_node_list:
                for idx, label in enumerate(self.label_list):
                    
                    test_this_label_list = [idx for idx in range(len(self.test_set_targets_list)) if self.test_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(test_this_label_list, self.datasets_info[node_idx]["test"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)
                    
                    if idx == 0:
                        self.datasets[node_idx]["test"]["images"] = self.test_set_data.index_select(dim=0, index=this_label_index_list)
                        self.datasets[node_idx]["test"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["test"]["images"] = torch.cat((self.datasets[node_idx]["test"]["images"],
                                                                                self.test_set_data.index_select(dim=0, index=this_label_index_list)), dim=0)
                        self.datasets[node_idx]["test"]["targets"] = torch.cat((self.datasets[node_idx]["test"]["targets"],
                                                                                 torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            else:
                # rotate data by image for testing set in far nodes
                for idx, label in enumerate(self.label_list):

                    test_this_label_list = [idx for idx in range(len(self.test_set_targets_list)) if self.test_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(test_this_label_list, self.datasets_info[node_idx]["test"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)

                    if idx == 0:

                        data_points = self.test_set_data.index_select(dim=0, index=this_label_index_list)  
                        num_data_points = data_points.size(dim=0)

                        for data_point_idx in range(num_data_points):
                            if data_point_idx == 0:
                                self.datasets[node_idx]["test"]["images"] = torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                         k=rotate_choice, dims=[2,3])
                            else:
                                self.datasets[node_idx]["test"]["images"] = torch.cat((self.datasets[node_idx]["test"]["images"],
                                                                                        torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                                    k=rotate_choice, dims=[2,3])), dim=0)

                        self.datasets[node_idx]["test"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    
                    else:

                        data_points = self.test_set_data.index_select(dim=0, index=this_label_index_list)
                        num_data_points = data_points.size(dim=0)
                        
                        for data_point_idx in range(num_data_points):
                        
                            self.datasets[node_idx]["test"]["images"] = torch.cat((self.datasets[node_idx]["test"]["images"],
                                                                                    torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                                k=rotate_choice, dims=[2,3])), dim=0)
                    
                        self.datasets[node_idx]["test"]["targets"] = torch.cat((self.datasets[node_idx]["test"]["targets"],
                                                                                 torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
           
        for node_idx in range(self.num_nodes):
            
            if node_idx == self.near_center_idx:

                print("Node {} (near center) has training labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))
                print("Node {} (near center) has validation labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))
                print("Node {} (near center) has testing labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))
            
            elif node_idx == self.far_center_idx:

                print("Node {} (far center) has training labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))
                print("Node {} (far center) has validation labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))
                print("Node {} (far center) has testing labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))
            else:

                print("Node {} has training labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))               
                print("Node {} has validation labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))
                print("Node {} has testing labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))
            
            print("-" * 30)
 
    # Setting 4 in dataset generation
    def Dataset_Generating_Label_Permute_and_Rotation(self, seed, label_permute):

        random.seed(a=seed)
        # By torch.rot90(), rotation direction is from the first towards the second axis if rotate_choice == 1, 
        # and from the second towards the first for rotate_choice == -1.
        rotate_choice = random.choice(seq=[-1,1])
        if rotate_choice == 1:
            print("rotation direction is from the first towards the second axis.")
        else:
            print("rotation direction is from the second towards the first axis.")
        print("-" * 30)

        self.datasets = {}
        for node_idx in range(self.num_nodes):

            self.datasets[node_idx] = { "train": { "images": None, "targets": None } }

            if node_idx in self.near_node_list:
                for idx, label in enumerate(self.label_list):
                    
                    train_this_label_list = [idx for idx in range(len(self.train_set_targets_list)) if self.train_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["train"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)
                    
                    if idx == 0:
                        self.datasets[node_idx]["train"]["images"] = self.train_set_data.index_select(dim=0, index=this_label_index_list)
                        self.datasets[node_idx]["train"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["train"]["images"] = torch.cat((self.datasets[node_idx]["train"]["images"],
                                                                                self.train_set_data.index_select(dim=0, index=this_label_index_list)), dim=0)
                        self.datasets[node_idx]["train"]["targets"] = torch.cat((self.datasets[node_idx]["train"]["targets"],
                                                                                 torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            else:
                # rotate data by image and permute labels according to the preset dictionary for training set in far nodes
                for idx, label in enumerate(self.label_list):

                    train_this_label_list = [idx for idx in range(len(self.train_set_targets_list)) if self.train_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["train"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)

                    if idx == 0:

                        data_points = self.train_set_data.index_select(dim=0, index=this_label_index_list)  
                        num_data_points = data_points.size(dim=0)

                        for data_point_idx in range(num_data_points):
                            if data_point_idx == 0:
                                self.datasets[node_idx]["train"]["images"] = torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                         k=rotate_choice, dims=[2,3])
                            else:
                                self.datasets[node_idx]["train"]["images"] = torch.cat((self.datasets[node_idx]["train"]["images"],
                                                                                        torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                                    k=rotate_choice, dims=[2,3])), dim=0)
                        if label in label_permute.keys():
                            self.datasets[node_idx]["train"]["targets"] = torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))
                        else:
                            self.datasets[node_idx]["train"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:

                        data_points = self.train_set_data.index_select(dim=0, index=this_label_index_list)
                        num_data_points = data_points.size(dim=0)
                        
                        for data_point_idx in range(num_data_points):
                        
                            self.datasets[node_idx]["train"]["images"] = torch.cat((self.datasets[node_idx]["train"]["images"],
                                                                                    torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                                k=rotate_choice, dims=[2,3])), dim=0)
                        if label in label_permute.keys():
                            self.datasets[node_idx]["train"]["targets"] = torch.cat((self.datasets[node_idx]["train"]["targets"],
                                                                                     torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))), dim=0)
                        else:
                            self.datasets[node_idx]["train"]["targets"] = torch.cat((self.datasets[node_idx]["train"]["targets"],
                                                                                     torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            self.datasets[node_idx]["val"] = { "images": None, "targets": None }

            if node_idx in self.near_node_list:
                for idx, label in enumerate(self.label_list):
                    
                    train_this_label_list = [idx for idx in range(len(self.train_set_targets_list)) if self.train_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["val"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)
                    
                    if idx == 0:
                        self.datasets[node_idx]["val"]["images"] = self.train_set_data.index_select(dim=0, index=this_label_index_list)
                        self.datasets[node_idx]["val"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["val"]["images"] = torch.cat((self.datasets[node_idx]["val"]["images"],
                                                                                self.train_set_data.index_select(dim=0, index=this_label_index_list)), dim=0)
                        self.datasets[node_idx]["val"]["targets"] = torch.cat((self.datasets[node_idx]["val"]["targets"],
                                                                                 torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            else:
                # rotate data by image and permute labels according to the preset dictionary for validation set in far nodes
                for idx, label in enumerate(self.label_list):

                    train_this_label_list = [idx for idx in range(len(self.train_set_targets_list)) if self.train_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["val"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)

                    if idx == 0:

                        data_points = self.train_set_data.index_select(dim=0, index=this_label_index_list)  
                        num_data_points = data_points.size(dim=0)

                        for data_point_idx in range(num_data_points):
                            if data_point_idx == 0:
                                self.datasets[node_idx]["val"]["images"] = torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                         k=rotate_choice, dims=[2,3])
                            else:
                                self.datasets[node_idx]["val"]["images"] = torch.cat((self.datasets[node_idx]["val"]["images"],
                                                                                        torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                                    k=rotate_choice, dims=[2,3])), dim=0)
                        if label in label_permute.keys():
                            self.datasets[node_idx]["val"]["targets"] = torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))
                        else:
                            self.datasets[node_idx]["val"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:

                        data_points = self.train_set_data.index_select(dim=0, index=this_label_index_list)
                        num_data_points = data_points.size(dim=0)
                        
                        for data_point_idx in range(num_data_points):
                        
                            self.datasets[node_idx]["val"]["images"] = torch.cat((self.datasets[node_idx]["val"]["images"],
                                                                                    torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                                k=rotate_choice, dims=[2,3])), dim=0)
                        if label in label_permute.keys():
                            self.datasets[node_idx]["val"]["targets"] = torch.cat((self.datasets[node_idx]["val"]["targets"],
                                                                                     torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))), dim=0)
                        else:
                            self.datasets[node_idx]["val"]["targets"] = torch.cat((self.datasets[node_idx]["val"]["targets"],
                                                                                     torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)

            self.datasets[node_idx]["test"] = { "images": None, "targets": None }

            if node_idx in self.near_node_list:
                for idx, label in enumerate(self.label_list):
                    
                    test_this_label_list = [idx for idx in range(len(self.test_set_targets_list)) if self.test_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(test_this_label_list, self.datasets_info[node_idx]["test"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)
                    
                    if idx == 0:
                        self.datasets[node_idx]["test"]["images"] = self.test_set_data.index_select(dim=0, index=this_label_index_list)
                        self.datasets[node_idx]["test"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["test"]["images"] = torch.cat((self.datasets[node_idx]["test"]["images"],
                                                                                self.test_set_data.index_select(dim=0, index=this_label_index_list)), dim=0)
                        self.datasets[node_idx]["test"]["targets"] = torch.cat((self.datasets[node_idx]["test"]["targets"],
                                                                                 torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            else:
                # rotate data by image and permute labels according to the preset dictionary for testing set in far nodes
                for idx, label in enumerate(self.label_list):

                    test_this_label_list = [idx for idx in range(len(self.test_set_targets_list)) if self.test_set_targets_list[idx] == label]
                    this_label_index_list = random.sample(test_this_label_list, self.datasets_info[node_idx]["test"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)

                    if idx == 0:

                        data_points = self.test_set_data.index_select(dim=0, index=this_label_index_list)  
                        num_data_points = data_points.size(dim=0)

                        for data_point_idx in range(num_data_points):
                            if data_point_idx == 0:
                                self.datasets[node_idx]["test"]["images"] = torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                         k=rotate_choice, dims=[2,3])
                            else:
                                self.datasets[node_idx]["test"]["images"] = torch.cat((self.datasets[node_idx]["test"]["images"],
                                                                                        torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                                    k=rotate_choice, dims=[2,3])), dim=0)
                        if label in label_permute.keys():
                            self.datasets[node_idx]["test"]["targets"] = torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))
                        else:
                            self.datasets[node_idx]["test"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:

                        data_points = self.test_set_data.index_select(dim=0, index=this_label_index_list)
                        num_data_points = data_points.size(dim=0)
                        
                        for data_point_idx in range(num_data_points):
                        
                            self.datasets[node_idx]["test"]["images"] = torch.cat((self.datasets[node_idx]["test"]["images"],
                                                                                    torch.rot90(input=torch.index_select(input=data_points,dim=0,index=torch.tensor([data_point_idx])), 
                                                                                                k=rotate_choice, dims=[2,3])), dim=0)
                        if label in label_permute.keys():
                            self.datasets[node_idx]["test"]["targets"] = torch.cat((self.datasets[node_idx]["test"]["targets"],
                                                                                     torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))), dim=0)
                        else:
                            self.datasets[node_idx]["test"]["targets"] = torch.cat((self.datasets[node_idx]["test"]["targets"],
                                                                                     torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
           
        for node_idx in range(self.num_nodes):
            
            if node_idx == self.near_center_idx:

                print("Node {} (near center) has training labels distributed by:".format(node_idx), self.datasets_info[node_idx]["train"])
                print("Node {} (near center) has training labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))

                print("Node {} (near center) has validation labels distributed by:".format(node_idx), self.datasets_info[node_idx]["val"])
                print("Node {} (near center) has validation labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))

                print("Node {} (near center) has testing labels distributed by:".format(node_idx), self.datasets_info[node_idx]["test"])
                print("Node {} (near center) has testing labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))
            
            elif node_idx == self.far_center_idx:

                print("Node {} (far center) has training labels distributed by:".format(node_idx), self.datasets_info[node_idx]["train"])
                print("Node {} (far center) has training labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))

                print("Node {} (far center) has validation labels distributed by:".format(node_idx), self.datasets_info[node_idx]["val"])
                print("Node {} (far center) has validation labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))

                print("Node {} (far center) has testing labels distributed by:".format(node_idx), self.datasets_info[node_idx]["test"])
                print("Node {} (far center) has testing labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))
            else:

                print("Node {} has training labels distributed by:".format(node_idx), self.datasets_info[node_idx]["train"])
                print("Node {} has training labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))
                
                print("Node {} has validation labels distributed by:".format(node_idx), self.datasets_info[node_idx]["val"])
                print("Node {} has validation labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))

                print("Node {} has testing labels distributed by:".format(node_idx), self.datasets_info[node_idx]["test"])
                print("Node {} has testing labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))
            
            print("-" * 30)

    def Prepare_Dataloaders(self, num_batch_lb):
       
        self.train_loaders = {}
        self.val_loaders = {}
        self.test_loaders = {}
       
        for node_idx in range(self.num_nodes):

            batch_size = floor(self.datasets_info[node_idx]["train"]["size"] / num_batch_lb)

            self.train_loaders[node_idx] = torch.utils.data.DataLoader(BiasedImageDataset(self.datasets[node_idx]["train"]["images"], 
                                                                                          self.datasets[node_idx]["train"]["targets"]), 
                                                                                          batch_size=batch_size, shuffle=True)

            batch_size = floor(self.datasets_info[node_idx]["val"]["size"] / num_batch_lb)

            self.val_loaders[node_idx] = torch.utils.data.DataLoader(BiasedImageDataset(self.datasets[node_idx]["val"]["images"], 
                                                                                        self.datasets[node_idx]["val"]["targets"]), 
                                                                                        batch_size=batch_size, shuffle=True)
  
            batch_size = floor(self.datasets_info[node_idx]["test"]["size"] / num_batch_lb)

            self.test_loaders[node_idx] = torch.utils.data.DataLoader(BiasedImageDataset(self.datasets[node_idx]["test"]["images"], 
                                                                                         self.datasets[node_idx]["test"]["targets"]), 
                                                                                         batch_size=batch_size, shuffle=True)
     
        return self.train_loaders, self.val_loaders, self.test_loaders
 
class BiasedImageDataset(torch.utils.data.Dataset):

    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.targets[idx]
        return image, label