import os
import sys
sys.path.insert(0, './')
import numpy as np

from dataset.Cifar10 import cifar10, syn_cifar10
from dataset.Cifar100 import cifar100, cifar100_superclass
from dataset.ImageNet import imagenet
from dataset.Adults import adults
from dataset.Purchase import purchase100

def parse_data(name, batch_size, root, valid_ratio = None, shuffle = True, augmentation = True, **kwargs):

    if name.lower() in ['cifar10',]:
        train_loader, valid_loader, test_loader, classes = cifar10(batch_size, root=root, valid_ratio=valid_ratio, shuffle=shuffle, augmentation=augmentation, **kwargs)
    elif name.lower() in ['syn_cifar10',]:
        train_loader, valid_loader, test_loader, classes = syn_cifar10(batch_size, root=root, valid_ratio=None, shuffle=False, augmentation=False, **kwargs)
    elif name.lower() in ['cifar100',]:
        train_loader, valid_loader, test_loader, classes = cifar100(batch_size, root=root, valid_ratio=valid_ratio, shuffle=shuffle, augmentation=augmentation, **kwargs)
    elif name.lower() in ['cifar100_superclass',]:
        train_loader, valid_loader, test_loader, classes = cifar100_superclass(batch_size, root=root, valid_ratio=valid_ratio, shuffle=shuffle, augmentation=augmentation, **kwargs)
    elif name.lower() in ['imagenet','imagenet100']:
        train_loader, valid_loader, test_loader, classes = imagenet(batch_size, root=root, valid_ratio=valid_ratio, shuffle=shuffle, augmentation=augmentation, **kwargs)
    elif name.lower() in ['adults', ]:
        train_loader, valid_loader, test_loader, classes = adults(batch_size, valid_ratio=valid_ratio, shuffle=shuffle, augmentation=augmentation, root=root, **kwargs)
    elif name.lower() in ['purchase100', ]:
        train_loader, valid_loader, test_loader, classes = purchase100(batch_size, valid_ratio=valid_ratio, shuffle=shuffle, augmentation=augmentation, root=root, **kwargs)
    else:
        raise NotImplementedError('Invalid dataset: %s' % name)

    return train_loader, valid_loader, test_loader, classes


def parse_data_mia(name, root, batch_size, split, **kwargs):
    if name.lower() in ['cifar10',]:
        train_loader, _, test_loader, classes = cifar10(batch_size, root=root, valid_ratio=None, shuffle=False, augmentation=False, **kwargs)
        loader = {'train': train_loader, 'test': test_loader}[split]
    elif name.lower() in ['syn_cifar10',]:
        train_loader, _, test_loader, classes = syn_cifar10(batch_size, root=root, valid_ratio=None, shuffle=False, augmentation=False, **kwargs)
        loader = {'train': train_loader, 'test': test_loader}[split]
    elif name.lower() in ['cifar100',]:
        train_loader, _, test_loader, classes = cifar100(batch_size, root=root, valid_ratio=None, shuffle=False, augmentation=False, **kwargs)
        loader = {'train': train_loader, 'test': test_loader}[split]
    elif name.lower() in ['cifar100_superclass',]:
        train_loader, _, test_loader, classes = cifar100_superclass(batch_size, root=root, valid_ratio=None, shuffle=False, augmentation=False, **kwargs)
        loader = {'train': train_loader, 'test': test_loader}[split]
    elif name.lower() in ['imagenet','imagenet100']:
        train_loader, _, test_loader, classes = imagenet(batch_size, root=root, valid_ratio=None, shuffle=False, augmentation=False, **kwargs)
        loader = {'train': train_loader, 'test': test_loader}[split]
    elif name.lower() in ['adults', ]:
        train_loader, _, test_loader, classes = adults(batch_size, valid_ratio=None, shuffle=False, augmentation=False, root=root, **kwargs)
        loader = {'train': train_loader, 'test': test_loader}[split]
    elif name.lower() in ['purchase100', ]:
        train_loader, _, test_loader, classes = purchase100(batch_size, valid_ratio=None, shuffle=False, augmentation=False, root=root, **kwargs)
        loader = {'train': train_loader, 'test': test_loader}[split]
    else:
        raise NotImplementedError('Invalid dataset: %s' % name)

    return loader


def parse_shadow_data(name, batch_size, root, n_shadow, valid_ratio = None, shuffle = True, augmentation = True, **kwargs):
    shadow_train_loaders = []
    shadow_test_loaders = []

    for i in range(n_shadow):
        if name.lower() in ['cifar10',]:
            train_loader, valid_loader, test_loader, classes = cifar10(batch_size, root=root, valid_ratio = valid_ratio, shadow_seed=i,
                                                                       shuffle = shuffle, augmentation = augmentation, **kwargs)
        elif name.lower() in ['syn_cifar10',]:
            train_loader, valid_loader, test_loader, classes = syn_cifar10(batch_size, root=root, valid_ratio=None, chunk=i,
                                                                           shuffle=False, augmentation=False, **kwargs)
        elif name.lower() in ['cifar100',]:
            train_loader, valid_loader, test_loader, classes = cifar100(batch_size, root=root, valid_ratio = valid_ratio, shadow_seed=i,
                                                                       shuffle = shuffle, augmentation = augmentation, **kwargs)
        elif name.lower() in ['cifar100_superclass',]:
            train_loader, valid_loader, test_loader, classes = cifar100_superclass(batch_size, root=root, valid_ratio = valid_ratio, shadow_seed=i,
                                                                       shuffle = shuffle, augmentation = augmentation, **kwargs)
        elif name.lower() in ['imagenet','imagenet100']:
            train_loader, valid_loader, test_loader, classes = imagenet(batch_size, root=root, valid_ratio = valid_ratio, shadow_seed=i,
                                                                        shuffle = shuffle, augmentation = augmentation, **kwargs)
        elif name.lower() in ['adults', ]:
            train_loader, valid_loader, test_loader, classes = adults(batch_size, valid_ratio=valid_ratio, shadow_seed=i,
                                                                      shuffle=shuffle, augmentation=augmentation, root=root, **kwargs)
        elif name.lower() in ['purchase100', ]:
            train_loader, valid_loader, test_loader, classes = purchase100(batch_size, valid_ratio=valid_ratio, shadow_seed=i,
                                                                           shuffle=shuffle, augmentation=augmentation, root=root, **kwargs)
        else:
            raise NotImplementedError('Invalid dataset: %s' % name)
        shadow_train_loaders.append(train_loader)
        shadow_test_loaders.append(test_loader)
    return shadow_train_loaders, shadow_test_loaders