"""PyTorch-related helper functions for device and dtype resolution.

Centralizes logic for selecting execution devices and mapping user-friendly
string dtype specifiers to torch.dtype objects.
"""
from __future__ import annotations

from typing import Dict

import torch

import constants

_DTYPE_MAP: Dict[str, torch.dtype] = {
    constants.DTYPE_FP32: torch.float32,
    constants.DTYPE_ALIAS_FP32: torch.float32,
    constants.DTYPE_FP16: torch.float16,
    constants.DTYPE_ALIAS_FP16: torch.float16,
    constants.DTYPE_BF16: torch.bfloat16,
    constants.DTYPE_ALIAS_BF16: torch.bfloat16,
}


def setup_device(device_str: str) -> torch.device:
    """Set up a device string into a torch.device.

    'auto' preference order: CUDA > MPS > CPU. Explicit CUDA/MPS choices are
    validated; a clear RuntimeError is raised if unavailable.
    """
    d = device_str.strip().lower()

    if d == constants.DEFAULT_DEVICE or d == constants.DEVICE_AUTO:
        if torch.cuda.is_available():
            return torch.device(constants.DEVICE_CUDA)
        if getattr(torch.backends, constants.DEVICE_MPS, None) and torch.backends.mps.is_available():
            return torch.device(constants.DEVICE_MPS)
        return torch.device(constants.DEVICE_CPU)

    if d.startswith(constants.DEVICE_CUDA):
        if not torch.cuda.is_available():
            raise RuntimeError(
                f"CUDA requested ('{device_str}'), but torch.cuda.is_available() is False."
            )
        return torch.device(device_str)

    if d.startswith(constants.DEVICE_MPS):
        if not (getattr(torch.backends, constants.DEVICE_MPS, None) and torch.backends.mps.is_available()):
            raise RuntimeError(
                f"MPS requested ('{device_str}'), but torch.backends.mps.is_available() is False."
            )
        return torch.device(constants.DEVICE_MPS)

    if d == constants.DEVICE_CPU:
        return torch.device(constants.DEVICE_CPU)

    raise RuntimeError(f"Unsupported or unknown device string: '{device_str}'")


def resolve_dtype(dtype_str: str) -> torch.dtype:
    """Map a user-specified dtype string to a torch.dtype.

    Supports canonical and short alias forms:
      - float32 / fp32 -> torch.float32
      - float16 / fp16 -> torch.float16
      - bfloat16 / bf16 -> torch.bfloat16
    """
    key = dtype_str.strip().lower()
    try:
        return _DTYPE_MAP[key]
    except KeyError as e:  # pragma: no cover - error path
        raise ValueError(
            f"Unsupported dtype string '{dtype_str}'. Supported: {sorted(_DTYPE_MAP.keys())}"  # noqa: E501
        ) from e


__all__ = ["setup_device", "resolve_dtype"]

