# Personalized Federated Learning with Moreau Envelopes
# https://proceedings.neurips.cc/paper/2020/file/f4f1f13c8289ac1b1ee0ff176b56fc60-Paper.pdf

import torch
from torch.autograd import Variable


class pFedMeSSLOptimizer:
    def __init__(self, device, local_model, global_model, kwargs):
        self.device = device
        self.local_model = local_model
        self.global_model = global_model

        self.lr = kwargs['lr']
        self.momentum = kwargs['momentum']
        self.wd = kwargs['wd']

        self.g_lambda = kwargs['lambda']
        self.ssl_method = kwargs["ssl_method"]
        self.basic_opt = kwargs['client_optimizer']
        self.accumulation_steps = kwargs['accumulation_steps']

        if self.basic_opt == "sgd":
            self.g_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.lr, momentum=self.momentum,
                                               weight_decay=self.wd)
            self.l_optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.lr, momentum=self.momentum,
                                               weight_decay=self.wd)
        else:
            self.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.global_model.parameters()),
                                                lr=self.lr,
                                                weight_decay=0.001, amsgrad=True)
            self.l_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.local_model.parameters()),
                                                lr=self.lr,
                                                weight_decay=0.001, amsgrad=True)

        self.K = 5

    def step(self, x_accumulator):
        total_loss = 0
        for i in range(self.K):
            for (x1, x2) in x_accumulator:
                x1, x2 = x1.to(self.device), x2.to(self.device)
                z1, z1_, z2, z2_ = self.local_model.predict(x1, x2)
                local_loss = self.local_model.loss_fn(z1, z1_, z2, z2_)
                local_loss = local_loss / self.accumulation_steps

                total_loss += local_loss

                local_loss.backward()

            reg_loss = 0.0
            for (p, g_p) in zip(self.local_model.trainable_parameters(),
                                self.global_model.trainable_parameters()):
                reg_loss += (self.g_lambda * 0.5) * torch.linalg.norm(p - g_p.data) ** 2
            reg_loss.backward()

            self.l_optimizer.step()  # in original implementation, weight decay, mu= 0.001
            self.l_optimizer.zero_grad()

        # outer-level optimization: calculate the gradient w.r.t. global model
        for g, p in zip(self.global_model.trainable_parameters(),
                        self.local_model.trainable_parameters()):
            if g.grad is None:
                g.grad = Variable(self.g_lambda * (g.data - p.data))
            else:
                g.grad.data.copy_(self.g_lambda * (g.data - p.data))
        self.g_optimizer.step()
        self.g_optimizer.zero_grad()

        return total_loss
