
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import glob

class Caltech256(Dataset):
    def __init__(self, img_dir, classes, transform=None):
        '''
        Params:
            img_dir: str, 数据集所在路径
            classes: list, 数据集包含的所有类名
            transform: transform, 数据处理方法
        Attributes:
            img_pths: list, 数据集所包含的所有图片路径的集合
            classes: list, 数据集包含的所有类名
            transform: transform, 数据处理方法
        '''
        img_pths = sorted(glob.glob(img_dir + os.sep + '**' + os.sep + '**.jpg'))  # 34745 张图片
        assert img_pths, 'no jpg file in ' + img_dir
        self.img_pths = img_pths
        self.classes = classes
        self.transform = transform

    def __len__(self):
        return len(self.img_pths)

    def __getitem__(self, idx):
        img_pth = self.img_pths[idx]
        cls_name = img_pth.split(os.sep)[-2]

        label = self.classes.index(cls_name)
        image = Image.open(img_pth).convert("RGB")

        if self.transform:
            image = self.transform(image)
        return image, label