import os
from PIL import Image
from torch.utils.data import Dataset


class Caltech256Dataset(Dataset):
    def __init__(self, root, split='train', transform=None):
        self.root = os.path.join(root, '256_ObjectCategories')
        self.split = split
        self.transform = transform
        self.train_ratio = 0.8

        raw_classes = sorted([
            d for d in os.listdir(self.root)
            if os.path.isdir(os.path.join(self.root, d))
        ])

        self.classes = [cls_name.split('.', 1)[1] if '.' in cls_name else cls_name for cls_name in raw_classes]
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.dir_map = {cls_name: raw_dir for cls_name, raw_dir in zip(self.classes, raw_classes)}
        self.data, self.targets = self._load_dataset()

    def _load_dataset(self):
        data, targets = [], []
        for cls_name in self.classes:
            cls_dir = os.path.join(self.root, self.dir_map[cls_name])
            imgs = [
                os.path.join(cls_dir, f)
                for f in os.listdir(cls_dir)
                if f.lower().endswith(('.jpg', '.jpeg', '.png'))
            ]
            imgs.sort()

            n = len(imgs)
            n_train = int(n * self.train_ratio)

            if self.split == 'train':
                split_imgs = imgs[:n_train]
            elif self.split == 'test':
                split_imgs = imgs[n_train:]
            else:
                raise ValueError(f"Unknown split")

            data.extend(split_imgs)
            targets.extend([self.class_to_idx[cls_name]] * len(split_imgs))

        return data, targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_path = self.data[index]
        label = self.targets[index]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label
