# ditto

import torch
import torch.nn as nn


class DittoSSLOptimizer:
    def __init__(self, device, local_model, global_model, kwargs):
        self.device = device
        self.local_model = local_model
        self.global_model = global_model
        self.lr = kwargs['lr']
        self.g_lambda = kwargs['lambda']
        self.basic_opt = kwargs['client_optimizer']
        self.accumulation_steps = kwargs['accumulation_steps']

        if self.basic_opt == "sgd":
            self.g_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.lr, weight_decay=0.001)
            self.l_optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.lr,
                                               weight_decay=0.001)  # TODO(ZY): a local opt does not exist
        else:
            # TODO(ZY): Why add this filter, also why there is parameter without grad?
            self.g_optimizer = torch.optim.Adam(self.global_model.parameters(), lr=self.lr, weight_decay=0.001,
                                                amsgrad=True)
            self.l_optimizer = torch.optim.Adam(self.local_model.parameters(), lr=self.lr, weight_decay=0.001,
                                                amsgrad=True)

        self.criterion = nn.CrossEntropyLoss().to(device)

    def step(self, x_accumulator):

        # fedavg
        for (x, labels) in x_accumulator:
            x = x.to(self.device)
            labels = labels.to(self.device)
            logits = self.global_model(x)
            global_loss = self.criterion(logits, labels)
            global_loss = global_loss / self.accumulation_steps
            global_loss.backward()
        self.g_optimizer.step()
        self.g_optimizer.zero_grad()

        # outer optimization (optimize personalized model)
        total_loss = 0.0
        for (x, labels) in x_accumulator:
            x = x.to(self.device)
            labels = labels.to(self.device)
            logits = self.local_model(x)
            local_loss = self.criterion(logits, labels)
            local_loss = local_loss / self.accumulation_steps

            total_loss += local_loss

            local_loss.backward()

        reg_loss = 0.0
        local_model_params = [p for p in self.local_model.parameters() if p.requires_grad]
        global_model_params = [p for p in self.global_model.parameters() if p.requires_grad]
        for (p, g_p) in zip(local_model_params,
                            global_model_params):
            reg_loss += (self.g_lambda * 0.5) * torch.linalg.norm(p - g_p.data) ** 2
        reg_loss.backward()

        self.l_optimizer.step()
        self.l_optimizer.zero_grad()

        return total_loss
