
class Worker:
    def __init__(self, worker_index, env, communication_pipe, function):
        self._worker_index = worker_index
        self._env = env
        self._communication_pipe = communication_pipe
        self._function = function
        self._num_calculated = 0

    def run(self):
        while True:
            point, iter = yield self._communication_pipe.from_server.get()
            gradient = yield self._function.stochastic_gradient(point)
            self._num_calculated += 1
            self._communication_pipe.to_server.put((gradient, iter, self._num_calculated))


class WorkerWithLocalSteps:
    def __init__(self, worker_index, env, communication_pipe, function, gamma, local_step_size_multiplier=1.0):
        self._worker_index = worker_index
        self._env = env
        self._communication_pipe = communication_pipe
        self._function = function
        self._num_calculated = 0
        
        self._gamma = gamma
        self._local_step_size_multiplier = local_step_size_multiplier

        self._point = None
        self._iter = None
        self._local_sum = None
        self._num_local_sum = None

    def run(self):
        while self._point is None:
            next_data = yield self._communication_pipe.from_server.get()
            if isinstance(next_data, tuple):
                self._point, self._iter = next_data
            else:
                # allreduce message can be sent before init as it uses _fast=True
                print(f"Worker {self._worker_index} received unexpected message on init: {next_data}")
                self._communication_pipe.to_server.put((0, 0, -1))

        self._local_sum = None
        self._num_local_sum = 0
        self._env.process(self.run_gradient())
        while True:
            data = yield self._communication_pipe.from_server.get()
            if isinstance(data, str) and data == 'allreduce':
                self._communication_pipe.to_server.put((self._local_sum, self._num_local_sum, self._iter))
            else:
                self._point, self._iter = data
                self._local_sum = None
                self._num_local_sum = 0

    def run_gradient(self):
        while True:
            current_iter = self._iter
            gradient = yield self._function.stochastic_gradient(self._point)
            if current_iter == self._iter:
                # self._local_sum = self._local_sum + gradient
                if self._local_sum is None:
                    self._local_sum = gradient
                else:   
                    self._local_sum = [ls+g.to(ls.device) for ls, g in zip(self._local_sum, gradient)]
                self._num_local_sum = self._num_local_sum + 1
                # self._point = self._point - self._local_step_size_multiplier * self._gamma * gradient
                for p, g in zip(self._point, gradient):
                    p.data.add_(-self._local_step_size_multiplier * self._gamma * g.to(p.data.device))
                
                self._communication_pipe.to_server.put(('calculated', self._iter), _fast=True)

class WorkerWithTargetComputeCommunicateRatio:
    def __init__(self, worker_index, env, communication_pipe, function, gamma, local_step_size_multiplier=1, min_local_steps=1, target_communicate_compute_ratio = None):
        self._worker_index = worker_index
        self._env = env
        self._communication_pipe = communication_pipe
        self._function = function
        self._num_calculated = 0
        
        self._gamma = gamma
        self._target_communicate_compute_ratio = target_communicate_compute_ratio
        self._min_local_steps = min_local_steps
        self._local_step_size_multiplier = local_step_size_multiplier
        if target_communicate_compute_ratio is None:
            self._local_steps = min_local_steps
        else:
            self._local_steps = max(
                min_local_steps,
                int(
                    1 + self._communication_pipe.to_server._time_to_communicate // (function._time_to_calculate * target_communicate_compute_ratio)
                )
            )

    def run(self):
        while True:
            point, iter = yield self._communication_pipe.from_server.get()
            
            gradient_sum = None
            for local_iter in range(self._local_steps):
                gradient = yield self._function.stochastic_gradient(point)
                if gradient_sum is None:
                    gradient_sum = gradient
                else:
                    gradient_sum = [gs+g.to(gs.device) for gs, g in zip(gradient_sum, gradient)]
                
                for p, g in zip(point, gradient):
                    p.data.add_(-self._local_step_size_multiplier * self._gamma * g.to(p.data.device))
                
                self._num_calculated += 1
            
            self._communication_pipe.to_server.put((gradient_sum, iter, self._num_calculated))
