from .base import *

class SOP(BaseDataset):
    def __init__(self, root, mode, transform=None):
        self.root = root + '/Stanford_Online_Products'
        self.mode = mode
        self.transform = transform
        if self.mode == 'train':
            self.classes = range(0, 11318)
        elif self.mode == 'eval':
            self.classes = range(11318, 22634)

        BaseDataset.__init__(self, self.root, self.mode, self.transform)
        metadata = open(os.path.join(self.root, 'Ebay_train.txt' if self.classes == range(0, 11318) else 'Ebay_test.txt'))
        for i, (image_id, class_id, _, path) in enumerate(map(str.split, metadata)):
            if i > 0:
                if int(class_id)-1 in self.classes:
                    self.ys += [int(class_id)-1]
                    self.I += [int(image_id)-1]
                    self.im_paths.append(os.path.join(self.root, path))
