"""Device parsing utilities.

Deterministic, minimal helper to normalize device inputs across CLIs/tests.
"""

from __future__ import annotations

import re

import torch


def parse_device(value: str | int | torch.device) -> torch.device:
    """Parse a device spec into a canonical ``torch.device``.

    Accepts "cpu", "cuda", "cuda:<idx>", integers, or numeric strings.

    Args:
        value: Device specification.

    Returns:
        torch.device: Canonical torch device.

    Raises:
        ValueError: If CUDA is requested but unavailable or index is invalid.

    """
    if isinstance(value, torch.device):
        dev = value
    elif isinstance(value, int):
        # Interpret bare integers as CUDA indices
        dev = torch.device(f"cuda:{value}")
    elif isinstance(value, str):
        s = value.strip().lower()
        if s == "cpu":
            dev = torch.device("cpu")
        elif s == "cuda":
            dev = torch.device("cuda:0")
        elif re.fullmatch(r"\d+", s):
            dev = torch.device(f"cuda:{s}")
        elif re.fullmatch(r"cuda:\d+", s):
            dev = torch.device(s)
        else:
            msg = f"Unrecognized device spec: {value}"
            raise TypeError(msg)
    else:
        msg = f"Unrecognized device spec: {value}"
        raise TypeError(msg)

    # Validate CUDA availability and index
    if dev.type == "cuda":
        if not torch.cuda.is_available():
            msg = "CUDA requested but not available"
            raise ValueError(msg)

        idx = dev.index if dev.index is not None else 0
        if idx < 0 or idx >= torch.cuda.device_count():
            msg = f"Requested CUDA device index {idx} out of range"
            raise ValueError(msg)

        dev = torch.device(f"cuda:{idx}")

    return dev


__all__ = ["parse_device"]
