
import os
import numpy as np

from torch.utils.data import Dataset
from .base import BaseDataset

class MLCDataset(BaseDataset, Dataset):
    def __init__(self, data_name, base_path):
        super().__init__(data_name, base_path)
        if data_name == 'objects365':
            self.file_path = os.path.join(base_path, data_name, 'o251', f'formatted_train_images_012.npy')
        elif data_name == 'coco2017':
            self.file_path = os.path.join(base_path, data_name, f'formatted_unlabeled_images.npy')
        else:
            self.file_path = os.path.join(base_path, data_name, f'formatted_train_images.npy')
        self.images = self.load()

    def load(self):

        return np.load(self.file_path)

    def __getitem__(self, index):
        image = os.path.join(self.image_path, self.images[index])

        return image

    def __len__(self):
        return len(self.images)

