import torch


class FedAvgSSLOptimizer:
    def __init__(self, device, global_model, kwargs):
        self.device = device
        self.global_model = global_model

        self.lr = kwargs['lr']
        self.g_lambda = kwargs['lambda']
        self.momentum = kwargs['momentum']
        self.wd = kwargs['wd']
        self.accumulation_steps = kwargs['accumulation_steps']

        self.optimizer = torch.optim.SGD(self.global_model.parameters(),
                                         lr=self.lr,
                                         momentum=self.momentum,
                                         weight_decay=self.wd)

    def step(self, batch_idx, x1, x2):
        z1, z1_, z2, z2_ = self.global_model.predict(x1, x2)
        global_loss = self.global_model.loss_fn(z1, z1_, z2, z2_)
        global_loss = global_loss / self.accumulation_steps
        global_loss.backward()
        if (batch_idx + 1) % self.accumulation_steps == 0:
            self.optimizer.step()
            self.optimizer.zero_grad()
        return global_loss * self.accumulation_steps
