import numpy as np

from simulator.base_server import BaseServer


class RennalaSGDSoftReduceWorker:
    def __init__(self, worker_index, env, communication_pipe, function):
        self._worker_index = worker_index
        self._env = env
        self._communication_pipe = communication_pipe
        self._function = function
        self._point = None
        self._iter = None
        self._local_sum = None
        self._num_local_sum = None
        
        self._resent_soft_deadline = 0
        

    def run(self):
        self._point, self._iter = yield self._communication_pipe.from_server.get()
        # print(f"CHECK start point: {self._point}")
        self._local_sum = 0
        self._num_local_sum = 0
        self._env.process(self.run_gradient())
        while True:
            data = yield self._communication_pipe.from_server.get()
            # print(f"CHECK data: {data}, {isinstance(data[0], str)}, {data[0] == 'allreduce'}")
            if isinstance(data[0], str) and data[0] == 'allreduce':
                # max_time_to_communicate = data[1]
                # residual_time = max_time_to_communicate - self._communication_pipe.to_server._time_to_communicate
                # local_soft_deadline = max(0, residual_time // self._function._time_to_calculate - 1)
                local_soft_deadline = 0
                self._resent_soft_deadline = local_soft_deadline
                if local_soft_deadline > 0:
                    yield self._env.timeout(local_soft_deadline)
                self._resent_soft_deadline = 0
                
                self._communication_pipe.to_server.put((self._local_sum, self._num_local_sum, self._iter))
            else:
                self._point, self._iter = data
                self._local_sum = 0
                self._num_local_sum = 0

    def run_gradient(self):
        while True:
            current_iter = self._iter
            # if isinstance(self._point, str):
            #     print(f"CHECK point: {self._point}")
            gradient = yield self._function.stochastic_gradient(self._point)
            if current_iter == self._iter:
                self._local_sum = self._local_sum + gradient
                self._num_local_sum = self._num_local_sum + 1
                # self._communication_pipe.to_server.put(('calculated', self._iter, self._resent_soft_deadline))
                self._communication_pipe.to_server.put(('calculated', self._iter, self._resent_soft_deadline), _fast=True)


class RennalaSGDSoftReduceServer(BaseServer):
    def __init__(self, env, communication_pipes, point, gamma, optimizer_cls, num_grads, calculate_metrics=None, history_window=0):
        super(RennalaSGDSoftReduceServer, self).__init__(env, calculate_metrics)
        self._point = point
        self._gamma = gamma
        self._num_grads = num_grads
        self._communication_pipes = communication_pipes
        self._num_workers = len(self._communication_pipes)
        self._iter = 0
        self._stats["worker_useful"] = [0] * self._num_workers
        self._stats["worker_calculated"] = [0] * self._num_workers
        
        self.optimizer = optimizer_cls(
            [self._point], lr=self._gamma
        )
    
    def prepare(self):
        self._pipes_from_server = [pipes.from_server for pipes in self._communication_pipes]
        self._pipes_to_server = [pipes.to_server for pipes in self._communication_pipes]
        self._get_events = [pipes.get() for pipes in self._pipes_to_server]
        return
        yield

    def step(self):
        for pipes in self._pipes_from_server:
            pipes.put((self._point, self._iter))
        num_collected = 0
        while num_collected < self._num_grads:
            filled_pipes_dict = yield self._env.any_of(self._get_events)
            for event, data in filled_pipes_dict.items():
                worker_index = self._get_events.index(event)
                str_, iter_grad, _ = data
                assert str_ == 'calculated'
                if iter_grad == self._iter:
                    num_collected += 1
                    # self._stats["worker_useful"][worker_index] += 1
                self._stats["worker_calculated"][worker_index] += 1
                self._get_events[worker_index] = self._pipes_to_server[worker_index].get()

        max_time_to_communicate = max([p._time_to_communicate for p in self._pipes_from_server])
        for pipes in self._pipes_from_server:
            pipes.put(('allreduce', max_time_to_communicate, 0), _fast=True)
        check_num_calculated = 0
        sum_gradient = 0
        while check_num_calculated < num_collected:
            filled_data_pipes_dict = yield self._env.any_of(self._get_events)
            for event, data in filled_data_pipes_dict.items():
                worker_index = self._get_events.index(event)

                self._get_events[worker_index] = self._pipes_to_server[worker_index].get()
                # if isinstance(data[0], str) and data[2] > 0:
                #     self._stats["worker_useful"][worker_index] += 1
                
                if not isinstance(data[0], str):
                    local_sum, num_local_sum, iter = data
                    self._stats["worker_useful"][worker_index] += num_local_sum
                    if iter == self._iter:
                        check_num_calculated += num_local_sum
                        sum_gradient += local_sum
        # assert check_num_calculated == num_collected, (check_num_calculated, num_collected)
        # print(f"CHECK num_collected: {num_collected}, check_num_calculated: {check_num_calculated}")
        
        self.optimizer.zero_grad()
        self._point.grad = sum_gradient
        self.optimizer.step()
        
        self._iter += 1
