import numpy as np
import torch

DTYPE = torch.float
DEVICE = 'cuda:0'

def to_np(x):
    if torch.is_tensor(x):
        x = x.detach().cpu().numpy()
    return x

def to_torch(x, dtype=None, device=None):
    dtype = dtype or DTYPE
    device = device or DEVICE
    return torch.tensor(x, dtype=dtype, device=device)

def to_device(*xs, device=DEVICE):
    return [x.to(device) for x in xs]

def normalize(x):
    """
        scales `x` to [0, 1]
    """
    x = x - x.min()
    x = x / x.max()
    return x

def to_img(x):
    normalized = normalize(x)
    array = to_np(normalized)
    array = np.transpose(array, (1,2,0))
    return (array * 255).astype(np.uint8)

def set_device(device):
    global DEVICE
    DEVICE = device
    if 'cuda' in device:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
        torch.cuda.set_device(device)
