import os
from PIL import Image
from torch.utils.data import Dataset


class Caltech101Dataset(Dataset):
    def __init__(self, root, split='train', transform=None):
        self.root = root
        self.root = os.path.join(root, 'caltech-101/101_ObjectCategories')
        self.split = split
        self.transform = transform
        self.train_ratio = 0.8

        self.classes = sorted([d for d in os.listdir(self.root) if os.path.isdir(os.path.join(self.root, d)) and d != 'BACKGROUND_Google'])
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.data, self.targets = self._load_dataset()

    def _load_dataset(self):
        data, targets = [], []
        for cls_name in self.classes:
            cls_dir = os.path.join(self.root, cls_name)
            imgs = [os.path.join(cls_dir, f) for f in os.listdir(cls_dir)
                    if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            imgs.sort()

            n = len(imgs)
            n_train = int(n * self.train_ratio)

            if self.split == 'train':
                split_imgs = imgs[:n_train]
            elif self.split == 'test':
                split_imgs = imgs[n_train:]
            else:
                raise ValueError(f"Unknown split: {self.split}")

            data.extend(split_imgs)
            targets.extend([self.class_to_idx[cls_name]] * len(split_imgs))

        return data, targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_path = self.data[index]
        label = self.targets[index]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label
