import simpy

from collections import namedtuple
from simulator.worker import Worker, WorkerWithLocalSteps, WorkerWithTargetComputeCommunicateRatio

Pipes = namedtuple('Pipes', 'from_server to_server')

class FixedTimeStochasticGradient:
    def __init__(self, env, func, time_to_calculate=1.):
        self._env = env
        self._func = func
        self._time_to_calculate = time_to_calculate
        self._lock = simpy.Resource(self._env, capacity=1)

    def stochastic_gradient(self, point):
        def _stochastic_gradient():
            with self._lock.request() as _:
                yield self._env.timeout(self._time_to_calculate)
                return self._func.stochastic_gradient(point)
        return self._env.process(_stochastic_gradient())


class FixedTimePipe:
    def __init__(self, env, pipe, time_to_communicate=0):
        self._env = env
        self._lock = simpy.Resource(self._env, capacity=1)
        self._pipe = pipe
        self._time_to_communicate = time_to_communicate

    # todo: improve this logic for arbitrarily sized data and delete _fast
    def put(self, data, _fast=False):
        def _put():
            with self._lock.request() as _:
                if not _fast:
                    yield self._env.timeout(self._time_to_communicate)
                yield self._pipe.put(data)
        return self._env.process(_put())

    def get(self):
        def _get():
            data = yield self._pipe.get()
            return data
        return self._env.process(_get())


def run_pipeline(server_cls, worker_cls, functions, point, gamma, optimizer_cls, sim_time=10000, 
                 times_to_calculate=None, times_to_communicate=None, 
                 server_params={}, worker_params={}, calculate_metrics=None, local_steps=None):
    env = simpy.Environment()
    num_workers = len(functions)
    communication_pipes = []
    times_to_communicate = ([0.] * num_workers) if times_to_communicate is None else times_to_communicate
    for time_to_communicate in times_to_communicate:
        from_server = FixedTimePipe(env, simpy.Store(env), time_to_communicate)
        to_server = FixedTimePipe(env, simpy.Store(env), time_to_communicate)
        communication_pipes.append(Pipes(from_server, to_server))
    server = server_cls(env, communication_pipes, point, gamma, optimizer_cls, **server_params, calculate_metrics=calculate_metrics)
    times_to_calculate = ([1.] * num_workers) if times_to_calculate is None else times_to_calculate
    workers = []
    for index, (communication_pipe, func, time_to_calculate) in enumerate(zip(communication_pipes, functions, times_to_calculate)):
        fixed_time_func = FixedTimeStochasticGradient(env, func, time_to_calculate)
        worker = worker_cls(index, env, communication_pipe, fixed_time_func, **worker_params)
        workers.append(worker)
    env.process(server.run())
    for worker in workers:
        env.process(worker.run())
    env.run(until=sim_time)
    return server.get_point(), server.get_stats()
