from __future__ import annotations

import csv
import contextlib
import io
import os
import re
import subprocess
import tempfile
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Mapping, Sequence

import torch


_NCU_VERSION_TRIPLET_CACHE: Dict[str, str] = {}


@dataclass(frozen=True)
class CudaPeakMemory:
    baseline_allocated_bytes: int
    baseline_reserved_bytes: int
    peak_allocated_bytes: int
    peak_reserved_bytes: int
    delta_allocated_bytes: int
    delta_reserved_bytes: int


def sync_cuda(device: torch.device) -> None:
    if device.type == "cuda" and torch.cuda.is_available():
        torch.cuda.synchronize(device)


def slugify(value: str) -> str:
    cleaned = re.sub(r"[^a-zA-Z0-9]+", "_", str(value).strip().lower())
    cleaned = cleaned.strip("_")
    return cleaned or "range"


@contextlib.contextmanager
def nvtx_range(name: str) -> Iterable[None]:
    if not (torch.cuda.is_available() and hasattr(torch.cuda, "nvtx")):
        yield
        return
    pushed = False
    try:
        torch.cuda.nvtx.range_push(str(name))
        pushed = True
    except Exception:
        pushed = False
    try:
        yield
    finally:
        if pushed:
            try:
                torch.cuda.nvtx.range_pop()
            except Exception:
                pass


def measure_cuda_peak_memory(
    fn: Callable[[], None],
    *,
    device: torch.device,
    empty_cache: bool = False,
) -> CudaPeakMemory:
    if device.type != "cuda" or not torch.cuda.is_available():
        raise RuntimeError("measure_cuda_peak_memory requires a CUDA device.")

    if empty_cache:
        sync_cuda(device)
        torch.cuda.empty_cache()
        sync_cuda(device)

    sync_cuda(device)
    baseline_alloc = int(torch.cuda.memory_allocated(device))
    baseline_reserved = int(torch.cuda.memory_reserved(device))
    torch.cuda.reset_peak_memory_stats(device)
    fn()
    sync_cuda(device)

    peak_alloc = int(torch.cuda.max_memory_allocated(device))
    peak_reserved = int(torch.cuda.max_memory_reserved(device))
    return CudaPeakMemory(
        baseline_allocated_bytes=baseline_alloc,
        baseline_reserved_bytes=baseline_reserved,
        peak_allocated_bytes=peak_alloc,
        peak_reserved_bytes=peak_reserved,
        delta_allocated_bytes=max(0, peak_alloc - baseline_alloc),
        delta_reserved_bytes=max(0, peak_reserved - baseline_reserved),
    )


def get_process_rss_bytes() -> int:
    if os.name == "nt":
        try:
            import ctypes
            from ctypes import wintypes

            class PROCESS_MEMORY_COUNTERS(ctypes.Structure):
                _fields_ = [
                    ("cb", wintypes.DWORD),
                    ("PageFaultCount", wintypes.DWORD),
                    ("PeakWorkingSetSize", ctypes.c_size_t),
                    ("WorkingSetSize", ctypes.c_size_t),
                    ("QuotaPeakPagedPoolUsage", ctypes.c_size_t),
                    ("QuotaPagedPoolUsage", ctypes.c_size_t),
                    ("QuotaPeakNonPagedPoolUsage", ctypes.c_size_t),
                    ("QuotaNonPagedPoolUsage", ctypes.c_size_t),
                    ("PagefileUsage", ctypes.c_size_t),
                    ("PeakPagefileUsage", ctypes.c_size_t),
                ]

            counters = PROCESS_MEMORY_COUNTERS()
            counters.cb = ctypes.sizeof(PROCESS_MEMORY_COUNTERS)
            handle = ctypes.windll.kernel32.GetCurrentProcess()
            ok = ctypes.windll.psapi.GetProcessMemoryInfo(handle, ctypes.byref(counters), counters.cb)
            if not ok:
                return 0
            return int(counters.WorkingSetSize)
        except Exception:
            return 0

    try:
        with open("/proc/self/statm", "r", encoding="utf-8") as handle:
            parts = handle.read().strip().split()
        rss_pages = int(parts[1]) if len(parts) > 1 else 0
        return int(rss_pages * os.sysconf("SC_PAGE_SIZE"))
    except Exception:
        return 0


class RssSampler:
    def __init__(self, interval_sec: float = 0.001) -> None:
        self.interval_sec = max(1e-4, float(interval_sec))
        self.baseline_bytes = 0
        self.peak_bytes = 0
        self._stop = threading.Event()
        self._thread = threading.Thread(target=self._run, daemon=True)

    def _run(self) -> None:
        while not self._stop.is_set():
            rss = get_process_rss_bytes()
            if rss > self.peak_bytes:
                self.peak_bytes = rss
            time.sleep(self.interval_sec)

    def __enter__(self) -> "RssSampler":
        self.baseline_bytes = get_process_rss_bytes()
        self.peak_bytes = self.baseline_bytes
        self._thread.start()
        return self

    def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
        self._stop.set()
        self._thread.join(timeout=1.0)
        return False


