from simulator.base_server import BaseServer

class RingmasterSGDServer(BaseServer):
    def __init__(self, env, communication_pipes, point, gamma, optimizer_cls, num_grads, calculate_metrics=None):
        super(RingmasterSGDServer, self).__init__(env, calculate_metrics)
        self._point = point
        self._gamma = gamma
        self._num_grads = num_grads
        self._communication_pipes = communication_pipes
        self._num_workers = len(self._communication_pipes)
        self._iter = 0
        self._stats["worker_useful"] = [0] * self._num_workers
        self._stats["worker_calculated"] = [0] * self._num_workers
        
        # self.optimizer = optimizer_cls(
        #     [self._point], lr=self._gamma
        # )
        self.optimizer = optimizer_cls(
            self._point.parameters(), lr=self._gamma
        )

    def prepare(self):
        for pipes in self._communication_pipes:
            # pipes.from_server.put((self._point, self._iter))
            pipes.from_server.put(([p.data.clone() for p in self._point.parameters()], self._iter))
        self._pipes_from_server = [pipes.from_server for pipes in self._communication_pipes]
        self._pipes_to_server = [pipes.to_server for pipes in self._communication_pipes]
        self._get_events = [pipes.get() for pipes in self._pipes_to_server]
        return
        yield
        
    def step(self):
        filled_pipes_dict = yield self._env.any_of(self._get_events)
        for event, (gradient, iter_grad, num_calc) in filled_pipes_dict.items():
            worker_index = self._get_events.index(event)
            self._stats["worker_calculated"][worker_index] += 1
            # assert self._stats["worker_calculated"][worker_index] == num_calc
            if iter_grad >= self._iter - self._num_grads + 1:
                self._stats["worker_useful"][worker_index] += 1
                
                self.optimizer.zero_grad()
                # self._point.grad = gradient
                for p, g in zip(self._point.parameters(), gradient):
                    p.grad = g.to(p.device)
                self.optimizer.step()
                
                self._iter += 1
            # self._pipes_from_server[worker_index].put((self._point, self._iter))
            self._pipes_from_server[worker_index].put(([p.data.clone() for p in self._point.parameters()], self._iter))
            self._get_events[worker_index] = self._pipes_to_server[worker_index].get()
    
    def get_point(self):
        return self._point
    
    def get_stats(self):
        return self._stats
