import numpy as np

from simulator.base_server import BaseServer


class RennalaSGDWorker:
    def __init__(self, worker_index, env, communication_pipe, function):
        self._worker_index = worker_index
        self._env = env
        self._communication_pipe = communication_pipe
        self._function = function
        self._point = None
        self._iter = None
        self._local_sum = None
        self._num_local_sum = None

    def run(self):
        while self._point is None:
            next_data = yield self._communication_pipe.from_server.get()
            if isinstance(next_data, tuple):
                self._point, self._iter = next_data
            else:
                # allreduce message can be sent before init as it uses _fast=True
                print(f"Worker {self._worker_index} received unexpected message on init: {next_data}")
                self._communication_pipe.to_server.put((0, 0, -1))

        self._local_sum = None
        self._num_local_sum = 0
        self._env.process(self.run_gradient())
        while True:
            data = yield self._communication_pipe.from_server.get()
            if isinstance(data, str) and data == 'allreduce':
                self._communication_pipe.to_server.put((self._local_sum, self._num_local_sum, self._iter))
            else:
                self._point, self._iter = data
                self._local_sum = None
                self._num_local_sum = 0

    def run_gradient(self):
        while True:
            current_iter = self._iter
            gradient = yield self._function.stochastic_gradient(self._point)
            if current_iter == self._iter:
                # self._local_sum = self._local_sum + gradient
                if self._local_sum is None:
                    self._local_sum = gradient
                else:   
                    self._local_sum = [ls+g.to(ls.device) for ls, g in zip(self._local_sum, gradient)]
                self._num_local_sum = self._num_local_sum + 1
                self._communication_pipe.to_server.put(('calculated', self._iter), _fast=True)


class RennalaSGDServer(BaseServer):
    def __init__(self, env, communication_pipes, point, gamma, optimizer_cls, num_grads, calculate_metrics=None):
        super(RennalaSGDServer, self).__init__(env, calculate_metrics)
        self._point = point
        self._gamma = gamma
        self._num_grads = num_grads
        self._communication_pipes = communication_pipes
        self._num_workers = len(self._communication_pipes)
        self._iter = 0
        self._stats["worker_useful"] = [0] * self._num_workers
        self._stats["worker_calculated"] = [0] * self._num_workers
        
        # self.optimizer = optimizer_cls(
        #     [self._point], lr=self._gamma
        # )
        self.optimizer = optimizer_cls(
            self._point.parameters(), lr=self._gamma
        )
    
    def prepare(self):
        self._pipes_from_server = [pipes.from_server for pipes in self._communication_pipes]
        self._pipes_to_server = [pipes.to_server for pipes in self._communication_pipes]
        self._get_events = [pipes.get() for pipes in self._pipes_to_server]
        return
        yield

    def step(self):
        for pipes in self._pipes_from_server:
            # pipes.put((self._point, self._iter))
            pipes.put(([p.data.clone() for p in self._point.parameters()], self._iter))
        num_collected = 0
        while num_collected < self._num_grads:
            filled_pipes_dict = yield self._env.any_of(self._get_events)
            for event, data in filled_pipes_dict.items():
                worker_index = self._get_events.index(event)
                str_, iter_grad = data
                assert str_ == 'calculated'
                if iter_grad == self._iter:
                    num_collected += 1
                    self._stats["worker_useful"][worker_index] += 1
                self._stats["worker_calculated"][worker_index] += 1
                self._get_events[worker_index] = self._pipes_to_server[worker_index].get()
        for pipes in self._pipes_from_server:
            pipes.put('allreduce', _fast=True)
        check_num_calculated = 0
        # sum_gradient = 0
        sum_gradient = None
        for worker_index in range(self._num_workers):
            while True:
                # ignoring possible 'calculated' messages
                data = yield self._get_events[worker_index]
                self._get_events[worker_index] = self._pipes_to_server[worker_index].get()
                if not isinstance(data[0], str):
                    break
            local_sum, num_local_sum, iter = data
            if iter == self._iter:
                check_num_calculated += num_local_sum
                # sum_gradient += local_sum
                if sum_gradient is None:
                    sum_gradient = local_sum
                elif local_sum is not None:
                    sum_gradient = [sg+lg.to(sg.device) for sg, lg in zip(sum_gradient, local_sum)]
        assert check_num_calculated == num_collected, (check_num_calculated, num_collected)
        
        self.optimizer.zero_grad()
        # self._point.grad = sum_gradient
        for p, sg in zip(self._point.parameters(), sum_gradient):
            p.grad = sg.to(p.device)
        self.optimizer.step()
        
        self._iter += 1
