import torch
from torchvision import transforms
from torchvision.io import read_image
from torch.utils.data import Dataset


class SingleImageDataset(Dataset):
    def __init__(self, path):
        self.image = read_image(path)
        self.image = transforms.Compose([
            transforms.CenterCrop([1024, 1024]),
            transforms.Resize([512,512], interpolation=2, antialias=True)
        ])(self.image).transpose(0,-1).transpose(0,1)
        self.image = self.image.to(torch.float)
        self.image = (self.image-self.image.min())/(self.image.max()-self.image.min())*2-1
        self.shape = self.image.size()
        self.spacing=2

        x_coords = torch.arange(0, self.shape[0], dtype=torch.float) / self.shape[0]
        y_coords = torch.arange(0, self.shape[1], dtype=torch.float) / self.shape[1]
        grid_x, grid_y = torch.meshgrid(x_coords, y_coords, indexing='ij')
        
        self.position = torch.stack([grid_x, grid_y], dim=-1)

        self.training_image = self.image[::self.spacing, ::self.spacing]
        self.training_position = self.position[::self.spacing, ::self.spacing]

        self.eval_image = self.image[1::self.spacing, 1::self.spacing]
        self.eval_position = self.position[1::self.spacing, 1::self.spacing]

        self.eval_mask =torch.ones_like(self.image)
        self.eval_mask[::self.spacing, ::self.spacing] = 0

    