'''
WorkerPool and WorkerBase for handling the common problems in managing
a multiprocess pool of workers that aren't done by multiprocessing.Pool,
including setup with per-process state, debugging by putting the worker
on the main thread, and correct handling of unexpected errors, and ctrl-C.

While the pytorch DataLoader is helpful for fast reading of data in
parallel, a utility is needed for fast writing of data.  This worker
class simpliifies this problem - by David Bau.

To use it,
1. Put the per-process setup and the per-task work in the
   setup() and work() methods of your own WorkerBase subclass.
2. To prepare the process pool, instantiate a WorkerPool, passing your
   subclass type as the first (worker) argument, as well as any setup keyword
   arguments.  The WorkerPool will instantiate one of your workers in each
   worker process (passing in the setup arguments in those processes).
   If debugging, the pool can have num_workers=0 to force all the work
   to be done immediately on the main thread; otherwise all the work
   will be passed to other processes.
3. Whenever there is a new piece of work to distribute, call pool.add(*args).
   The arguments will be queued and passed as worker.work(*args) to the
   next available worker.
4. When all the work has been distributed, call pool.join() to wait for all
   the work to complete and to finish and terminate all the worker processes.
   When pool.join() returns, all the work will have been done.

No arrangement is made to collect the results of the work: for example,
the return value of work() is ignored.  If you need to collect the
results, use your own mechanism (filesystem, shared memory object, queue)
which can be distributed using setup arguments.
'''

from multiprocessing import Process, Queue, cpu_count
import signal
import atexit
import sys


class WorkerBase(Process):
    '''
    Subclass this class and override its work() method (and optionally,
    setup() as well) to define the units of work to be done in a process
    worker in a woker pool.
    '''

    def __init__(self, i, num_workers, queue, initargs):
        if num_workers > 0:
            # Make sure we ignore ctrl-C if we are not on main process.
            signal.signal(signal.SIGINT, signal.SIG_IGN)
        self.process_id = i
        self.num_workers = num_workers
        self.queue = queue
        super(WorkerBase, self).__init__()
        self.setup(**initargs)

    def run(self):
        # Do the work until None is dequeued
        while True:
            try:
                work = self.queue.get()
            except (KeyboardInterrupt, SystemExit):
                print('Exiting...')
                break
            if work is None:
                self.queue.put(None)  # for another worker
                return
            else:
                work_batch, work_kw = work
            self.work(*work_batch, **work_kw)

    def setup(self, **initargs):
        '''
        Override this method for any per-process initialization.
        Keywoard args are passed from WorkerPool constructor.
        '''
        pass

    def work(self, *args, **kwargs):
        '''
        Override this method for one-time initialization.
        Args are passed from WorkerPool.add() arguments.
        '''
        raise NotImplementedError('worker subclass needed')


class WorkerPool(object):
    '''
    Instantiate this object (passing a WorkerBase subclass type
    as its first argument) to create a worker pool.  Then call
    pool.add(*args) to queue args to distribute to worker.work(*args),
    and call pool.join() to wait for all the workers to complete.
    '''

    def __init__(self, worker=WorkerBase, num_workers=None, **initargs):
        global active_pools
        if num_workers is None:
            num_workers = cpu_count()
        if num_workers == 0:
            # zero num_workers uses only main process, for debugging.
            self.queue = None
            self.processes = None
            self.worker = worker(None, 0, None, initargs)
            return
        # Ctrl-C strategy: worker processes should ignore ctrl-C.  Set
        # this up to be inherited by child processes before forking.
        original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
        active_pools[id(self)] = self
        self.queue = Queue(maxsize=(num_workers * 3))
        self.processes = None   # Initialize before trying to construct workers
        self.processes = [worker(i, num_workers, self.queue, initargs)
                          for i in range(num_workers)]
        for p in self.processes:
            p.start()
        # The main process should handle ctrl-C.  Restore this now.
        signal.signal(signal.SIGINT, original_sigint_handler)

    def add(self, *work_batch, **work_kw):
        if self.queue is None:
            if hasattr(self, 'worker'):
                self.worker.work(*work_batch, **work_kw)
            else:
                print('WorkerPool shutting down.', file=sys.stderr)
        else:
            try:
                # The queue can block if the work is so slow it gets full.
                self.queue.put([work_batch, work_kw])
            except (KeyboardInterrupt, SystemExit):
                # Handle ctrl-C if done while waiting for the queue.
                self.early_terminate()

    def join(self):
        # End the queue, and wait for all worker processes to complete nicely.
        if self.queue is not None:
            self.queue.put(None)
            for p in self.processes:
                p.join()
            self.queue = None
        # Remove myself from the set of pools that need cleanup on shutdown.
        try:
            del active_pools[id(self)]
        except:
            pass

    def early_terminate(self):
        # When shutting down unexpectedly, first end the queue.
        if self.queue is not None:
            try:
                self.queue.put_nowait(None)  # Nonblocking put throws if full.
                self.queue = None
            except:
                pass
        # But then don't wait: just forcibly terminate workers.
        if self.processes is not None:
            for p in self.processes:
                p.terminate()
            self.processes = None
        try:
            del active_pools[id(self)]
        except:
            pass

    def __del__(self):
        if self.queue is not None:
            print('ERROR: workerpool.join() not called!', file=sys.stderr)
            self.join()


# Error and ctrl-C handling: kill worker processes if the main process ends.
active_pools = {}


def early_terminate_pools():
    for _, pool in list(active_pools.items()):
        pool.early_terminate()


atexit.register(early_terminate_pools)
