
import torch
import torchvision
import torchvision.transforms as transforms
# import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from data_utils.metadatasets import *


def load_data(root_dir, dataset,  batch_size=512):
    #f'~/Datasets/{dataset}/train/'
    trainset = FeaturesDataset(root=root_dir, dataset=dataset, split='tr')
    testset = FeaturesDataset(root=root_dir, dataset=dataset, split='te')
    # testset = FeaturesDataset(root=root_dir, dataset=dataset, split='test')
    n_classes = trainset.n_classes

    trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size,
                                              num_workers=4, pin_memory=True, shuffle=True)
    valloader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size,
                                            shuffle=False,num_workers=4)
    return trainloader, valloader, n_classes

