import os

from PIL import Image
import numpy as np
import torch


class ImagePreprocesser:
    def __init__(self):
        pass

    def pil_to_numpy(self, pil_image):
        image = np.array(pil_image).astype(np.float32) / 255.0
        return image

    def numpy_to_pt(self, image, device):
        if image.ndim == 2:
            image = np.expand_dims(image, axis=2)
        image = torch.from_numpy(image).unsqueeze(0).to(device)
        return image.permute(0, 3, 1, 2)

    def load_image(self, image_path, image_size, device):
        if type(image_path) is str:
            assert os.path.exists(image_path), f"Image path {image_path} does not exist"
            pil_image = Image.open(image_path).convert("RGB")
            assert pil_image.size[0] == pil_image.size[1] == image_size, f"Image {image_path} is not square"
        else:
            pil_image = image_path
        image = self.pil_to_numpy(pil_image)
        image = self.numpy_to_pt(image, device)
        return image
