import torch
from torch.utils.data import DataLoader, Dataset

class CombinData(Dataset):
    def __init__(self, data1: Dataset, activate_vector):
        """
        :param data1: Few labeled data or part of all data.
        :param activate_vector: The activate vector which correspond to data1.
        """
        super(CombinData, self).__init__()

        self.data1 = data1
        self.activate_vector = activate_vector

    def __len__(self):
        return len(self.data1)

    def __getitem__(self, index):
        images_path, *_ = self.data1[index]
        a = self.activate_vector[index]

        return images_path, a

if __name__ == '__main__':
    data = Animals(root='data/Animals_with_Attributes2', test=False)
    data1 = 0.01 * data

    activate_vectors = torch.load('activate_vector/animals_train/id0_activate.pth')
    coun = 0
    cdata = CombinData(data1, activate_vectors)
    for i in DataLoader(cdata, shuffle=False, batch_size=4):
        print(i[0])


