"""Helper abstraction to use multiple worker processes to perform a random search type of workload.

See `SearchWorkers` which is the main abstraction class.
"""
import copy
import ctypes
import enum
import multiprocessing.sharedctypes
import sys
import torch
import torch.multiprocessing
import traceback
import unittest

from torch.multiprocessing.queue import ConnectionWrapper
from typing import Callable, Optional, Tuple, TypeVar
T = TypeVar('T')

class SpawnMethod(enum.Enum):
    """Various methods for creating a process.
    See https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing for pytorch guidelines.
    Using the benchmark below, they all seem to have similar performance in steady-state."""

    Fork = 'fork'
    ForkServer = 'forkserver'
    Spawn = 'spawn'

class Campaign:
    """Represent the current search campaign.

    is_running() / has_ended(): allows the search_fn to check if the campaign as ended, in which case it should stop and return None."""
    def __init__(self, alive_flag):
        self._alive_flag: multiprocessing.sharedctypes.RawValue = alive_flag
    
    def is_running(self) -> bool:
        return self._alive_flag.value
    def has_ended(self) -> bool:
        return not self.is_running()

class SearchWorkers:
    """Manager for a Pool of worker processes that can perform a search in parallel.

    The search is organized into 'campaigns'.
    During each campaign, all processes will search indepently until one finds a result.
    When a result is found, the campaign is stopped and all processes will go into a standby state.
    The processes are kept alive until close() is called, and can be reused for multiple campaigns.

    close() must be called for graceful shutdown after use to avoid keeping useless processes around.
    This can be done safely with a 'with' statement:

        with SearchWorkers(...) as workers:
            workers.campaign(...)
            ...
            workers.campaign(...)

    Worker processes will disable intra-operation parallelism in pytorch using `torch.set_num_threads(1)`.
    A trivially parallel search (no synchronization) is more efficient than intra-op parallelism and scales better.
    They also re-seed various random number generators to search different points if it relies on randomness.

    If used with `num_workers = 1` this will not spawn any processes and `campaign()` will just call the `search_fn`.
    This can be used to enable profiling and reduce ressource usage without having 2 variants of the search code.
    Beware that it may hide problems linked to multiprocessing so remember to also test with `num_workers > 1`.
    """
    def __init__(self, method: SpawnMethod = SpawnMethod.Spawn, num_workers: Optional[int] = None):
        """Start worker processes.
        
        - method: see `SpawnMethod` enum for options.
        - num_workers: number of workers. If not set this uses `torch.get_num_threads()` which seems to use hardware core count.
            Another option is to use `torch.multiprocessing.cpu_count()` which uses hardware threads (hyperthreading, if available).
        """
        if num_workers is None:
            num_workers = torch.get_num_threads()
        assert num_workers > 0
        # Flag used by a worker to stop the overall search. No locking required:
        # - no read-modify-write operation, only single write or read
        # - all potential unsychronized writes use the same value 'False' so no divergent data-race.
        self.campaign_running = multiprocessing.sharedctypes.RawValue(ctypes.c_bool, False)

        if num_workers == 1:
            self.workers = _ExecuteSearchFnLocally()
        else:
            mp_context = torch.multiprocessing.get_context(method.value)
            def spawn_worker(i: int) -> Tuple[multiprocessing.Process, ConnectionWrapper]:
                pipe_parent, pipe_child = _make_pipe(duplex = True)
                worker = mp_context.Process(
                    target = _worker_fn, name = f"search_worker_{i}", daemon = False,
                    args = (pipe_child, self.campaign_running)
                )
                worker.start()
                return (worker, pipe_parent)
            self.workers = [spawn_worker(i) for i in range(num_workers)]

    def campaign(self, search_fn: Callable[..., Optional[T]], *args, **kwargs) -> Optional[T]:
        """Run a search campaign to completion and return its overall result.
        
        search_fn is called on every worker as `search_fn(campaign: Campaign, *args, **kwargs)`.
        It should search until it finds a result or another worker finds one and stops the campaign.
        If successful the search_fn must return a non-None result which is used as the result of the campaign() method.
        If the `campaign` class instance indicates campaign completion, the search_fn should stop and return None.

        A typical search_fn structure should look like:

            while campaign.is_running():
                ... # computation / test / attempt
                if successful: return result

        search_fn must be *pickable*, which means it must be defined as a top level function in a python module (no lambda).
        If multiple search_fn find a result at the same time, one is chosen.
        Only the result data passed to `return` is transfered back to the main process, everything else done locally is lost.

        If all search_fn return None, the campaign() method also returns None.
        Thus users can implement a timeout or iteration limit in search_fn.

        An exception emitted from a search_fn will abort the campaign with no result and be raised wrapped as a `WorkerException`.
        The worker processes are put in standby and can be used for new campaigns.

        Measures are in place to reduce the likelihood of keeping not-shared data long term in shared memory,
        as if it the case an error could occur when shared memory file descriptors hit a limit ("too many files open").
        It is recommended that tensor arguments unique to a single campaign should be sent as a copy.
        """
        if isinstance (self.workers, _ExecuteSearchFnLocally):
            self.campaign_running.value = True
            try:
                return search_fn(Campaign(alive_flag = self.campaign_running), *args, **kwargs)
            except:
                raise WorkerException(traceback.format_exc())
        
        assert len(self.workers) > 0
        # Start campaign
        self.campaign_running.value = True
        for _, pipe in self.workers:
            pipe.send((search_fn, args, kwargs))
        # The first worker process to find a result sets campaign_running to False and send its result.
        # All other workers stop searching when they see the campaign being stopped and send a None.
        results = [pipe.recv() for _, pipe in self.workers]
        # Reset in case no search_fn stopped the campaign (not found).
        self.campaign_running.value = False
        # Check for exceptions, or return first result.
        for r in results:
            if isinstance(r, WorkerException):
                raise r
        for r in results:
            if r is not None:
                # Use deepcopy to force all tensors in r to be copied away from shared_memory, reducing pressure on shm fds.
                return copy.deepcopy(r)
        return None # No result found

    def close(self):
        """Shutdown worker processes and cleanup ressources."""
        for _, pipe in self.workers:
            pipe.send(_ShutdownRequest)
            pipe.close()
        for worker, _ in self.workers:
            worker.join()
        # TODO timeout + terminate ?
        self.workers = []

    # Support with statement
    def __enter__(self):
        return self
    def __exit__(self, exc_type, exc_value, traceback):
        self.close()

    def __del__(self):
        # Non-graceful cleanup, list should be empty if close() was called
        for worker, pipe in self.workers:
            pipe.close()
            worker.terminate()

