from torch.utils.data import Dataset

class VQARTaskIndexDataset(Dataset):

    def __init__(self, task_list):
        self.task_list = [key for key in task_list.keys()]

    def __len__(self):
        return len(self.task_list)

    def __getitem__(self, index):
        # TODO Clean the hack
        if isinstance(index, str):
            print("index is str")
            print(index)
            return index
        return self.task_list[index]
