import torch

class TorchServer(object):
    """A param server for distributed training.
    Apply param update.
    """
    def __init__(self,
                 optimizer: torch.optim.Optimizer,
                 model: torch.nn.Module,
                 clipping: bool,
                 ):
        self.optimizer = optimizer
        self.model = model
        self.clipping = clipping

    def apply_gradient(self) -> None:
        if self.clipping:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clipping)
        self.optimizer.step()

    def set_gradient(self, gradient: torch.Tensor) -> None:
        beg = 0
        for group in self.optimizer.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                # for p in self.model.parameters():
                end = beg + len(p.grad.view(-1))
                x = gradient[beg:end].reshape_as(p.grad.data)
                p.grad.data = x.clone().detach()
                beg = end