import torch
from torch.utils import data

class SubDataset(data.Dataset):
    def __init__(self):
        self.name  = 'labeled dataset for substitute training'
        self.items = []

    def __getitem__(self, index):
        item = self.items[index]
        seed = item[0]
        data = item[1]
        prob = item[2]
        next_data = item[3]
        return seed, data, prob, next_data
    
    def __len__(self):
        return len(self.items)


class UnlabeledDataset(data.Dataset):
    def __init__(self):
        self.name = 'unlabeled dataset for substitute training'
        self.items = []

    def __getitem__(self, index):
        item = self.items[index]
        data = item[0]
        prob = item[1]
        label = item[2]
        used = item[3]  # involved in the querying process
        return data, prob, label, used

    def __len__(self):
        return len(self.items)


class LabeledDataset(data.Dataset):
    def __init__(self):
        self.name = 'labeled dataset'
        self.items = []

    def __getitem__(self, index):
        item = self.items[index]
        data = item[0]
        target = item[1]
        return data, target

    def __len__(self):
        return len(self.items)
