class LocalUpdater(object):

    def __init__(self, args, dataset, idxs, p):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True, num_workers=args.num_workers)
        self.cds = len(idxs)  # client data size (|D_i|)
        self.p = p  # 1 - p_{drop}
        
    def train(self, post_phi, local_part, lr, m0_vec=None, V0_vec=None, local_eps=None):

        net = get_model(args)
        with torch.no_grad():
            for param, src in zip(net.parameters(), post_phi['m0'].parameters()):
                param.copy_(src)
        
        if m0_vec is None:
            m0_vec = weights2vec(post_phi['m0'], local_part)
        if V0_vec is None:
            V0_vec = weights2vec(post_phi['V0'], local_part)

        net.train()
        body_params = [p for name, p in net.named_parameters() if 'linear' not in name]
        head_params = [p for name, p in net.named_parameters() if 'linear' in name]
        optimizer = torch.optim.SGD([{'params': body_params, 'lr': body_lr}, {'params': head_params, 'lr': head_lr}], momentum=self.args.momentum, weight_decay=self.args.wd)

        if local_eps is None:
            local_eps = self.args.local_ep
        
        for ep in range(local_eps):
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                logits = net(images)
                loss_nll = self.loss_func(logits, labels)
                loss_nll.backward()
                m_vec = weights2vec(net, local_part).requires_grad_(True)
                loss_reg = (0.5*self.p/self.cds) * (((m_vec-m0_vec)**2) / (V0_vec/(post_phi['n0']+post_phi['d']+1))).sum()
                dloss_reg = torch.autograd.grad(loss_reg, m_vec, retain_graph=False, allow_unused=True)[0]
                idx0 = 0
                for name, param in net.named_parameters():
                    if local_part == 'body' and 'linear' in name:
                        continue
                    if local_part == 'head' and 'linear' not in name:
                        continue
                    shape = list(param.shape)
                    idx1 = idx0 + np.prod(shape)
                    if param.grad is not None:
                        param.grad.add_(torch.reshape(dloss_reg[idx0:idx1], shape))
                    idx0 = idx1
                optimizer.step()
                
        return net
