from __future__ import annotations

import torch


def get_device(preferred: str | None = None) -> torch.device:
    if preferred is None:
        preferred = "cuda"
    if preferred.startswith("cuda") and torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


class Autocast:
    def __init__(self, device_type: str = "cuda", enabled: bool | None = None):
        if enabled is None:
            enabled = (device_type == "cuda") and torch.cuda.is_available()
        self.enabled = enabled
        self.device_type = device_type
        self.ctx = None

    def __enter__(self):
        if self.enabled:
            self.ctx = torch.autocast(self.device_type)
            self.ctx.__enter__()
        return self

    def __exit__(self, exc_type, exc, tb):
        if self.ctx is not None:
            return self.ctx.__exit__(exc_type, exc, tb)
        return False

