from torchvision import datasets, transforms as T
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
import torch
from pathlib import Path
import os
import numpy as np

import imagenet_presets as presets

NORMALIZE_DICT = {
    'cifar10':  dict( mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010) ),
    'cifar100': dict( mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761) ),
    'cifar10_224':  dict( mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010) ),
    'cifar100_224': dict( mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761) ),
}

def bar_custom(current, total):
    progress = "Downloading: %d%% [%d / %d] bytes" % (current / total * 100, current, total)
    return progress


def download_imagenet(path):
    val_url = 'http://www.image-net.org/challenges/LSVRC/2012/dd31405981ef5f776aa17412e1f0c112/ILSVRC2012_img_val.tar'
    devkit_url = 'http://www.image-net.org/challenges/LSVRC/2012/dd31405981ef5f776aa17412e1f0c112/ILSVRC2012_devkit_t12.tar.gz'

    print("Downloading ImageNet...")
    os.makedirs(path, exist_ok=True)
    wget.download(url=val_url, out=path, bar=bar_custom)
    wget.download(url=devkit_url, out=path, bar=bar_custom)
    print('done!')

def load_data_torch(config):
    data_path = Path(__file__).resolve().parent / "dataset"
    dataset_name=config.dataset_name
    eval_batch_size=config.eval_batch_size
    torch.manual_seed(config.torch_seed)
    if dataset_name == 'imagenet':
        IMAGENET_ROOT = Path(os.environ['IMAGENET_ROOT'])
        if not IMAGENET_ROOT.exists():
            raise FileNotFoundError('Imagenet dataset missing')
        resize_size, crop_size = (256, 224)
        # transform_train=presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=None,
        #                                       random_erase_prob=0)
        transform_test = presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size)
        # train_set = datasets.ImageNet(
        #     IMAGENET_ROOT,split='train',transform =transform_train)
        test_set = datasets.ImageNet(
            IMAGENET_ROOT,split='val',transform =transform_test)
        
        # train_sampler = RandomSampler(train_set)
        test_sampler = SequentialSampler(test_set)

        collate_fn =None
        # num_classes = len(dataset.classes)
        # mixup_transforms = []
        # if args.mixup_alpha > 0.0:
        #     mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
        # if args.cutmix_alpha > 0.0:
        #     mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
        # if mixup_transforms:
        #     mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)

        #     def collate_fn(batch):
        #         return mixupcutmix(*default_collate(batch))
        # train_loader = DataLoader(
        #     train_set,
        #     batch_size=train_batch_size,
        #     sampler=train_sampler,
        #     num_workers=config.workers,
        #     pin_memory=True,
        # collate_fn=collate_fn,)

        test_loader = DataLoader(test_set,batch_size=eval_batch_size,sampler=test_sampler,num_workers=config.workers,pin_memory=True)

        return test_loader

    elif dataset_name == 'mnist':
        transform = T.Compose([T.ToTensor(),
                                        T.Normalize((0.5,), (0.5,))])
        # train_set = datasets.MNIST(data_path / 'mnist', train=True, download=True, transform=transform)
        test_set = datasets.MNIST(data_path / 'mnist', train=False, download=True, transform=transform)
    elif dataset_name == "cifar10":
        transform = T.Compose([T.ToTensor(),
                                        T.Normalize((0.5, 0.5, 0.5), (1,1,1))])
        # train_set = datasets.CIFAR10(data_path / 'cifar10', train=True, download=True, transform=transform)
        test_set = datasets.CIFAR10(data_path / 'cifar10', train=False, download=True, transform=transform)
    elif dataset_name == "cifar100":
        # train_transform = T.Compose(
        #     [
        #         T.RandomHorizontalFlip(),
        #         T.RandomCrop(32, padding=4),
        #         T.ToTensor(),
        #         T.Normalize(mean=[n/255. for n in [129.3, 124.1, 112.4]], 
        #         std=[n/255. for n in [68.2,  65.4,  70.4]])])
        test_transform = T.Compose([
            T.ToTensor(),
            T.Normalize(
                mean=[n/255. for n in [129.3, 124.1, 112.4]], 
                std=[n/255. for n in [68.2,  65.4,  70.4]]
                )
                ])
        # train_set = datasets.CIFAR100(data_path / 'cifar100', train=True, download=True, transform=train_transform)
        test_set = datasets.CIFAR100(data_path / 'cifar100', train=False, download=True, transform=test_transform)

    if dataset_name=='benchmark_cifar10':
        num_classes = 10
        assert np.all(np.array(config.input_shape)==np.array((3, 32, 32)))
        # train_transform = T.Compose([
        #     T.RandomCrop(32, padding=4),
        #     T.RandomHorizontalFlip(),
        #     T.ToTensor(),
        #     T.Normalize( **NORMALIZE_DICT['cifar10'] ),
        # ])
        val_transform = T.Compose([
            T.ToTensor(),
            T.Normalize( **NORMALIZE_DICT['cifar10'] ),
        ])
        # train_set = datasets.CIFAR10(data_path / 'cifar10', train=True, download=True, transform=train_transform)
        test_set = datasets.CIFAR10(data_path / 'cifar10', train=False, download=True, transform=val_transform)
        
    elif dataset_name=='benchmark_cifar100':
        num_classes = 100
        # train_transform = T.Compose([
        #     T.RandomCrop(32, padding=4),
        #     T.RandomHorizontalFlip(),
        #     T.ToTensor(),
        #     T.Normalize( **NORMALIZE_DICT['cifar100'] ),
        # ])
        val_transform = T.Compose([
            T.ToTensor(),
            T.Normalize( **NORMALIZE_DICT['cifar100'] ),
        ])
        # train_set = datasets.CIFAR100(data_path / 'cifar100', train=True, download=True, transform=train_transform)
        test_set = datasets.CIFAR100(data_path / 'cifar100', train=False, download=True, transform=val_transform)


    # print(f'[INFO] {dataset_name} dataset loaded. train items: {len(train_set)} ({len(train_set)//train_batch_size+1} batches), eval items: {len(test_set)} ({len(test_set)//eval_batch_size+1} batches)')
    # train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True)
    eval_loader = DataLoader(test_set, batch_size=eval_batch_size, shuffle=True)

    return eval_loader


if __name__ == '__main__':
    configs=BypassConfig.from_preset('cifar10_3c3d')

    train_loader, test_loader = load_data(configs)
    print(1)