import torch


def check_device(tensor, device):
    if tensor.device != device:
        tensor = tensor.to(device)
    return tensor


def check_image(tensor):
    assert torch.max(tensor) <= 1. and torch.min(tensor) >= - \
        1., "Output images should be (-1, 1.)"


def normalize_tensor(tensor):
    check_image(tensor)
    return (tensor + 1.) / 2.
