import contextlib
import time
from typing import Any, Callable, Dict, Tuple

try:  # Prefer resource over psutil to avoid extra deps
    import resource  # type: ignore
except Exception:  # pragma: no cover
    resource = None  # type: ignore

import torch


def _get_rss_mb() -> float:
    """Return current/max resident set size in MiB when available."""
    try:
        if resource is None:
            return 0.0
        usage = resource.getrusage(resource.RUSAGE_SELF)
        # ru_maxrss is kilobytes on Linux, bytes on macOS; normalize to MiB
        rss_kb = float(usage.ru_maxrss)
        # Heuristic: assume Linux (kilobytes)
        rss_mb = rss_kb / 1024.0
        # If value is suspiciously small, treat ru_maxrss as bytes
        if rss_mb < 1.0:
            rss_mb = rss_kb / (1024.0 * 1024.0)
        return rss_mb
    except Exception:
        return 0.0


@contextlib.contextmanager
def record_cost(device: str = "cpu"):
    """Context manager to record wall time, CUDA time and peak memory.

    Yields a dict that will be populated with cost stats on exit.
    Keys set: time_s, cuda_time_ms, gpu_mem_mb, cpu_peak_rss_mb
    """
    stats: Dict[str, float] = {"time_s": 0.0, "cuda_time_ms": 0.0, "gpu_mem_mb": 0.0, "cpu_peak_rss_mb": 0.0}

    # CPU time
    t0 = time.perf_counter()

    # GPU metrics
    use_cuda = device.startswith("cuda") and torch.cuda.is_available()
    start_evt = end_evt = None
    if use_cuda:
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()
        start_evt = torch.cuda.Event(enable_timing=True)
        end_evt = torch.cuda.Event(enable_timing=True)
        start_evt.record()

    try:
        yield stats
    finally:
        # Stop timers
        if use_cuda and end_evt is not None and start_evt is not None:
            end_evt.record()
            torch.cuda.synchronize()
            stats["cuda_time_ms"] = float(start_evt.elapsed_time(end_evt))
            stats["gpu_mem_mb"] = float(torch.cuda.max_memory_allocated() / (1024 * 1024))
        stats["time_s"] = float(time.perf_counter() - t0)
        stats["cpu_peak_rss_mb"] = _get_rss_mb()


def run_with_cost(fn: Callable[..., Any], *args, device: str = "cpu", **kwargs) -> Tuple[Any, Dict[str, float]]:
    """Execute a callable and capture cost metrics.

    Returns (fn_result, stats_dict).
    """
    with record_cost(device=device) as stats:
        result = fn(*args, **kwargs)
    return result, stats