def safe_float(text: str | None) -> float | None:
    if text is None:
        return None
    cleaned = str(text).strip().replace(",", "")
    if cleaned in {"", "N/A", "n/a", "nan", "NaN"}:
        return None
    try:
        return float(cleaned)
    except ValueError:
        return None


@dataclass(frozen=True)
class NcuFlops:
    metrics: Dict[str, float]

    @property
    def total(self) -> float:
        return float(sum(self.metrics.values()))


def parse_ncu_csv_metrics(text: str, *, metrics: Sequence[str]) -> NcuFlops:
    totals: Dict[str, float] = {m: 0.0 for m in metrics}

    reader = csv.reader(io.StringIO(text))
    idx_metric_name: int | None = None
    idx_metric_value: int | None = None
    idx_kernel_name: int | None = None

    def _maybe_header(row: List[str]) -> bool:
        lowered = [cell.strip().lower() for cell in row]
        has_name = any("metric name" in cell for cell in lowered)
        has_value = any("metric value" in cell for cell in lowered)
        if not (has_name and has_value):
            return False
        nonlocal idx_metric_name, idx_metric_value, idx_kernel_name
        idx_metric_name = next((i for i, cell in enumerate(lowered) if "metric name" in cell), None)
        idx_metric_value = next((i for i, cell in enumerate(lowered) if "metric value" in cell), None)
        idx_kernel_name = next((i for i, cell in enumerate(lowered) if "kernel name" in cell), None)
        return idx_metric_name is not None and idx_metric_value is not None

    for row in reader:
        if not row:
            continue
        if _maybe_header(row):
            continue

        # Preferred parse path: use explicit Metric Name / Metric Value columns.
        if idx_metric_name is not None and idx_metric_value is not None and len(row) > max(idx_metric_name, idx_metric_value):
            metric = row[idx_metric_name].strip()
            if metric in totals:
                value = safe_float(row[idx_metric_value])
                if value is None:
                    continue
                if idx_kernel_name is not None and idx_kernel_name < len(row):
                    kernel = row[idx_kernel_name].strip().lower()
                    if kernel in {"", "summary", "total"} or kernel.startswith("summary"):
                        continue
                totals[metric] += float(value)
            continue

        # Fallback parse: locate the metric token anywhere in the row and take the last numeric field.
        trimmed_cells = {cell.strip() for cell in row}
        for metric in metrics:
            if metric not in trimmed_cells:
                continue
            value: float | None = None
            for cell in reversed(row):
                value = safe_float(cell)
                if value is not None:
                    break
            if value is None:
                continue
            totals[metric] += float(value)
            break

    totals = {k: float(v) for k, v in totals.items() if float(v) > 0.0}
    return NcuFlops(metrics=totals)


