#! -*- coding: utf-8
import typing

import torch

from .afedopt import AFedOptClient, AFedOptServer

__all__ = ["FedYogiClient", "FedYogiServer"]


class FedYogiClient(AFedOptClient):
    pass  # no implements.


class FedYogiServer(AFedOptServer):
    def update_velocity(self, deltas: typing.Dict[torch.Tensor, torch.Tensor]):
        # update velocity from delta.
        for group in self.param_groups:
            beta2 = group["beta2"]
            for p in group["params"]:
                if not p.requires_grad or not p in deltas:
                    continue

                state, delta_sq = self.state[p], deltas[p]**2
                if not "velocity" in state:
                    state["velocity"] = torch.zeros_like(p)
                velocity = state["velocity"]
                velocity.data.add_(delta_sq * torch.sign(p.data - delta_sq),
                                   alpha=-(1.0 - beta2))
