from dataset.augmentations import ImageNetRTransform
from PIL import Image
from torch.utils.data import random_split, Dataset
from torchvision import datasets, transforms
from typing import Tuple
import os
import pickle


class ImageNetR(Dataset):
    """Custom PyTorch Dataset for loading and transforming the ImageNet-R dataset with train-test split."""
    def __init__(self, train: bool = True, image_size: int = 224, transform=None, *args, **kwargs) -> None:
        self.root_dir = os.path.join('data', 'imagenet-r')
        self.test_size = 0.2
        self.transform = ImageNetRTransform(image_size) if transform is None else transform
        
        # Select transform based on the split
        # self.transform = imagenet_r_transform.train_transform if train else imagenet_r_transform.test_transform
        
        # Load wnid to class mapping
        with open('data/imagenet-r-class-names.pkl', 'rb') as f:
            self.wnid_to_class_name_mapping = pickle.load(f)
        
        # Load full dataset with ImageFolder to get paths and labels
        full_dataset = datasets.ImageFolder(root=self.root_dir)
        
        # Split into train and test sets
        train_size = int((1 - self.test_size) * len(full_dataset))
        test_size = len(full_dataset) - train_size
        train_set, test_set = random_split(full_dataset, [train_size, test_size])
        
        # Choose data split
        selected_set = train_set if train else test_set
        self.data = [full_dataset.samples[i][0] for i in selected_set.indices]  # image paths
        self.targets = [full_dataset.samples[i][1] for i in selected_set.indices]  # labels
        
        # Map class indices to names
        self.idx_to_class_name = {
            idx: self.wnid_to_class_name_mapping[cls]
            for cls, idx in full_dataset.class_to_idx.items()
        }

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Tuple[Image.Image, int]:
        # Load the image and apply transformations
        img_path = self.data[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.targets[idx]
        return image, label
