from misc.utils import *
from torch.utils.data.sampler import SubsetRandomSampler
from split_data import data_load


class SubsetSequentialSampler(torch.utils.data.Sampler):
    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return (int(self.indices[i]) for i in range(len(self.indices)))

    def __len__(self):
        return len(self.indices)

class DataLoader:
    def __init__(self, args):
        self.args = args
        self.n_workers = 1
        self.client_id = None

        from torch.utils.data import DataLoader
        self.DataLoader = DataLoader
        self.client_pa, self.train_size_all = {}, {}
        self.test_pa = None
        self.all_data = None
        self.get_all_data()

    def switch(self, client_id):
        if not self.client_id == client_id:
            self.client_id = client_id
            self.pa_loader = self.client_pa[client_id]
            self.train_size = self.train_size_all[client_id]

    def get_all_data(self):
        dataidxs, nlp_dataset = data_load(dataset=self.args.dataset, client_num=self.args.n_clients)
        train_dataset = nlp_dataset["train"]
        for client_id in range(self.args.n_clients):
            client_idx = dataidxs[client_id]
            train_dl = self.DataLoader(train_dataset, batch_size=self.args.batch_size,
                                       sampler=SubsetRandomSampler(client_idx),
                                       pin_memory=True)
            self.client_pa[client_id] = train_dl
            self.train_size_all[client_id] = len(client_idx)
        test_dataset = nlp_dataset["test"]
        test_dl = self.DataLoader(test_dataset, batch_size=self.args.batch_size, pin_memory=True)
        self.test_pa = test_dl
