import torch


class TorchServer(object):
    def __init__(self, optimizer: torch.optim.Optimizer, model):
        self.optimizer = optimizer
        self.model = model

    def apply_gradient(self) -> None:
        self.optimizer.step()
        self.optimizer.zero_grad()

    def set_gradient(self, gradient: torch.Tensor) -> None:
        beg = 0
        for p in self.model.parameters():
            p.grad = torch.ones_like(p.data)

            end = beg + len(p.view(-1))
            x = gradient[beg:end].reshape_as(p.grad.data)
            p.grad.data = x.clone().detach()
            beg = end
