import random
import torch
from torch.utils.data import Dataset, DataLoader, random_split


class MaskPatchTransform:
    """
    Class to Randomly masks a square patch of the image as explained in the paper.
    Returns (masked_image, (top, left, mask_size)).
    """
    def __init__(self, mask_size=32):
        self.mask_size = mask_size

    def __call__(self, img_tensor):
        """
        img_tensor: (C, H, W) PyTorch tensor.
        """
        _, H, W = img_tensor.shape
        top = random.randint(0, H - self.mask_size)
        left = random.randint(0, W - self.mask_size)
        masked_img = img_tensor.clone()
        masked_img[:, top:top+self.mask_size, left:left+self.mask_size] = 0.0
        return masked_img, (top, left, self.mask_size)

class InpaintingDatasetWrapper(Dataset):
    """
    Function wrapper for the dataset.
    It takes an image from the base dataset, applies a mask transform,
    and returns (masked_image, original_image, mask_info).
    """
    def __init__(self, base_dataset, mask_size=32):
        self.base_dataset = base_dataset
        self.mask_transform = MaskPatchTransform(mask_size=mask_size)
    
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        sample = self.base_dataset[idx]
        image = sample[0] if isinstance(sample, (tuple, list)) else sample
        masked_img, mask_info = self.mask_transform(image)
        return masked_img, image, mask_info

def custom_collate_fn(batch):
    """
    Ensures mask_info is not merged into a single tensor.
    """
    masked_imgs, original_imgs, mask_infos = [], [], []
    for (m_img, o_img, m_info) in batch:
        masked_imgs.append(m_img)
        original_imgs.append(o_img)
        mask_infos.append(m_info)
    masked_imgs = torch.stack(masked_imgs, dim=0)
    original_imgs = torch.stack(original_imgs, dim=0)
    return masked_imgs, original_imgs, mask_infos



if __name__=="__main__":
    pass