def run_ncu_metrics(
    target_cmd: Sequence[str],
    *,
    ncu_path: str = "ncu",
    nvtx_include: str | None = None,
    metrics: Sequence[str],
    cwd: str | None = None,
    env: Mapping[str, str] | None = None,
    timeout_sec: float | None = None,
) -> NcuFlops:
    def _ncu_version_triplet(run_env: Mapping[str, str]) -> str | None:
        cache_key = str(ncu_path)
        cached = _NCU_VERSION_TRIPLET_CACHE.get(cache_key)
        if cached:
            return cached

        try:
            proc = subprocess.run(
                [str(ncu_path), "--version"],
                cwd=cwd,
                env=dict(run_env),
                capture_output=True,
                text=True,
                timeout=10,
            )
        except Exception:
            return None

        text = "\n".join([(proc.stdout or "").strip(), (proc.stderr or "").strip()]).strip()
        match = re.search(r"Version\s+(\d+\.\d+\.\d+)", text)
        if not match:
            return None
        triplet = match.group(1)
        _NCU_VERSION_TRIPLET_CACHE[cache_key] = triplet
        return triplet

    def _is_writable_directory(path: Path) -> bool:
        try:
            path.mkdir(parents=True, exist_ok=True)
        except Exception:
            return False
        probe = path / f".write_probe_{os.getpid()}_{time.time_ns()}"
        try:
            probe.write_text("probe", encoding="utf-8")
            probe.unlink()
            return True
        except Exception:
            try:
                if probe.exists():
                    probe.unlink()
            except Exception:
                pass
            return False

    def _extract_section_deploy_dir(detail: str) -> Path | None:
        # Example:
        # ==WARNING== Could not deploy stock section files to "/tmp/.../Documents/NVIDIA Nsight Compute/2021.3.1/Sections".
        match = re.search(r'Could not deploy stock section files to "([^"]+)"', detail or "")
        if not match:
            return None
        try:
            return Path(match.group(1))
        except Exception:
            return None

    def _looks_like_section_deploy_error(detail: str) -> bool:
        lowered = (detail or "").lower()
        return any(
            token in lowered
            for token in (
                "set the home environment variable",
                "could not deploy stock section files",
                "failed to add default section search path",
                "failed to add rules search path",
            )
        )

    def _ensure_ncu_user_dirs(run_env: Dict[str, str]) -> None:
        home = run_env.get("HOME")
        if not home:
            return
        triplet = _ncu_version_triplet(run_env)
        if not triplet:
            return
        root = Path(home) / "Documents" / "NVIDIA Nsight Compute" / str(triplet)
        _is_writable_directory(root)
        _is_writable_directory(root / "Sections")
        _is_writable_directory(root / "Rules")

    def _patched_env_with_writable_home(base_env: Mapping[str, str], cwd: str | None) -> Dict[str, str]:
        patched: Dict[str, str] = dict(base_env)
        suffix = str(os.getuid()) if hasattr(os, "getuid") else (patched.get("USERNAME") or str(os.getpid()))
        candidates: List[Path] = [Path(tempfile.gettempdir()) / f"ncu_home_{suffix}"]
        if cwd:
            candidates.append(Path(cwd) / f".ncu_home_{suffix}")
        candidates.append(Path(tempfile.gettempdir()))

        for candidate in candidates:
            if _is_writable_directory(candidate) and _is_writable_directory(candidate / "Documents"):
                patched["HOME"] = str(candidate)
                return patched
        return patched

    cmd: List[str] = [str(ncu_path), "--csv", "--page", "raw", "--metrics", ",".join(metrics)]
    if nvtx_include:
        cmd += ["--nvtx", "--nvtx-include", str(nvtx_include)]

    # Nsight Compute may fail hard if its default stock 'sections' directory is missing (common in minimal containers).
    # Providing an explicit (even empty) section search path avoids relying on the installation layout.
    section_folder_suffix = str(os.getuid()) if hasattr(os, "getuid") else str(os.getpid())
    section_folder = Path(tempfile.gettempdir()) / f"ncu_sections_{section_folder_suffix}"
    _is_writable_directory(section_folder)
    cmd += ["--apply-rules", "no", "--section-folder", str(section_folder)]

    cmd += [str(part) for part in target_cmd]

    base_env: Dict[str, str] = dict(env) if env is not None else dict(os.environ)
    run_env = base_env
    if not (run_env.get("HOME") and _is_writable_directory(Path(run_env["HOME"])) and _is_writable_directory(Path(run_env["HOME"]) / "Documents")):
        run_env = _patched_env_with_writable_home(run_env, cwd=cwd)
    _ensure_ncu_user_dirs(run_env)

    proc = subprocess.run(cmd, cwd=cwd, env=run_env, capture_output=True, text=True, timeout=timeout_sec)
    if proc.returncode != 0:
        stderr = (proc.stderr or "").strip()
        stdout = (proc.stdout or "").strip()
        detail = stderr or stdout
        if _looks_like_section_deploy_error(detail):
            deploy_dir = _extract_section_deploy_dir(detail)
            if deploy_dir is not None:
                _is_writable_directory(deploy_dir)
                _is_writable_directory(deploy_dir.parent / "Rules")

            retry = subprocess.run(cmd, cwd=cwd, env=run_env, capture_output=True, text=True, timeout=timeout_sec)
            if retry.returncode == 0:
                proc = retry
            else:
                patched_env = _patched_env_with_writable_home(run_env, cwd=cwd)
                if patched_env.get("HOME") and patched_env.get("HOME") != run_env.get("HOME"):
                    _ensure_ncu_user_dirs(patched_env)
                    retry2 = subprocess.run(
                        cmd,
                        cwd=cwd,
                        env=patched_env,
                        capture_output=True,
                        text=True,
                        timeout=timeout_sec,
                    )
                    if retry2.returncode == 0:
                        proc = retry2
                    else:
                        retry_detail = ((retry.stderr or "").strip() or (retry.stdout or "").strip())
                        retry2_detail = ((retry2.stderr or "").strip() or (retry2.stdout or "").strip())
                        raise RuntimeError(
                            f"ncu failed (exit={proc.returncode}). {detail}\n"
                            f"Retry after creating Sections/Rules failed (exit={retry.returncode}). {retry_detail}\n"
                            f"Retry with HOME={patched_env.get('HOME')} failed (exit={retry2.returncode}). {retry2_detail}"
                        )
                else:
                    retry_detail = ((retry.stderr or "").strip() or (retry.stdout or "").strip())
                    raise RuntimeError(
                        f"ncu failed (exit={proc.returncode}). {detail}\n"
                        f"Retry after creating Sections/Rules failed (exit={retry.returncode}). {retry_detail}"
                    )
        else:
            raise RuntimeError(f"ncu failed (exit={proc.returncode}). {detail}")

    parsed = parse_ncu_csv_metrics(proc.stdout or "", metrics=metrics)
    if not parsed.metrics:
        raise RuntimeError(
            "ncu ran but no requested metrics were parsed. "
            "Double-check `--nvtx-include` matches your NVTX range name and that the selected metrics exist."
        )
    return parsed
