# Custom PyTorch dataset for loading pre-rotated MNIST images with optional transforms.

import torch
from torch.utils.data import Dataset

class RotatedMNISTDataset(Dataset):
    def __init__(self, pt_path, transform=None):
        # Load preprocessed MNIST data (images, labels, rotation angles) from .pt file
        data = torch.load(pt_path)
        self.images = data["images"]
        self.labels = data["labels"]
        self.angles = data["angles"]
        self.transform = transform  # optional transform (e.g., normalization, augmentation)

    def __len__(self):
        # Return number of samples
        return len(self.images)

    def __getitem__(self, idx):
        # Retrieve single image and label pair
        img = self.images[idx]
        label = int(self.labels[idx])
        if self.transform:
            img = self.transform(img)  # apply transform if provided
        return img, label
