from simulator.base_server import BaseServer

class SynchronizedSGDServer(BaseServer):
    def __init__(self, env, communication_pipes, point, gamma, optimizer_cls, calculate_metrics=None):
        super(SynchronizedSGDServer, self).__init__(env, calculate_metrics)
        self._point = point
        self._gamma = gamma
        self._communication_pipes = communication_pipes
        self._num_workers = len(self._communication_pipes)
        self._iter = 0
        self._stats['steps'] = 0
        self._stats['time'] = 0
        
        # self.optimizer = optimizer_cls(
        #     [self._point], lr=self._gamma
        # )
        self.optimizer = optimizer_cls(
            self._point.parameters(), lr=self._gamma
        )

    def prepare(self):
        return
        yield

    def step(self):
        for pipes in self._communication_pipes:
            # pipes.from_server.put(([(n, p.data.clone()) for n, p in self._point.named_parameters()], self._iter))
            # pipes.from_server.put((self._point, self._iter))
            pipes.from_server.put(([p.data.clone() for p in self._point.parameters()], self._iter))
        # mean_gradient = 0
        mean_gradient = None
        for pipes in self._communication_pipes:
            gradient, _, _ = yield pipes.to_server.get()
            # mean_gradient += gradient
            # if mean_gradient is None:
            #     mean_gradient = {name: g.clone() for name, g in gradient.items()}
            # else:
            #     mean_gradient = {name: mean_gradient[name]+g.to(mean_gradient[name].device) for name, g in gradient.items()}
            if mean_gradient is None:
                mean_gradient = gradient
            else:
                mean_gradient = [mg+g.to(mg.device) for mg,g in zip(mean_gradient, gradient)]
            
            
        # mean_gradient /= self._num_workers
        # mean_gradient = {n: mg/self._num_workers for n, mg in mean_gradient.items()}
        mean_gradient = [mg/self._num_workers for mg in mean_gradient]
        
        # for name, p in self._point.named_parameters():
        #     # p.grad = mean_gradient[name].to(p.device)
        #     p.data.add_(-self._gamma * mean_gradient[name].to(p.device))
        
        self.optimizer.zero_grad()
        
        for p, mg in zip(self._point.parameters(), mean_gradient):
            p.grad = mg.to(p.device)
            # p.data.add_(-self._gamma * mg.to(p.device))
        
        # self.optimizer.zero_grad()
        # self._point.grad = mean_gradient
        self.optimizer.step()
        
        self._stats['steps'] += 1
        self._stats['time'] = self._env.now
        self._iter += 1
