import os
import torch
import torchvision
import torchvision.transforms as transforms
# import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from data_utils.mydatasets import *
from data_utils.splitdatasets import SplitDataset
from glob import glob
import random

def load_base_data(root_dir, dataset,  batch_size=512):
    #f'~/Datasets/{dataset}/train/'
    trainset = FeaturesDataset(root=root_dir, dataset=dataset, split='train')
    # test_roodir = "../../../Datasets/mini_test_subset/"
    # test_roodir =   '../../../Datasets/mini_res12_test'
    # cls = os.listdir(test_roodir)
    # testdset = random.choice(cls)


    # fls = glob("~/Datasets/mini_test_subset/")
    testset = FeaturesDataset(root=root_dir, dataset=dataset, split='test')
    n_classes = trainset.n_classes

    trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size,
                                              num_workers=4, pin_memory=True, shuffle=True)
    valloader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size,
                                            shuffle=False, num_workers=4)
    return trainloader, valloader, n_classes



def load_test_data(test_roodir, dataset,  batch_size=512):

    cls = os.listdir(test_roodir)
    testdset = random.choice(cls)

    # fls = glob("~/Datasets/mini_test_subset/")
    testset = FeaturesDataset(root=test_roodir, dataset=dataset, split='test')
    n_classes = testset.n_classes

    valloader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size,
                                            shuffle=False, num_workers=4)
    return valloader, n_classes

def load_split_data(root_dir, dataset, batch_size=512, subset=1, max_ways=10):
    #f'~/Datasets/{dataset}/train/'
    trainset = SplitDataset(root=root_dir, dataset=dataset, split='train', subset=subset, max_ways=max_ways)
    # test_roodir = "../../../Datasets/mini_test_subset/"
    # test_roodir =   '../../../Datasets/mini_res12_test'
    # cls = os.listdir(test_roodir)
    # testdset = random.choice(cls)


    # fls = glob("~/Datasets/mini_test_subset/")
    testset = SplitDataset(root=root_dir, dataset=dataset, split='test', subset=subset, max_ways=max_ways)
    n_classes = trainset.n_classes

    trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size,
                                              num_workers=4, pin_memory=True, shuffle=True)
    valloader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size,
                                            shuffle=False, num_workers=4)
    return trainloader, valloader, n_classes
