#! -*- coding: utf-8
import typing

import torch

from .afedopt import AFedOptClient, AFedOptServer

__all__ = ["FedAdamClient", "FedAdamServer"]


class FedAdamClient(AFedOptClient):
    pass  # no implements.


class FedAdamServer(AFedOptServer):
    def update_velocity(self, deltas: typing.Dict[torch.Tensor, torch.Tensor]):
        # update velocity from delta.
        for group in self.param_groups:
            beta2 = group["beta2"]
            for p in group["params"]:
                if not p.requires_grad or not p in deltas:
                    continue

                state, delta = self.state[p], deltas[p]
                if not "velocity" in state:
                    state["velocity"] = torch.zeros_like(p)
                velocity = state["velocity"]
                velocity.mul_(beta2).add_(delta**2, alpha=(1.0-beta2))
