import logging
import os
import signal
import multiprocessing as mp
from multiprocessing.pool import AsyncResult
from signal import signal, SIGINT, SIG_IGN
from typing import List, Optional

import networkx as nx

from egr.algo_wrappers import SubgraphIsomorphism


LOG = logging.getLogger(__name__)
Results = List[AsyncResult]

POOL_ARGS = dict(initializer=signal, initargs=(SIGINT, SIG_IGN))


def compute_isomorphism(
    G: nx.Graph,
    H: nx.Graph,
    indices: List[int],
    nproc: Optional[int] = None,
    timeout: Optional[int] = None,
) -> bool:
    nproc = nproc or os.cpu_count()

    num_indices: int = len(indices)

    checker = SubgraphIsomorphism(G, H)

    LOG.debug('Invoking %d parallel processes for %d jobs', nproc, num_indices)
    num_timeouts = 0
    with mp.Pool(processes=nproc, **POOL_ARGS) as pool:

        def schedule(idx, n) -> AsyncResult:
            return pool.apply_async(checker.is_isomorphic, args=(idx, n))

        procs: Results = [schedule(idx, n) for n, idx in enumerate(indices)]
        if timeout is None:
            results = compute(num_indices, procs)
        else:
            results, num_timeouts = timed_compute(num_indices, procs, timeout)

    return results, num_timeouts


def compute(num_indices, procs):
    LOG.debug('computing in parallel without timeout')
    results: List[bool] = [False] * num_indices
    counter: mp.Value = mp.Value('i', 0, lock=True)
    while counter.value < num_indices:
        for result in procs:
            if result.ready():
                n, iso = result.get()
                results[n] = iso

                with counter.get_lock():
                    counter.value += 1
            else:
                result.wait(0.01)
    return results


def timed_compute(num_indices, procs, timeout):
    LOG.debug('computing in parallel with timeout of %d seconds', timeout)
    results: List[bool] = [False] * num_indices
    num_timeouts: int = 0
    for proc in procs:
        try:
            n, iso = proc.get(timeout=timeout)
            results[n] = iso
        except mp.TimeoutError as err:
            LOG.debug('%s', err)
            num_timeouts += 1
    return results, num_timeouts
