import logging

import torch
from torch._prims_common import DeviceLikeType

logger = logging.getLogger(__name__)


def set_device(device: DeviceLikeType) -> torch.device:
    if device == "cpu":
        device = torch.device("cpu")
    elif isinstance(device, str) and device.startswith("cuda"):
        assert torch.cuda.is_available()
        device = torch.device(device)
    else:
        raise ValueError
    logger.info(f'Set device to: "{device}"')
    return device