class WorkerException(Exception):
    """Contains the traceback of an exception that occurred on a worker process"""
    # Trying to wrap the actual exception type is complex and could fail due to pickling error when transfered.
    pass

def _worker_fn(pipe: ConnectionWrapper, campaign_running: multiprocessing.sharedctypes.RawValue):
    # Using intra-op parallelism during a parallel search is not effective in general so disable it.
    # This can be re-enabled by calling set_num_threads in search_fn if needed
    torch.set_num_threads(1)
    _set_random_seeds()
    while True:
        request = pipe.recv()
        if request is _ShutdownRequest:
            return
        search_fn, args, kwargs = request
        
        # For each campaign the worker MUST send a result : value, None, or exception log
        try:
            result = search_fn(Campaign(alive_flag = campaign_running), *args, **kwargs)
            del request, search_fn, args, kwargs
        except KeyboardInterrupt:
            return # Killed by parent, shutdown quietly (no pipe.send as pipe is broken in parent)
        except:
            campaign_running.value = False # Abort campaign
            worker_name = multiprocessing.current_process().name
            pipe.send(WorkerException(f"in worker '{worker_name}':\n{traceback.format_exc()}"))
        else:
            if result is not None:
                campaign_running.value = False # Signal other processes to stop searching
                pipe.send(result)
                del result
            else:
                pipe.send(None)

