import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np


class MyDataset(Dataset):
    def __init__(self):
        data = []
        labels = []
        names = ["AIDS", "DHFR", "Mutagenicity", "NCI1", "DD", "ENZYMES", "PROTEINS", "Cuneiform", "MSRC_21"]
        for name in names:
            data += list(np.load(f'./data/{name}/after/data.npy'))
            labels += list(np.load(f'./data/{name}/after/props.npy'))
    
        
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


def load_data():
    dataset = MyDataset()
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
    return dataloader
    