from misc.utils import *
import global_var as gvr


class DataLoader:
    def __init__(self, args):
        self.args = args
        self.n_workers = 1
        self.client_id = None

        if self.args.model in gvr.HYP_METHODS:
            self.init_k = {}
            for i in range(self.args.n_clients):
                self.init_k[i] = args.norm_frc_list[i]

        from torch_geometric.loader import DataLoader
        self.DataLoader = DataLoader


    def switch(self, client_id): 
        if not self.client_id == client_id:
            self.client_id = client_id
            self.partition = get_data(self.args, client_id=client_id)
            if self.args.model in gvr.HYP_METHODS:
                self.k = self.init_k[client_id]
                
            self.pa_loader = self.DataLoader(dataset=self.partition, batch_size=1, 
                shuffle=False, num_workers=self.n_workers, pin_memory=False)


def get_data(args, client_id):
    return [
        torch_load(
            args.data_path, 
            f'{args.dataset}_{args.mode}/{args.n_clients}/partition_{client_id}.pt'
        )['client_data']
    ]




