import numpy as np
from PIL import Image
import torch


def to_tensor(image):
    if isinstance(image, str):
        image = Image.open(image)

    array = np.array(image)
    tensor = torch.from_numpy(array)
    tensor = tensor.permute(2, 0, 1)
    tensor = tensor.unsqueeze(0)
    return tensor.float()


def to_pil(tensor):
    tensor = tensor.detach().cpu()[0]
    tensor = tensor.permute(1, 2, 0)
    array = tensor.to(torch.uint8).numpy()
    image = Image.fromarray(array)
    return image


def get_random_crop_size(H, W):
    if H <= W:
        ratio = torch.tensor(0.).uniform_(0.7, 1)
        h = round(H * ratio.item())
        scale = torch.tensor(0.).uniform_(0.9, 1.1)
        w = round(W / H * h * scale.item())
        w = min(w, W)
    else:
        w, h = get_random_crop_size(W, H)
    return h, w


def get_random_crop(H, W):
    ch, cw = get_random_crop_size(H, W)
    h = torch.randint(H - ch + 1, size=[]).item()
    w = torch.randint(W - cw + 1, size=[]).item()
    return slice(h, h + ch), slice(w, w + cw)