class _ShutdownRequest:
    """Private class object used as tag to request graceful shutdown of subprocesses."""
    # A previous attempt used pipe_parent.close() and expected an EOFError in pipe_child.recv().
    # However references to pipe_parent are kept alive somewhere (processes) and prevent pipe shutdown.
    pass

def _make_pipe(duplex: bool) -> Tuple[ConnectionWrapper, ConnectionWrapper]:
    """torch.multiprocessing uses ConnectionWrapper in redefinitions of Queue/SimpleQueue to enable shared memory for tensor transfers.
    Pipe is not redefined but it is simple enough to do it here."""
    pipe_in, pipe_out = torch.multiprocessing.Pipe(duplex = duplex)
    return (ConnectionWrapper(pipe_in), ConnectionWrapper(pipe_out))

def _set_random_seeds():
    """Set new random seed for pseudo random generators in pytorch to differentiate workers results.
    Also sets the PNRG for python's random and numpy if they are in use (imported already)."""
    # Results are already not reproducible due to non-deterministic search order between processes.
    new_seed = torch.seed() # Re-seed pytorch and returns new chosen seed
    if 'random' in sys.modules:
        import random
        random.seed(new_seed + 1234)
    if 'numpy' in sys.modules:
        import numpy
        numpy.random.seed((new_seed + 2345) % 2**32) # Requires an uint32

class _ExecuteSearchFnLocally:
    """Class *tag* that replaces the list of workers for num_workers = 1 to indicate this special case.
    In this case, no workers are spawned and search_fn is executed directly by `campaign()`."""
    def __iter__(self):
        """Behave as a list of 0 workers for close()/enter/exit/del."""
        return iter([])

####################################### Tests ########################################

class ParallelTests(unittest.TestCase):
    def test_behavior_multiprocess(self):
        with SearchWorkers(num_workers = 2) as workers:
            self.assertEqual(workers.campaign(_testing_search_fn, ending = 'success'), 42)
            self.assertEqual(workers.campaign(_testing_search_fn, 'success'), 42)
            self.assertIsNone(workers.campaign(_testing_search_fn, ending = 'not_found'))
            with self.assertRaisesRegex(WorkerException, 'RuntimeError'):
                workers.campaign(_testing_search_fn, 'crash')

    def test_behavior_no_workers_optimization(self):
        with SearchWorkers(num_workers = 1) as workers:
            self.assertEqual(workers.campaign(_testing_search_fn, ending = 'success'), 42)
            self.assertEqual(workers.campaign(_testing_search_fn, 'success'), 42)
            self.assertIsNone(workers.campaign(_testing_search_fn, ending = 'not_found'))
            with self.assertRaisesRegex(WorkerException, 'RuntimeError'):
                workers.campaign(_testing_search_fn, 'crash')
    
    def test_example_class(self):
        search_example = _ExampleSearchClass()
        non_parallel_result = search_example.search_target_index(nb_workers = 1) # Non-parallel code
        parallel_result = search_example.search_target_index(nb_workers = 2) # Parallel code
        if non_parallel_result is not None and parallel_result is not None:
            # Only check match if both search did not timeout.
            # This could fail due to random float collisions but very unlikely.
            self.assertEqual(non_parallel_result, parallel_result)

def _testing_search_fn(campaign: Campaign, ending: str):
    import random
    iterations = 0
    while campaign.is_running():
        iterations += 1
        # "success" after some iterations
        if iterations == 42:
            if ending == 'crash': raise RuntimeError # Simulate crash
            elif ending == 'not_found': return None
            elif ending == 'success': return iterations
            else: raise AssertionError

