# distributed/parallel_runner.py
"""
Simple multi-process / multi-GPU parallel runner for experiments.

This module offers two modes:
  - local multiprocessing (useful for CPU-bound parallelism)
  - torch.multiprocessing / spawn for multi-GPU runs

Provides:
  - run_in_parallel(func, args_list, n_workers): generic local multiprocess runner
  - run_on_gpus(worker_fn, n_gpus, args_per_process): spawn processes and assign one GPU each

Caveats:
  - For robust multi-node or large-scale experiments prefer established frameworks:
    torch.distributed.launch / torchrun, Ray, Slurm job arrays, or Dask.
"""
import multiprocessing as mp
import os
import traceback
from typing import Callable, Iterable, List, Any, Tuple

def _worker_wrapper(fn: Callable, args, result_queue: mp.Queue, worker_id: int):
    """
    Internal wrapper executed in each process.
    Captures exceptions and returns results via result_queue.
    """
    try:
        res = fn(*args)
        result_queue.put((worker_id, True, res))
    except Exception as e:
        tb = traceback.format_exc()
        result_queue.put((worker_id, False, (e, tb)))

def run_in_parallel(fn: Callable, args_list: List[Tuple], n_workers: int = None, timeout: float = None):
    """
    Run fn in parallel on multiple processes.
    - fn: function to call; should be picklable
    - args_list: list of arg tuples, one per job
    - n_workers: number of parallel processes (defaults to mp.cpu_count())
    Returns list of results in the same order as args_list (exceptions included).
    """
    if n_workers is None:
        n_workers = min(len(args_list), mp.cpu_count())
    manager = mp.Manager()
    result_q = manager.Queue()
    procs = []
    for i, args in enumerate(args_list):
        p = mp.Process(target=_worker_wrapper, args=(fn, args, result_q, i))
        p.start()
        procs.append(p)
    results = [None] * len(args_list)
    collected = 0
    while collected < len(args_list):
        try:
            worker_id, ok, payload = result_q.get(timeout=timeout)
        except Exception:
            break
        if ok:
            results[worker_id] = payload
        else:
            # store exception tuple
            results[worker_id] = payload
        collected += 1
    for p in procs:
        p.join(timeout=1.0)
    return results

# Torch-based GPU spawn runner
def _torch_worker_spawn(worker_fn: Callable, args, gpu_id: int):
    """
    Worker that sets CUDA_VISIBLE_DEVICES or torch.cuda.set_device and runs worker_fn.
    """
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    try:
        import torch
        if torch.cuda.is_available():
            torch.cuda.set_device(0)  # since we mask visible devices to one
    except Exception:
        pass
    return worker_fn(*args)

def run_on_gpus(worker_fn: Callable, args_list: List[Tuple], gpu_ids: List[int]):
    """
    Spawn one process per GPU (len(gpu_ids)) and run worker_fn on each with corresponding args.
    - worker_fn: picklable function
    - args_list: list of args tuples length <= len(gpu_ids)
    - gpu_ids: list of GPU ids to use (e.g., [0,1,2])
    Returns results list.
    """
    if len(args_list) > len(gpu_ids):
        raise ValueError("args_list length must be <= number of provided GPUs")
    mp.set_start_method("spawn", force=True)
    procs = []
    result_q = mp.Queue()
    for i, args in enumerate(args_list):
        gpu = gpu_ids[i]
        p = mp.Process(target=_worker_wrapper, args=(lambda *a: _torch_worker_spawn(worker_fn, a, gpu), args, result_q, i))
        p.start()
        procs.append(p)
    results = [None] * len(args_list)
    collected = 0
    while collected < len(args_list):
        worker_id, ok, payload = result_q.get()
        if ok:
            results[worker_id] = payload
        else:
            results[worker_id] = payload
        collected += 1
    for p in procs:
        p.join()
    return results
