import os
from glob import glob
from tqdm import tqdm
from torch.utils.data import Dataset
import torch
import torch.nn as nn

import torch
import torchvision
import torchvision.transforms as transforms
# import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# d = torch.load('../../../Datasets/Tiny-subsets/train_0')
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
train_transform = transforms.Compose([transforms.Resize((224)),
                                      transforms.ColorJitter(brightness=32 / 255.0, saturation=0.5),
                                      transforms.RandomCrop(224, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean, std),
                                      ])

test_transform = transforms.Compose([transforms.Resize((224, 224)),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean, std),
                                      ])
def load_data(data_dir, dataset, batch_size=512):
    # data_dir = '../../../Datasets/v14/geon/final_data/MetaTrain/'
    train_root = os.path.join(data_dir, dataset, 'tr')
    # test_root = os.path.join(data_dir, dataset, 'te')
    # val_root = os.path.join(data_dir, dataset, 'va')
    n_classes = len(os.listdir(train_root))
    trainset = torchvision.datasets.ImageFolder(root=f'{data_dir}/{dataset}/tr',
                                                transform=train_transform)
    testset = torchvision.datasets.ImageFolder(root=f'{data_dir}/{dataset}/te',
                                               transform=test_transform)

    valset = torchvision.datasets.ImageFolder(root=f'{data_dir}/{dataset}/va',
                                               transform=test_transform)

    trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=4)
    testloader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size,
                                            shuffle=False)
    valloader = torch.utils.data.DataLoader(dataset=valset, batch_size=batch_size,
                                            shuffle=False, num_workers=4)
    return trainloader, valloader, testloader, n_classes