import os
import torch
import torch.utils
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

def build_imagenet_data(data_path='', input_size=224, batch_size=256, workers=24):
    print('==> Using Pytorch Dataset')

    train_dir = os.path.join(data_path, 'train')
    val_dir = os.path.join(data_path, 'val')
    test_dir = '/home/admin1/Syh/Training-free-quant/mixed_bit/data/train'
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )

    # torchvision.set_image_backend('accimage') # Fast
    train_dataset = datasets.ImageFolder(
        root=train_dir,
        transform=transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    )

    val_dataset = datasets.ImageFolder(
        root=val_dir,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            normalize,
        ])
    )

    test_dataset = datasets.ImageFolder(
        test_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True
    )

    val_dataloader = torch.utils.data.DataLoader(
        dataset=val_dataset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True
    )

    test_dataloader = torch.utils.data.DataLoader(
        dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True
    )

    return train_dataloader, val_dataloader, test_dataloader