import torch

from typing import Dict


def set_device(
    batch: Dict[str, torch.Tensor],
    device: torch.device,
) -> Dict[str, torch.Tensor]:
    """ Set the device.

    Args:
        batch (Dict[str, torch.Tensor]): The batch.
        device (torch.device): The device.

    Returns:
        Dict[str, torch.Tensor]: The set device batch.
    """

    output_batch = {}

    for key, value in batch.items():
        try:
            output_batch[key] = value.to(device=device)
        except:
            output_batch[key] = value

            pass

    return output_batch
