
import torch
import torchvision
import torchvision.transforms as transforms
# import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from data_utils.tiny_datasets import *
import random
# imagenet_data = torchvision.datasets.ImageNet('/root/datasets/', split='val')

def load_data(root_dir, dataset,  batch_size=512):
    #f'~/Datasets/{dataset}/train/'
    trainset = FeaturesDataset(root=root_dir, dataset=dataset, split='train')
    test_roodir = "../../../Datasets/mini_test_subset/"
    # test_roodir =   '../../../Datasets/mini_res12_test'
    cls = os.listdir(test_roodir)
    testdset = random.choice(cls)

    # fls = glob("~/Datasets/mini_test_subset/")
    testset = FeaturesDataset(root=test_roodir, dataset=testdset, split='test')
    n_classes = trainset.n_classes

    trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size,
                                              num_workers=4, pin_memory=True, shuffle=True)
    valloader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size,
                                            shuffle=False, num_workers=4)
    return trainloader, valloader, n_classes

