import math
import multiprocessing as mp
import os
from multiprocessing.pool import ThreadPool
from typing import Any, Callable, List, TypeVar
import psutil
from tqdm.auto import tqdm

T = TypeVar("T")
U = TypeVar("U")


def num_fitting_processes(cpus_per_process: float = 2, memory_per_process: float = 16) -> int:
    """
    Returns the number of processes that can be fitted onto the machine when using a particular
    number of CPUs and a particular amount of memory for every process.

    Args:
        cpus_per_process: The number of CPUs that every process requires.
        memory_per_process: The memory in GiB that every process needs.

    Returns:
        The number of processes to use.
    """
    num_processes_cpu = math.floor(os.cpu_count() / cpus_per_process)

    available_gib = psutil.virtual_memory().total / (1024 ** 3)
    num_processes_memory = math.floor(available_gib / memory_per_process)

    return min(num_processes_cpu, num_processes_memory)


def run_parallel(execute: Callable[[T], U], data: List[T], num_processes: int) -> List[U]:
    """
    Runs a function on multiple processes, parallelizing computations for the provided data.

    Args:
        execute: The function to run in each process.
        data: The data items to execute the function for.
        num_processes: The number of processes to parallelize over.

    Returns:
        The outputs of the function calls, ordered in the same way as the data.
    """
    # Initialize queues and put all items into input queue. Also, put as many "done" items into the
    # queue as we have processes
    inputs = mp.Queue()
    outputs = mp.Queue()
    for i, item in enumerate(data):
        inputs.put((i, item))
    for _ in range(num_processes):
        inputs.put((0, None))  # this will shut the workers down one after the other

    # Create the processes in a thread pool to speed up creation
    def factory(_i: int) -> mp.Process:
        process = mp.Process(target=_worker, args=(execute, inputs, outputs))
        process.start()
        return process

    with ThreadPool(num_processes) as p:
        processes = p.map_async(factory, range(num_processes))

        # Parallelize execution -- keep this inside the with statement so processes keep getting
        # spawned
        result = [None] * len(data)
        with tqdm(total=len(data)) as progress:
            progress.set_postfix({"num_processes": num_processes})
            for _ in range(len(data)):
                i, output = outputs.get()
                result[i] = output
                progress.update()

        for p in processes.get():
            p.kill()

    # cleanup
    inputs.close()
    outputs.close()
    return result


def _worker(execute: Callable[[Any], Any], inputs: mp.Queue, outputs: mp.Queue) -> None:
    while True:
        i, data = inputs.get()
        if data is None:
            return
        output = execute(data)
        outputs.put((i, output))
