import torch.optim as optim


class Scaffold_Optimizer(optim.Optimizer):
    def __init__(self, params, method, learning_rate, _step=0):
        super(Scaffold_Optimizer, self).__init__(params, {})
        self.params = params
        self.learning_rate = learning_rate
        self.method = method

    def set_parameters(self, params):
        self.params = []
        self.sparse_params = []
        for k, p in params:
            if p.requires_grad:
                if self.method != 'sparseadam' or "embed" not in k:
                    self.params.append(p)
                else:
                    self.sparse_params.append(p)

        self.optimizer = optim.SGD(self.params, lr=self.learning_rate)

    def step(self, device, server_controls, client_controls, closure=False):
        # self._step += 1
        for group in self.param_groups:
            for p,c,ci in zip(group['params'], server_controls.values(), client_controls.values()):
                if p.grad is None:
                    continue
                #本地更新
                #y_i=y_i - lr * (g(y_i) + c - ci)
                #p表示y_i,即本地模型的参数
                # c.data = c.data.to("cuda:0")
                # ci.data = ci.data.to("cuda:0")
                c.data = c.data.to(device)
                ci.data = ci.data.to(device)
                dp = p.grad.data + c.data - ci.data
                p.data = p.data - dp.data * self.learning_rate

                c.data = c.data.cpu()
                ci.data = ci.data.cpu()


class Ditto_local_Optimizer(optim.Optimizer):
    def __init__(self, params, learning_rate, ditto_lambda):
        self.params = params
        self.learning_rate = learning_rate
        self.start_decay = False
        self.ditto_lambda = ditto_lambda
        super(Ditto_local_Optimizer, self).__init__(params, {})


    def step(self, updated_global_model_params, closure=None):
        for group in self.param_groups:
            for p, g  in zip(group['params'], updated_global_model_params):
                if p.grad is None:
                    continue;
                # v_k = v_k - η(grad(v_k) + λ(v_K - w^t))
                p.data = p.data - self.learning_rate * (p.grad.data + self.ditto_lambda * (p.data - g.data))

