import numpy as np


from simulator.base_server import BaseServer


class SubsetRingReduceWorker:
    def __init__(self, worker_index, env, communication_pipe, function, gamma=None, local_step_size_multiplier=1):
        self._worker_index = worker_index
        self._env = env
        self._communication_pipe = communication_pipe
        self._function = function
        self._point = None
        self._iter = None
        self._local_sum = None
        self._num_local_sum = None
        self._gamma = gamma
        self._local_step_size_multiplier = local_step_size_multiplier

    def run(self):
        while self._point is None:
            next_data = yield self._communication_pipe.from_server.get()
            if isinstance(next_data, tuple) and next_data[0] == 'init':
                self._point, self._iter = next_data[1:]
            else:
                print(f"Worker {self._worker_index} received unexpected message on init: {next_data}")

        self._local_sum = 0
        self._num_local_sum = 0
        self._env.process(self.run_gradient())
        while True:
            data = yield self._communication_pipe.from_server.get()
            if isinstance(data, str) and data == 'allreduce':
                # print(f"Worker {self._worker_index} received allreduce")
                self._communication_pipe.to_server.put((self._local_sum, self._num_local_sum, self._iter))
            else:
                print(f"Worker {self._worker_index} received reset: {self._iter}->{data[1]}")
                self._point, self._iter = data
                self._local_sum = 0
                self._num_local_sum = 0

    def run_gradient(self):
        while True:
            current_iter = self._iter
            gradient = yield self._function.stochastic_gradient(self._point)
            if current_iter == self._iter:
                self._local_sum = self._local_sum + gradient
                self._num_local_sum = self._num_local_sum + 1
                if self._gamma is not None:
                    self._point = self._point - self._local_step_size_multiplier * self._gamma * gradient
                self._communication_pipe.to_server.put(('calculated', self._iter), _fast=True)


class SubsetRingReduceServer(BaseServer):
    def __init__(self, env, communication_pipes, point, gamma, optimizer_cls, num_grads, calculate_metrics=None, history_window=10, max_allreduce_workers_ratio=1):
        super(SubsetRingReduceServer, self).__init__(env, calculate_metrics)
        self._point = point
        self._gamma = gamma
        self._num_grads = num_grads
        self._communication_pipes = communication_pipes
        self._num_workers = len(self._communication_pipes)
        self._iter = 0
        self._stats["worker_useful"] = [0] * self._num_workers
        self._stats["worker_calculated"] = [0] * self._num_workers
        
        self._history_window = history_window
        self._max_allreduce_workers_ratio = max_allreduce_workers_ratio
        self._max_allreduce_workers = int(self._max_allreduce_workers_ratio * self._num_workers)
        
        self.optimizer = optimizer_cls(
            [self._point], lr=self._gamma
        )
    
    def prepare(self):
        self._pipes_from_server = [pipes.from_server for pipes in self._communication_pipes]
        self._pipes_to_server = [pipes.to_server for pipes in self._communication_pipes]
        self._get_events = [pipes.get() for pipes in self._pipes_to_server]
        
        self._current_calc_registry = {w_id: (0,0) for w_id in range(self._num_workers)}
        
        for pipes in self._pipes_from_server:
            pipes.put(("init", self._point, self._iter))
        return
        yield

    def step(self):
        # print("Waiting to collect")
        num_collected = 0
        while num_collected < self._num_grads:
            filled_pipes_dict = yield self._env.any_of(self._get_events)
            for event, data in filled_pipes_dict.items():
                worker_index = self._get_events.index(event)
                str_, iter_grad = data
                assert str_ == 'calculated'
                if self._history_window + iter_grad >= self._iter:
                    num_collected += 1
                    self._current_calc_registry[worker_index] = (
                        self._current_calc_registry[worker_index][0] + 1 if iter_grad == self._current_calc_registry[worker_index][1] else 1, 
                        iter_grad
                    )
                
                self._stats["worker_calculated"][worker_index] += 1        
                self._get_events[worker_index] = self._pipes_to_server[worker_index].get()
                
                
        # reduce top workers
        #- send reduce
        #- wait for allreduce
        #- update model
        #- increment iter
        #- send new points to workers
        ranked_workers = sorted(self._current_calc_registry.items(), key=lambda x: x[1][0], reverse=True) # x[1][0]*1e6 + 1e6-x[1][1]
        # thr_index = 0
        # num_allreduce = 0
        # for worker_index, (num_grads, iter_grad) in ranked_workers:
        #     num_allreduce += num_grads
        #     if num_allreduce >= self._num_grads:
        #         break
        #     thr_index += 1
        
        # workers_to_reduce = ranked_workers[:thr_index]
        workers_to_reduce = ranked_workers[:self._max_allreduce_workers]
        for worker_index, (num_grads, iter_grad) in workers_to_reduce:
            self._stats["worker_useful"][worker_index] += num_grads
            self._current_calc_registry.pop(worker_index)
        
        print(f"Allreduce for subset: {workers_to_reduce}")
        for worker_index, _ in workers_to_reduce:
            self._pipes_from_server[worker_index].put('allreduce', _fast=True)
            
        check_num_calculated = 0
        sum_gradient = 0
        for worker_index, _ in workers_to_reduce:
            while True:
                data = yield self._get_events[worker_index]
                self._get_events[worker_index] = self._pipes_to_server[worker_index].get()
                if isinstance(data, tuple) and data[0] == 'calculated':
                    # just receive late calculated messages
                    self._stats["worker_useful"][worker_index] += 1
                else:
                    local_sum, num_local_sum, iter = data
                    if self._history_window + iter >= self._iter:
                        check_num_calculated += num_local_sum
                        sum_gradient += local_sum
                    break
        # assert check_num_calculated >= num_collected, (check_num_calculated, num_collected)
        print(f"check_num_calculated: {check_num_calculated}/{num_collected}")
        
        self.optimizer.zero_grad()
        self._point.grad = sum_gradient
        self.optimizer.step()
        
        self._iter += 1
        for worker_index, _ in workers_to_reduce:
            self._current_calc_registry[worker_index] = (0, self._iter)
            self._pipes_from_server[worker_index].put((self._point, self._iter))

        
        # reset stale workers (if very old iteration)
        #- send actual data
        for worker_index in list(self._current_calc_registry.keys()):
            if self._history_window + self._current_calc_registry[worker_index][1] < self._iter:
                print(f"reject worker {worker_index}, {self._iter}, {self._current_calc_registry[worker_index]}")
                self._current_calc_registry[worker_index] = (0, self._iter)
                self._pipes_from_server[worker_index].put((self._point, self._iter))
