#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import torch

def convert_np_img_to_torch(img, ensure_color_channels=False, device=None):
    if ensure_color_channels and len(img.shape) == 2:
        img = np.tile(img[:,:,None], (1, 1, 3))
    img = torch.from_numpy(np.moveaxis(img, 2, 0)[None])
    if device is not None:
        img = img.to(device=device)
    return img

def convert_torch_img_to_np(img):
    if img.shape[1] == 1:
        return img[0, 0, :, :].detch().cpu().numpy()
    else:
        return np.moveaxis(img[0, :, :, :].detach().cpu().numpy(), 0, 2)