import torch


# Notes:
# 1. A global coin is tossed every step with prob p to be head and (1-p) to be tail.
# 2. All clients must participate in the optimization in every round
# 3. All the clients go exactly 1 step of SGD when the coin is head
# 4. The server sends the average of all clients' models to each client when the coin is tail

class L2GD:
    def __init__(self, device, model, avg_model, kwargs):
        self.device = device
        self.model = model
        self.avg_model = avg_model

        self.lr = kwargs['lr']
        self.g_lambda = kwargs['lambda']
        self.accumulation_steps = kwargs['accumulation_steps']
        self.num_clients = kwargs['num_clients']
        self.global_coin_prob = kwargs['global_coin_prob']

        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr / self.num_clients / (1 - self.global_coin_prob), weight_decay=0.001)

    def sync_avg_model(self):
        """
        sync self.avg_model from server
        Returns:

        """
        pass

    def sync_coin(self):
        """
        sync the coin tossing result from the server
        Returns:

        """
        return True

    def step(self, x_accumulator):
        if self.sync_coin():
            self.optimizer.zero_grad()

            total_loss = 0
            for x1, x2 in x_accumulator:
                x1, x2 = x1.to(self.device), x2.to(self.device)
                p1, z2, p2, z1 = self.model.predict(x1, x2)
                loss = self.model.loss_fn(p1, z2, p2, z1)
                loss = loss / self.accumulation_steps

                total_loss += loss

                loss.backward()
            self.optimizer.step()

            return total_loss
        else:
            self.sync_avg_model()

            ratio = self.lr * self.g_lambda / self.num_clients / self.global_coin_prob
            for p, a_p in zip(self.model.parameters(), self.avg_model.parameters()):
                p.data = (1 - ratio) * p.data + ratio * a_p.data

            return 0
