from torch.utils.data import DataLoader
from torchvision import transforms
from loader.datasets import ImageDataset
from loader.datasets.augmix import AugMixDataset
import config
import numpy as np

#返回imageDataset类别
def _get_train_set(data_path):
    return ImageDataset(image_dir=data_path,
                        transform=transforms.Compose([
                            transforms.Resize((32, 32)),
                            transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                 (0.2023, 0.1994, 0.2010))
                        ]))

#返回 imgaDataset类别
def _get_test_set(data_path):
    return ImageDataset(image_dir=data_path,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                 (0.2023, 0.1994, 0.2010))
                        ]))

#加载图片，返回data_loader
def load_images(data_type=None):
    assert data_type is None or data_type in ['train', 'test', 'align']

    data_config = config.get_data_config('dcase')
    test_dir = data_config['test_dir']
    train_dir = data_config['train_dir']
    align_dir = data_config['align_dir']
    batch_size = data_config['batch_size']

    if data_type == 'train':
        data_set = _get_train_set(train_dir)
    elif data_type == 'align':
        data_set = _get_test_set(align_dir)
        # indices = np.random.randint(0, 350, size = 18)
        # data_loader, _, _ = data_set.__getitem__(indices[0])
        # data_loader = data_loader.reshape(1, data_loader.shape[0], data_loader.shape[1], data_loader.shape[2])
        # for i in range(1,18):
        #     image, _, _ = data_set.__getitem__(indices[i])
        #     image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2])
        #     data_loader = np.concatenate((data_loader, image))
        data_loader = DataLoader(dataset=data_set,
                             batch_size=190,
                             num_workers=12,
                             shuffle=True)
        return data_loader, data_config    
    else:
        data_set = _get_test_set(test_dir)

    data_loader = DataLoader(dataset=data_set,
                             batch_size=batch_size,
                             num_workers=12,
                             shuffle=True)

    return data_loader, data_config
