import heapq

import numpy as np

from distributed_optimization_library.factory import Factory
from distributed_optimization_library.function import OptimizationProblemMeta
from distributed_optimization_library.signature import Signature
from distributed_optimization_library.asynchronous.asynchronous_transport import DelayedAsynchronousTransport
from distributed_optimization_library.algorithm import BaseMasterAlgorithm


class FactoryAsyncMaster(Factory):
    pass

class FactoryAsyncNode(Factory):
    pass


class StochasticGradientNodeAlgorithm(object):
    def __init__(self, function, **kwargs):
        self._function = function
    
    def calculate_stochastic_gradient(self, point):
        return self._function.stochastic_gradient(point)
    
    def calculate_function(self, point):
        return self._function.value(point)
    
    def calculate_gradient(self, point):
        return self._function.gradient(point)


@FactoryAsyncNode.register("asynchronous_sgd_node")
class AsynchronousSGDNode(StochasticGradientNodeAlgorithm):
    pass


@FactoryAsyncNode.register("rennala_node")
class AsynchronousMiniBatchSGDNode(StochasticGradientNodeAlgorithm):
    pass


@FactoryAsyncNode.register("minibatch_sgd_node")
class MiniBatchSGDNode(StochasticGradientNodeAlgorithm):
    pass


@FactoryAsyncMaster.register("asynchronous_sgd_master")
class AsynchronousSGD(object):
    def __init__(self, transport, point, gamma=None, gamma_multiply=None, seed=None, meta=None):
        self._transport = transport
        self._point = point
        if gamma_multiply is not None:
            gamma *= gamma_multiply
        self._gamma = gamma
        self._seed = seed
        self._time = 0
        
        self._heap = []
        self._iter = 0
        self._number_of_nodes = self._transport.get_number_of_nodes()
        
        for node_index in range(self._transport.get_number_of_nodes()):
            available_time = self._transport.call_available_node_method(
                self._time, node_index, node_method="calculate_stochastic_gradient", point=self._point)
            heapq.heappush(self._heap, (available_time, node_index, self._iter))
    
    def step(self):
        available_time, node_index, iter = heapq.heappop(self._heap)
        self._time = available_time
        stochastic_gradient = self._transport.call_ready_node(self._time, node_index)
        if iter >= self._iter - self._number_of_nodes:
            self._point = self._point - self._gamma * stochastic_gradient
        self._iter += 1
        available_time = self._transport.call_available_node_method(
            self._time, node_index, node_method="calculate_stochastic_gradient", point=self._point)
        heapq.heappush(self._heap, (available_time, node_index, self._iter))
        
    def calculate_function(self):
        return np.mean(self._transport.call_nodes_method(node_method='calculate_function',
                                                         point=self._point))
        
    def get_point(self):
        return self._point
    
    def get_time(self):
        return self._time


@FactoryAsyncMaster.register("rennala_master")
class AsynchronousMiniBatchSGD(object):
    def __init__(self, transport, point, gamma=None, gamma_multiply=None, batch_size=None, seed=None, meta=None):
        self._transport = transport
        self._point = point
        if gamma_multiply is not None:
            gamma *= gamma_multiply
        self._gamma = gamma
        self._batch_size = batch_size
        self._seed = seed
        self._time = 0
        
        self._heap = []
        self._iter = 0
        self._number_of_nodes = self._transport.get_number_of_nodes()
        
        for node_index in range(self._transport.get_number_of_nodes()):
            available_time = self._transport.call_available_node_method(
                self._time, node_index, node_method="calculate_stochastic_gradient", point=self._point)
            heapq.heappush(self._heap, (available_time, node_index, self._iter))
            
        self._gradient_estimator = 0
        self._current_batch = 0
    
    def step(self):
        available_time, node_index, iter = heapq.heappop(self._heap)
        self._time = available_time
        stochastic_gradient = self._transport.call_ready_node(self._time, node_index)
        if iter == self._iter:
            self._gradient_estimator = self._gradient_estimator + stochastic_gradient
            self._current_batch += 1
        if self._current_batch >= self._batch_size:
            assert self._current_batch == self._batch_size
            self._point = self._point - self._gamma * (self._gradient_estimator / self._current_batch)
            self._iter += 1
            self._current_batch = 0
            self._gradient_estimator = 0
        available_time = self._transport.call_available_node_method(
            self._time, node_index, node_method="calculate_stochastic_gradient", point=self._point)
        heapq.heappush(self._heap, (available_time, node_index, self._iter))
    
    def calculate_function(self):
        return np.mean(self._transport.call_nodes_method(node_method='calculate_function',
                                                         point=self._point))
        
    def get_point(self):
        return self._point
    
    def get_time(self):
        return self._time


@FactoryAsyncMaster.register("minibatch_sgd_master")
class MiniBatchSGD(object):
    def __init__(self, transport, point, gamma=None, gamma_multiply=None, batch_size=None, seed=None, meta=None):
        self._transport = transport
        self._point = point
        if gamma_multiply is not None:
            gamma *= gamma_multiply
        self._gamma = gamma
        self._batch_size = batch_size
        self._seed = seed
        self._time = 0
        
        self._number_of_nodes = self._transport.get_number_of_nodes()
        self._current_times = [None for _ in range(self._number_of_nodes)]
        for node_index in range(self._number_of_nodes):
            available_time = self._transport.call_available_node_method(
                self._time, node_index, node_method="calculate_stochastic_gradient", point=self._point)
            self._current_times[node_index] = available_time
    
    def step(self):
        max_available_time = -np.inf
        for node_index in range(self._number_of_nodes):
            available_time = self._current_times[node_index]
            max_available_time = max(max_available_time, available_time)
            
        self._time = max_available_time
        gradient_estimator = 0
        for node_index in range(self._number_of_nodes):
            stochastic_gradient = self._transport.call_ready_node(self._time, node_index)
            gradient_estimator = gradient_estimator + stochastic_gradient
        self._point = self._point - self._gamma * (gradient_estimator / self._number_of_nodes)
        for node_index in range(self._number_of_nodes):
            available_time = self._transport.call_available_node_method(
                self._time, node_index, node_method="calculate_stochastic_gradient", point=self._point)
            self._current_times[node_index] = available_time
    
    def calculate_function(self):
        return np.mean(self._transport.call_nodes_method(node_method='calculate_function',
                                                         point=self._point))
        
    def get_point(self):
        return self._point
    
    def get_time(self):
        return self._time


def _generate_seed(generator):
    return generator.integers(10e9)


def get_algorithm(functions, point, seed, 
                  algorithm_name, delays, 
                  algorithm_master_params={}, algorithm_node_params={},
                  meta=OptimizationProblemMeta()):
    node_name = algorithm_name + "_node"
    master_name = algorithm_name + "_master"
    node_cls = FactoryAsyncNode.get(node_name)
    master_cls = FactoryAsyncMaster.get(master_name)
    generator = np.random.default_rng(seed)
    nodes = [Signature(node_cls, function, seed=_generate_seed(generator), **algorithm_node_params) 
             for function in functions]
    transport = DelayedAsynchronousTransport(nodes, delays)
    return master_cls(transport, point, seed=seed, meta=meta, **algorithm_master_params)
