import os
from torchvision import transforms, datasets


def get_imagenet_stats():
    '''
    Returns standard ImageNet statistics.
    '''

    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]

    return (imagenet_mean, imagenet_std)


def get_transforms(resize):
    '''
    Returns image transforms.
    '''

    (imagenet_mean, imagenet_std) = get_imagenet_stats()

    tx = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.Resize(resize),
        transforms.ToTensor(),
        transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
    ])

    return tx


def get_data(domain_name, P):
    '''
    Given a parameter dictionary P, initialize and return the specified dataset.
    '''

    # define transforms:
    tx = get_transforms(P['img_resize'])

    if P['dataset'] == 'office31':
        domain_data_path = os.path.join(P['data_root'], P['dataset'], domain_name, 'images')
    else:
        domain_data_path = os.path.join(P['data_root'], P['dataset'], domain_name)

    # select and return the right dataset:
    if P['dataset'] in ['office31', 'visda2017']:
        img_dataset = datasets.ImageFolder(domain_data_path, transform=tx)
    else:
        raise ValueError('Unknown dataset.')

    return img_dataset
