import contextlib
from functools import partial
from mmengine.device import get_device
import torch


def get_amp_context(enabled=True):
    device_str = get_device()  # e.g. 'cuda', 'npu', 'mps', 'cpu'
    # 2) Select appropriate autocast context manager
    if device_str.startswith("cuda"):
        amp_cm = partial(torch.amp.autocast, device_type="cuda", enabled=enabled)
    elif device_str.startswith("npu"):
        try:
            amp_cm = partial(torch.amp.autocast, device_type="npu", enabled=enabled)
        except ImportError:
            amp_cm = lambda *args, **kwargs: contextlib.nullcontext()
    else:
        # CPU, MPS, etc. — just no-op
        amp_cm = lambda *args, **kwargs: contextlib.nullcontext()
    return amp_cm
