from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import ImageFolder


def AffectNetDataLoader(batch_size):
    transformer = transforms.Compose([transforms.Resize((240,240)),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                     ])


    training_data = ImageFolder(root =  "/VirtualSanghwa/AffectNet/train", transform =transformer , target_transform = None)
    test_data = ImageFolder(root =  "/VirtualSanghwa/AffectNet/val", transform = transformer, target_transform = None)



    # 데이터로더를 생성합니다.
    train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

    return train_dataloader, test_dataloader