import numpy as np
import random
import torch
import collections as ct
from math import floor


# Class for pre-processing dataset
class Data_Preprocess():

    def __init__(self, X_train, y_train, X_test, y_test, label_list_1, label_list_2, label_list_3, label_list_4, 
                 near_center_idx, far_center_idx):
       
        self.X_train = X_train 
        self.y_train = y_train 
        self.y_train_list = list(y_train)
       
        self.X_test = X_test 
        self.y_test = y_test 
        self.y_test_list = list(y_test)

        self.label_list_1 = label_list_1 
        self.label_list_2 = label_list_2 
        self.label_list_3 = label_list_3
        self.label_list_4 = label_list_4

        self.label_list = self.label_list_1 + self.label_list_2 + self.label_list_3 + self.label_list_4

        self.near_center_idx = near_center_idx
        self.far_center_idx = far_center_idx
   
    def Decide_Dataset_Sizes(self, num_nodes, train_size, val_size, test_size):
        
        self.num_nodes = num_nodes 

        self.train_size = train_size 
        self.val_size = val_size 
        self.test_size = test_size 
        
        self.datasets_info = {}
        for node_idx in range(num_nodes):
            self.datasets_info[node_idx] = {"train": { "size": self.train_size } }
            self.datasets_info[node_idx]["val"] = { "size": self.val_size }
            self.datasets_info[node_idx]["test"] = { "size": self.test_size }
           
        for node_idx in range(num_nodes):
            if node_idx == self.near_center_idx:
                print("Node {} (near center) has {} training instances, {} validation instances and {} testing instances.".format(node_idx,
                                                                                                                         self.datasets_info[node_idx]["train"]["size"],
                                                                                                                         self.datasets_info[node_idx]["val"]["size"],
                                                                                                                         self.datasets_info[node_idx]["test"]["size"]))
            elif node_idx == self.far_center_idx:
                print("Node {} (far center) has {} training instances, {} validation instances and {} testing instances.".format(node_idx,
                                                                                                                        self.datasets_info[node_idx]["train"]["size"],
                                                                                                                        self.datasets_info[node_idx]["val"]["size"],
                                                                                                                        self.datasets_info[node_idx]["test"]["size"]))
            else:
                print("Node {} has {} training instances, {} validation instances and {} testing instances.".format(node_idx,
                                                                                                           self.datasets_info[node_idx]["train"]["size"],
                                                                                                           self.datasets_info[node_idx]["val"]["size"],
                                                                                                           self.datasets_info[node_idx]["test"]["size"]))
            print("-" * 30)
           
    def Resample_Data(self, near_node_list, seed, near_p_dict, far_p_dict, margin):

        np.random.seed(seed=seed)
        self.near_node_list = near_node_list

        for node_idx in range(self.num_nodes):
            if node_idx in self.near_node_list:
                self.Data_Size_Generating(node_idx=node_idx, indication="train", p_dict=near_p_dict, margin=margin)
                self.Data_Size_Generating(node_idx=node_idx, indication="val", p_dict=near_p_dict, margin=margin)
                self.Data_Size_Generating(node_idx=node_idx,  indication="test", p_dict=near_p_dict, margin=margin)
            else:
                self.Data_Size_Generating(node_idx=node_idx, indication="train", p_dict=far_p_dict, margin=margin)
                self.Data_Size_Generating(node_idx=node_idx, indication="val", p_dict=far_p_dict, margin=margin)
                self.Data_Size_Generating(node_idx=node_idx,  indication="test", p_dict=far_p_dict, margin=margin)

        for node_idx in range(self.num_nodes):
            if node_idx == self.near_center_idx:
                print("Node {} (near center) has training labels distributed by:".format(node_idx), self.datasets_info[node_idx]["train"])
                print("Node {} (near center) has validation labels distributed by:".format(node_idx), self.datasets_info[node_idx]["val"])
                print("Node {} (near center) has testing labels distributed by:".format(node_idx), self.datasets_info[node_idx]["test"])
            elif node_idx == self.far_center_idx:
                print("Node {} (far center) has training labels distributed by:".format(node_idx), self.datasets_info[node_idx]["train"])
                print("Node {} (far center) has validation labels distributed by:".format(node_idx), self.datasets_info[node_idx]["val"])
                print("Node {} (far center) has testing labels distributed by:".format(node_idx), self.datasets_info[node_idx]["test"])
            else:
                print("Node {} has training labels distributed by:".format(node_idx), self.datasets_info[node_idx]["train"])
                print("Node {} has validation labels distributed by:".format(node_idx), self.datasets_info[node_idx]["val"])
                print("Node {} has testing labels distributed by:".format(node_idx), self.datasets_info[node_idx]["test"])
            print("-" * 30)

    def Data_Size_Generating(self, node_idx, indication, p_dict, margin):
        
        # get probabilities for 4 parts in the label list
        label_list_1_p = p_dict["label_list_1"]
        label_list_2_p = p_dict["label_list_2"]
        label_list_3_p = p_dict["label_list_3"]
        label_list_4_p = p_dict["label_list_4"]
        
        # get dataset size
        size = self.datasets_info[node_idx][indication]["size"]
            
        label_list_1_size = int(round(size*label_list_1_p, ndigits=0)) # round the number and change it to integer type
        partial_sum = 0
        for idx, label in enumerate(self.label_list_1):
            # if this is the last label in the sublist, get number of data on this label 
            # by subtracting the partial sum ftom sublist size
            if idx == len(self.label_list_1) - 1:
                self.datasets_info[node_idx][indication][label] = label_list_1_size - partial_sum  
                continue

            # each label approximates the mean
            label_size = label_list_1_size/len(self.label_list_1) + np.random.uniform(low=-margin, high=margin)
            label_size = int(round(label_size, ndigits=0))

            self.datasets_info[node_idx][indication][label] = label_size
            # add the size to the partial sum
            partial_sum += label_size
        
        label_list_2_size = int(round(size*label_list_2_p, ndigits=0))
        partial_sum = 0
        for idx, label in enumerate(self.label_list_2):
            if idx == len(self.label_list_2) - 1:
                self.datasets_info[node_idx][indication][label] = label_list_2_size - partial_sum  
                continue
            label_size = label_list_2_size/len(self.label_list_2) + np.random.uniform(low=-margin, high=margin)
            label_size = int(round(label_size, ndigits=0))
            self.datasets_info[node_idx][indication][label] = label_size
            partial_sum += label_size
        
        label_list_3_size = int(round(size*label_list_3_p, ndigits=0))
        partial_sum = 0
        for idx, label in enumerate(self.label_list_3):
            if idx == len(self.label_list_3) - 1:
                self.datasets_info[node_idx][indication][label] = label_list_3_size - partial_sum  
                continue
            label_size = label_list_3_size/len(self.label_list_3) + np.random.uniform(low=-margin, high=margin)
            label_size = int(round(label_size, ndigits=0))
            self.datasets_info[node_idx][indication][label] = label_size
            partial_sum += label_size
        
        label_list_4_size = size - label_list_1_size - label_list_2_size - label_list_3_size
        partial_sum = 0
        for idx, label in enumerate(self.label_list_4):
            if idx == len(self.label_list_4) - 1:
                self.datasets_info[node_idx][indication][label] = label_list_4_size - partial_sum  
                continue
            label_size = label_list_4_size/len(self.label_list_4) + np.random.uniform(low=-margin, high=margin)
            label_size = int(round(label_size, ndigits=0))
            self.datasets_info[node_idx][indication][label] = label_size
            partial_sum += label_size

    # Setting 1 in dataset generation
    def Dataset_Generating(self, seed):

        random.seed(a=seed)
        self.datasets = {}
        for node_idx in range(self.num_nodes):

            self.datasets[node_idx] = { "train": { "instances": None, "targets": None } }
     
            for idx, label in enumerate(self.label_list):

                train_this_label_list = [idx for idx in range(len(self.y_train_list)) if self.y_train_list[idx] == label]
                this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["train"][label])
                this_label_index_list = torch.LongTensor(this_label_index_list)

                if idx == 0:
                    self.datasets[node_idx]["train"]["instances"] = self.X_train.index_select(dim=0, index=this_label_index_list)
                    self.datasets[node_idx]["train"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                else:
                    self.datasets[node_idx]["train"]["instances"] = torch.cat((self.datasets[node_idx]["train"]["instances"],
                                                                            self.X_train.index_select(dim=0, index=this_label_index_list)), dim=0)
                    self.datasets[node_idx]["train"]["targets"] = torch.cat((self.datasets[node_idx]["train"]["targets"],
                                                                            torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            self.datasets[node_idx]["val"] = { "instances": None, "targets": None }

            for idx, label in enumerate(self.label_list):

                train_this_label_list = [idx for idx in range(len(self.y_train_list)) if self.y_train_list[idx] == label]
                this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["val"][label])
                this_label_index_list = torch.LongTensor(this_label_index_list)

                if idx == 0:
                    self.datasets[node_idx]["val"]["instances"] = self.X_train.index_select(dim=0, index=this_label_index_list)
                    self.datasets[node_idx]["val"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                else:
                    self.datasets[node_idx]["val"]["instances"] = torch.cat((self.datasets[node_idx]["val"]["instances"],
                                                                          self.X_train.index_select(dim=0, index=this_label_index_list)), dim=0)
                    self.datasets[node_idx]["val"]["targets"] = torch.cat((self.datasets[node_idx]["val"]["targets"],
                                                                           torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
        
            self.datasets[node_idx]["test"] = { "instances": None, "targets": None }

            for idx, label in enumerate(self.label_list):

                test_this_label_list = [idx for idx in range(len(self.y_test_list)) if self.y_test_list[idx] == label]
                this_label_index_list = random.sample(test_this_label_list, self.datasets_info[node_idx]["test"][label])
                this_label_index_list = torch.LongTensor(this_label_index_list)
                if idx == 0:
                    self.datasets[node_idx]["test"]["instances"] = self.X_test.index_select(dim=0, index=this_label_index_list)
    
                    self.datasets[node_idx]["test"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                else:
                    self.datasets[node_idx]["test"]["instances"] = torch.cat((self.datasets[node_idx]["test"]["instances"],
                                                                           self.X_test.index_select(dim=0, index=this_label_index_list)), dim=0)
        
                    self.datasets[node_idx]["test"]["targets"] = torch.cat((self.datasets[node_idx]["test"]["targets"],
                                                                            torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
           
        for node_idx in range(self.num_nodes):
            
            if node_idx == self.near_center_idx:

                print("Node {} (near center) has training labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))
                print("Node {} (near center) has validation labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))
                print("Node {} (near center) has testing labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))
            
            elif node_idx == self.far_center_idx:
                
                print("Node {} (far center) has training labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))
                print("Node {} (far center) has validation labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))
                print("Node {} (far center) has testing labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))

            else:

                print("Node {} has training labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))
                print("Node {} has validation labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))
                print("Node {} has testing labels distributed by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))
            
            print("-" * 30)

    # Setting 2 in dataset generation
    def Dataset_Generating_and_Label_Permute(self, seed, label_permute):

        random.seed(a=seed)
        self.datasets = {}
        for node_idx in range(self.num_nodes):

            self.datasets[node_idx] = { "train": { "instances": None, "targets": None } }

            if node_idx in self.near_node_list:

                for idx, label in enumerate(self.label_list):

                    train_this_label_list = [idx for idx in range(len(self.y_train_list)) if self.y_train_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["train"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)
                    
                    if idx == 0:
                        self.datasets[node_idx]["train"]["instances"] = self.X_train.index_select(dim=0, index=this_label_index_list)
                        self.datasets[node_idx]["train"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["train"]["instances"] = torch.cat((self.datasets[node_idx]["train"]["instances"],
                                                                                self.X_train.index_select(dim=0, index=this_label_index_list)), dim=0)
                        self.datasets[node_idx]["train"]["targets"] = torch.cat((self.datasets[node_idx]["train"]["targets"],
                                                                                 torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            else:
                # permute labels according to the preset dictionary for training set in far nodes
                for idx, label in enumerate(self.label_list):

                    train_this_label_list = [idx for idx in range(len(self.y_train_list)) if self.y_train_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["train"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)

                    if idx == 0:
                        self.datasets[node_idx]["train"]["instances"] = self.X_train.index_select(dim=0, index=this_label_index_list)

                        if label in label_permute.keys():
                            self.datasets[node_idx]["train"]["targets"] = torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))
                        else:
                            self.datasets[node_idx]["train"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["train"]["instances"] = torch.cat((self.datasets[node_idx]["train"]["instances"],
                                                                                self.X_train.index_select(dim=0, index=this_label_index_list)), dim=0)

                        if label in label_permute.keys():
                            self.datasets[node_idx]["train"]["targets"] = torch.cat((self.datasets[node_idx]["train"]["targets"],
                                                                                     torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))), dim=0)
                        else:
                            self.datasets[node_idx]["train"]["targets"] = torch.cat((self.datasets[node_idx]["train"]["targets"],
                                                                                    torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)

            self.datasets[node_idx]["val"] = { "instances": None, "targets": None }

            if node_idx in self.near_node_list:
                
                for idx, label in enumerate(self.label_list):

                    train_this_label_list = [idx for idx in range(len(self.y_train_list)) if self.y_train_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["val"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)
                    
                    if idx == 0:
                        self.datasets[node_idx]["val"]["instances"] = self.X_train.index_select(dim=0, index=this_label_index_list)
                        self.datasets[node_idx]["val"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["val"]["instances"] = torch.cat((self.datasets[node_idx]["val"]["instances"],
                                                                                self.X_train.index_select(dim=0, index=this_label_index_list)), dim=0)
                        self.datasets[node_idx]["val"]["targets"] = torch.cat((self.datasets[node_idx]["val"]["targets"],
                                                                                 torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            else:
                # permute labels according to the preset dictionary for validation set in far nodes
                for idx, label in enumerate(self.label_list):

                    train_this_label_list = [idx for idx in range(len(self.y_train_list)) if self.y_train_list[idx] == label]
                    this_label_index_list = random.sample(train_this_label_list, self.datasets_info[node_idx]["val"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)

                    if idx == 0:
                        self.datasets[node_idx]["val"]["instances"] = self.X_train.index_select(dim=0, index=this_label_index_list)

                        if label in label_permute.keys():
                            self.datasets[node_idx]["val"]["targets"] = torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))
                        else:
                            self.datasets[node_idx]["val"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["val"]["instances"] = torch.cat((self.datasets[node_idx]["val"]["instances"],
                                                                              self.X_train.index_select(dim=0, index=this_label_index_list)), dim=0)

                        if label in label_permute.keys():
                            self.datasets[node_idx]["val"]["targets"] = torch.cat((self.datasets[node_idx]["val"]["targets"],
                                                                                   torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))), dim=0)
                        else:
                            self.datasets[node_idx]["val"]["targets"] = torch.cat((self.datasets[node_idx]["val"]["targets"],
                                                                                   torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            self.datasets[node_idx]["test"] = { "instances": None, "targets": None }

            if node_idx in self.near_node_list:
                
                for idx, label in enumerate(self.label_list):

                    test_this_label_list = [idx for idx in range(len(self.y_test_list)) if self.y_test_list[idx] == label]
                    this_label_index_list = random.sample(test_this_label_list, self.datasets_info[node_idx]["test"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)
                    
                    if idx == 0:
                        self.datasets[node_idx]["test"]["instances"] = self.X_test.index_select(dim=0, index=this_label_index_list)
                        self.datasets[node_idx]["test"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["test"]["instances"] = torch.cat((self.datasets[node_idx]["test"]["instances"],
                                                                                self.X_test.index_select(dim=0, index=this_label_index_list)), dim=0)
                        self.datasets[node_idx]["test"]["targets"] = torch.cat((self.datasets[node_idx]["test"]["targets"],
                                                                                 torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
            
            else:
                # permute labels according to the preset dictionary for testing set in far nodes
                for idx, label in enumerate(self.label_list):

                    test_this_label_list = [idx for idx in range(len(self.y_test_list)) if self.y_test_list[idx] == label]
                    this_label_index_list = random.sample(test_this_label_list, self.datasets_info[node_idx]["test"][label])
                    this_label_index_list = torch.LongTensor(this_label_index_list)

                    if idx == 0:
                        self.datasets[node_idx]["test"]["instances"] = self.X_test.index_select(dim=0, index=this_label_index_list)

                        if label in label_permute.keys():
                            self.datasets[node_idx]["test"]["targets"] = torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))
                        else:
                            self.datasets[node_idx]["test"]["targets"] = torch.LongTensor([label]).repeat(len(this_label_index_list))
                    else:
                        self.datasets[node_idx]["test"]["instances"] = torch.cat((self.datasets[node_idx]["test"]["instances"],
                                                                              self.X_test.index_select(dim=0, index=this_label_index_list)), dim=0)

                        if label in label_permute.keys():
                            self.datasets[node_idx]["test"]["targets"] = torch.cat((self.datasets[node_idx]["test"]["targets"],
                                                                                   torch.LongTensor([label_permute[label]]).repeat(len(this_label_index_list))), dim=0)
                        else:
                            self.datasets[node_idx]["test"]["targets"] = torch.cat((self.datasets[node_idx]["test"]["targets"],
                                                                                   torch.LongTensor([label]).repeat(len(this_label_index_list))), dim=0)
           
        for node_idx in range(self.num_nodes):
            
            if node_idx == self.near_center_idx:

                print("Node {} (near center) has training labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["train"])
                print("Node {} (near center) has training labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))

                print("Node {} (near center) has validation labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["val"])
                print("Node {} (near center) has validation labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))

                print("Node {} (near center) has testing labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["test"])
                print("Node {} (near center) has testing labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))
            
            elif node_idx == self.far_center_idx:

                print("Node {} (far center) has training labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["train"])
                print("Node {} (far center) has training labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))

                print("Node {} (far center) has validation labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["val"])
                print("Node {} (far center) has validation labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))

                print("Node {} (far center) has testing labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["test"])
                print("Node {} (far center) has testing labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))

            else:

                print("Node {} has training labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["train"])
                print("Node {} has training labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["train"]["targets"].tolist()))
                
                print("Node {} has validation labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["val"])
                print("Node {} has validation labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["val"]["targets"].tolist()))

                print("Node {} has testing labels originally distributed by:".format(node_idx), self.datasets_info[node_idx]["test"])
                print("Node {} has testing labels permuted by:".format(node_idx), ct.Counter(self.datasets[node_idx]["test"]["targets"].tolist()))
            
            print("-" * 30)

    def Prepare_Dataloaders(self, num_batch_lb):
       
        self.train_loaders = {}
        self.val_loaders = {}
        self.test_loaders = {}
       
        for node_idx in range(self.num_nodes):

            batch_size = floor(self.datasets_info[node_idx]["train"]["size"] / num_batch_lb)

            self.train_loaders[node_idx] = torch.utils.data.DataLoader(BiasedinstanceDataset(self.datasets[node_idx]["train"]["instances"], 
                                                                                             self.datasets[node_idx]["train"]["targets"]), 
                                                                                             batch_size=batch_size, shuffle=True)

            batch_size = floor(self.datasets_info[node_idx]["val"]["size"] / num_batch_lb)

            self.val_loaders[node_idx] = torch.utils.data.DataLoader(BiasedinstanceDataset(self.datasets[node_idx]["val"]["instances"], 
                                                                                           self.datasets[node_idx]["val"]["targets"]), 
                                                                                           batch_size=batch_size, shuffle=True)
  
            batch_size = floor(self.datasets_info[node_idx]["test"]["size"] / num_batch_lb)

            self.test_loaders[node_idx] = torch.utils.data.DataLoader(BiasedinstanceDataset(self.datasets[node_idx]["test"]["instances"], 
                                                                                            self.datasets[node_idx]["test"]["targets"]), 
                                                                                            batch_size=batch_size, shuffle=True)
     
        return self.train_loaders, self.val_loaders, self.test_loaders
 
class BiasedinstanceDataset(torch.utils.data.Dataset):

    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        instance = self.data[idx]
        label = self.targets[idx]
        return instance, label