#! -*- coding: utf-8
import typing

import torch

from .afedopt import AFedOptClient, AFedOptServer

__all__ = ["FedAdagradClient", "FedAdagradServer"]


class FedAdagradClient(AFedOptClient):
    pass  # no implements.


class FedAdagradServer(AFedOptServer):
    def update_velocity(self, deltas: typing.Dict[torch.Tensor, torch.Tensor]):
        # update velocity from delta.
        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad or not p in deltas:
                    continue

                state, delta = self.state[p], deltas[p]
                if not "velocity" in state:
                    state["velocity"] = torch.zeros_like(p)
                velocity = state["velocity"]
                velocity.add_(delta**2, alpha=1.0)
