import torch

class DeviceUtils:
    """
    Utility class for determining training devices and mixed precision support.

    Example:
    >>> device = DeviceUtils.get_training_device()
    >>> print("Training on:", device)

    >>> dtype = DeviceUtils.get_mixed_precision_dtype()
    >>> print("Using precision:", dtype)
    """

    @staticmethod
    def get_training_device() -> torch.device:
        """
        Returns the best available training device in the order of:
        CUDA > CPU

        Returns:
            torch.device: The selected training device.
        """
        # if torch.backends.mps.is_available() and torch.backends.mps.is_built():
        #     return torch.device("mps")
        if torch.cuda.is_available():
            return torch.device("cuda")
        else:
            return torch.device("cpu")

    @staticmethod
    def get_mixed_precision_dtype(default_dtype: torch.dtype = torch.float32) -> torch.dtype:
        """
        Determines whether bfloat16 is supported on the selected device.
        If so, returns torch.bfloat16; otherwise, returns the given default dtype.

        Args:
            default_dtype (torch.dtype): The fallback dtype if bf16 is not supported.

        Returns:
            torch.dtype: Either torch.bfloat16 or the fallback dtype.
        """
        device = DeviceUtils.get_training_device()

        # CUDA >= Ampere GPUs support bfloat16
        if device.type == "cuda":
            major, _ = torch.cuda.get_device_capability(device)
            if major >= 8:  # Ampere or newer
                return torch.bfloat16

        # # MPS supports bf16 as of PyTorch 2.1+
        # if device.type == "mps":
        #     # No strict way to query BF16 support on MPS, assume True if built
        #     return torch.bfloat16
        if device.type == 'cpu':
            return torch.bfloat16
        
        return default_dtype