import torch


class perFedSimSiamOpt:
    def __init__(self, device, local_model, global_model, kwargs):
        self.device = device
        self.local_model = local_model
        self.global_model = global_model
        self.lr = kwargs['lr']
        self.momentum = kwargs['momentum']
        self.wd = kwargs['wd']
        self.g_lambda = kwargs['lambda']
        self.ssl_method = kwargs["ssl_method"]
        self.basic_opt = kwargs['client_optimizer']
        self.accumulation_steps = kwargs['accumulation_steps']

        if self.basic_opt == "sgd":
            self.g_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.lr, momentum=self.momentum,
                                               weight_decay=self.wd)
            self.l_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.lr, momentum=self.momentum,
                                               weight_decay=self.wd)
        else:
            self.g_optimizer = torch.optim.Adam(self.global_model.parameters(), lr=self.lr, weight_decay=0.001,
                                                amsgrad=True)
            self.l_optimizer = torch.optim.Adam(self.local_model.parameters(), lr=self.lr, weight_decay=0.001,
                                                amsgrad=True)
        self.is_using_global_simsiam = False

    def step(self, x_accumulator):
        # global optimization: fedavg
        if self.is_using_global_simsiam:
            for (x1, x2) in x_accumulator:
                x1, x2 = x1.to(self.device), x2.to(self.device)
                p1, z2, p2, z1 = self.global_model.predict(x1, x2)
                global_loss = self.global_model.loss_fn(p1, z2, p2, z1)
                global_loss = global_loss / self.accumulation_steps
                global_loss.backward()
            self.g_optimizer.step()
            self.g_optimizer.zero_grad()

        # local optimization
        total_loss = 0.0
        for (x1, x2) in x_accumulator:
            x1, x2 = x1.to(self.device), x2.to(self.device)
            p1_l, z2_l, p2_l, z1_l = self.local_model.predict(x1, x2)
            local_loss = self.local_model.loss_fn(p1_l, z2_l, p2_l, z1_l)

            p1_g, z2_g, p2_g, z1_g = self.global_model.predict(x1, x2)
            local_loss += self.global_model.loss_fn(p1_g, z2_l, self.global_model.predictor(z2_l), z1_g)
            local_loss += self.global_model.loss_fn(self.global_model.predictor(z1_l), z2_g, p2_g, z1_l)

            local_loss = local_loss / self.accumulation_steps

            total_loss += local_loss

            local_loss.backward()

        self.l_optimizer.step()
        self.l_optimizer.zero_grad()
        self.g_optimizer.step()
        self.g_optimizer.zero_grad()
        return total_loss