class _ExampleSearchClass:
    """Example of using SearchWorkers within a class"""
    def __init__(self):
        """Some internal state representing a search problem: find the index of a particular item in a list."""
        self.numbers = torch.rand(100)
        self.target = self.numbers[self.draw_index()]

    def draw_index(self) -> int:
        """Part of the search process: pick an index to test at random."""
        return int(torch.randint(0, self.numbers.size(0), (1,)))

    def check_index(self, i: int) -> bool:
        """Procedure to check if an index is the right one."""
        return bool(self.numbers[i] == self.target)

    def search_target_index(self, nb_workers: int, nb_attempts: int = 1000) -> Optional[int]:
        """User function that implement the search procedure."""
        if nb_workers == 1:
            # What the search procedure looks like when using no multithreading
            for attempt in range(nb_attempts):
                index = self.draw_index()
                if self.check_index(index):
                    return index
            return None # Not really needed as None is implicitly returned
        else:
            # What it looks like using the SearchWorkers abstraction.
            # No need to implement both versions in your code as Searchworkers works with nb_workers=1.
            with SearchWorkers(num_workers = nb_workers) as workers:
                return workers.campaign(_ExampleSearchClass._search_fn, self, nb_attempts // nb_workers)

    @staticmethod
    def _search_fn(campaign: Campaign, self, worker_nb_attempts: int):
        """Internal search function launched on each worker.
        This is just a regular function that has been declared as a staticmethod of _ExampleSearchClass for namespacing purposes.
        It could be declared outside of the class as a regular function, but as it is linked to the class internals using a staticmethod is more idiomatic.
        Note that it could NOT be a regular method as `campaign` is always passed as the first argument.

        All data needed for the search MUST be in argument, as python needs to know what has to be sent to workers.
        Here `self` (could be named anything like 'instance') is the class instance with the search data,
        but only because this is what has been passed to `workers.campaign(...)` in `search_target_index`.
        
        Every result that you want to retrieve MUST be sent back in the return.
        Here it is only the index if found. For multiple data, use a tuple / dict / class.
        """
        for attempt in range(worker_nb_attempts): # This is a timeout, completely optional, a variant could be time based.
            if campaign.has_ended():  # Check for other worker's completion, main difference with non-parallel code
                return None
            index = self.draw_index()
            if self.check_index(index):
                return index
        return None # Not really needed as None is implicitly returned

####################################### Benchmark ########################################

def benchmark(n_cpu: int = torch.get_num_threads(), n_repeat: int = 40):
    """Synthetic load benchmark with torch.mm. Compares time-to-completion of:
    - intra-op:  (n_cpu * n_repeat) * work using intra-op paralellism over n_cpu.
    - workers: n_cpu workers doing n_repeat * work in single_thread mode."""
    import timeit
    if n_cpu != torch.get_num_threads():
        torch.set_num_threads(n_cpu)
    # Intra-op
    timeit.timeit(_bench_expensive_work, number = n_cpu * n_repeat) # warmup
    print("intra-op", timeit.timeit(_bench_expensive_work, number = n_cpu * n_repeat))
    # Workers
    with SearchWorkers(method = SpawnMethod.Fork, num_workers = n_cpu) as workers:
        workers.campaign(_bench_search_fn, n_repeat) # warmup
        print("workers fork", timeit.timeit(lambda: workers.campaign(_bench_search_fn, n_repeat), number = 1))
    with SearchWorkers(method = SpawnMethod.ForkServer, num_workers = n_cpu) as workers:
        workers.campaign(_bench_search_fn, n_repeat) # warmup
        print("workers forkserver", timeit.timeit(lambda: workers.campaign(_bench_search_fn, n_repeat), number = 1))
    with SearchWorkers(method = SpawnMethod.Spawn, num_workers = n_cpu) as workers:
        workers.campaign(_bench_search_fn, n_repeat) # warmup
        print("workers spawn", timeit.timeit(lambda: workers.campaign(_bench_search_fn, n_repeat), number = 1))

def _bench_expensive_work():
    n = 400
    torch.mm(torch.rand(n, n), torch.rand(n, n))

def _bench_search_fn(campaign: Campaign, n_repeat: int):
    iterations = 0
    while campaign.is_running() and iterations < n_repeat:
        _bench_expensive_work()
        iterations += 1
