import os
from torch.utils.data import DataLoader

from datasets import ESC50

__all__ = ['get_dataset']

def get_dataset(args):
    if args.dataset == 'esc50':
        trainset = ESC50(root=args.root, split='train', name='esc50', model=args.model, nways=args.ways, shots=args.shots, query_shots=args.query_shots, num_tasks=args.num_tasks, niterations=args.N, batch_size=args.train_batch_size, visualize=args.visualize)
        validset = ESC50(root=args.root, split='valid', name='esc50', model=args.model, nways=args.ways, shots=args.shots, query_shots=args.shots,
                niterations=args.N_valid, batch_size=args.batch_size, visualize=args.visualize)
        testset  = ESC50(root=args.root, split='test',  name='esc50', model=args.model, nways=args.ways, shots=args.shots, query_shots=args.shots,
                niterations=args.N_test, batch_size=args.batch_size, visualize=args.visualize)
    else:
        raise NotImplementedError
    
    trainloader = DataLoader(trainset, batch_size=args.train_batch_size, num_workers=args.num_workers_train, shuffle=True, drop_last=True)
    validloader = DataLoader(validset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
    testloader  = DataLoader(testset,  batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
    return trainloader, validloader, testloader
