from torchvision import datasets, transforms as T
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
import torch
from pathlib import Path
import os, wget
import numpy as np

from bypass.core.dataset.imagenet_utils import presets, transforms, utils, sampler
from bypass.core.dataset.imagenet_utils.transforms import *
from bypass.core.dataset.imagenet_utils.sampler import RASampler
from bypass.core.dataset.imagenet_utils.utils import *

NORMALIZE_DICT = {
    'cifar10':  dict( mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010) ),
    'cifar100': dict( mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761) ),
    'cifar10_224':  dict( mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010) ),
    'cifar100_224': dict( mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761) ),
}

def bar_custom(current, total):
    progress = "Downloading: %d%% [%d / %d] bytes" % (current / total * 100, current, total)
    return progress


def download_imagenet(path):
    val_url = 'http://www.image-net.org/challenges/LSVRC/2012/dd31405981ef5f776aa17412e1f0c112/ILSVRC2012_img_val.tar'
    devkit_url = 'http://www.image-net.org/challenges/LSVRC/2012/dd31405981ef5f776aa17412e1f0c112/ILSVRC2012_devkit_t12.tar.gz'

    print("Downloading ImageNet...")
    os.makedirs(path, exist_ok=True)
    wget.download(url=val_url, out=path, bar=bar_custom)
    wget.download(url=devkit_url, out=path, bar=bar_custom)
    print('done!')

def load_data_torch(config):
    data_path = Path(__file__).resolve().parents[3] / "dataset"
    dataset_name=config.dataset_name
    train_batch_size=config.train_batch_size
    eval_batch_size=config.eval_batch_size
    torch.manual_seed(config.torch_seed)
    if dataset_name == 'imagenet':
        IMAGENET_ROOT = Path(os.environ['IMAGENET_ROOT'])
        resize_size, crop_size = (256, 224)
        if config.model_name == 'imagenet_DeiTBase':
            transform_train = presets.deitimagenet(input_size=crop_size, is_train=True)
            transform_test = presets.deitimagenet(input_size=crop_size, is_train=False)
        else:
            transform_train=presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=None,
                                                random_erase_prob=0)
            transform_test = presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size)
        # train_dataset = datasets.ImageFolder(
        #     IMAGENET_ROOT/'train',transform =transform_preset)
        train_set = datasets.ImageNet(
            IMAGENET_ROOT,split='train',transform =transform_train)
        test_set = datasets.ImageNet(
            IMAGENET_ROOT,split='val',transform =transform_test)
        
        num_tasks = get_world_size()
        global_rank = get_rank()
        train_sampler = RASampler(
                train_set, num_replicas=num_tasks, rank=global_rank, shuffle=True
            )
        test_sampler = SequentialSampler(test_set)

        collate_fn =None
        # num_classes = len(dataset.classes)
        # mixup_transforms = []
        # if args.mixup_alpha > 0.0:
        #     mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
        # if args.cutmix_alpha > 0.0:
        #     mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
        # if mixup_transforms:
        #     mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)

        #     def collate_fn(batch):
        #         return mixupcutmix(*default_collate(batch))

        train_loader = DataLoader(
            train_set,
            batch_size=train_batch_size,
            sampler=train_sampler,
            num_workers=config.workers,
            pin_memory=True,
            drop_last=True,
            #shuffle=True,
        collate_fn=collate_fn,)
        
        test_loader = DataLoader(test_set,batch_size=eval_batch_size,sampler=test_sampler,num_workers=config.workers,pin_memory=True, drop_last=False,shuffle=False)

        return train_loader, test_loader

    elif dataset_name == 'mnist':
        transform = T.Compose([T.ToTensor(),
                                        T.Normalize((0.5,), (0.5,))])
        train_set = datasets.MNIST(data_path / 'mnist', train=True, download=True, transform=transform)
        test_set = datasets.MNIST(data_path / 'mnist', train=False, download=True, transform=transform)
    elif dataset_name == "cifar10":
        transform = T.Compose([T.ToTensor(),
                                        T.Normalize((0.5, 0.5, 0.5), (1,1,1))])
        train_set = datasets.CIFAR10(data_path / 'cifar10', train=True, download=True, transform=transform)
        test_set = datasets.CIFAR10(data_path / 'cifar10', train=False, download=True, transform=transform)
    elif dataset_name == "cifar100":
        train_transform = T.Compose(
            [
                T.RandomHorizontalFlip(),
                T.RandomCrop(32, padding=4),
                T.ToTensor(),
                T.Normalize(mean=[n/255. for n in [129.3, 124.1, 112.4]], 
                std=[n/255. for n in [68.2,  65.4,  70.4]])])
        test_transform = T.Compose([
            T.ToTensor(),
            T.Normalize(
                mean=[n/255. for n in [129.3, 124.1, 112.4]], 
                std=[n/255. for n in [68.2,  65.4,  70.4]]
                )
                ])
        train_set = datasets.CIFAR100(data_path / 'cifar100', train=True, download=True, transform=train_transform)
        test_set = datasets.CIFAR100(data_path / 'cifar100', train=False, download=True, transform=test_transform)

    if dataset_name=='benchmark_cifar10':
        num_classes = 10
        assert np.all(np.array(config.input_shape)==np.array((3, 32, 32)))
        train_transform = T.Compose([
            T.RandomCrop(32, padding=4),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize( **NORMALIZE_DICT['cifar10'] ),
        ])
        val_transform = T.Compose([
            T.ToTensor(),
            T.Normalize( **NORMALIZE_DICT['cifar10'] ),
        ])
        train_set = datasets.CIFAR10(data_path / 'cifar10', train=True, download=True, transform=train_transform)
        test_set = datasets.CIFAR10(data_path / 'cifar10', train=False, download=False, transform=val_transform)
        
    elif dataset_name=='benchmark_cifar100':
        num_classes = 100
        train_transform = T.Compose([
            T.RandomCrop(32, padding=4),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize( **NORMALIZE_DICT['cifar100'] ),
        ])
        val_transform = T.Compose([
            T.ToTensor(),
            T.Normalize( **NORMALIZE_DICT['cifar100'] ),
        ])
        train_set = datasets.CIFAR100(data_path / 'cifar100', train=True, download=True, transform=train_transform)
        test_set = datasets.CIFAR100(data_path / 'cifar100', train=False, download=True, transform=val_transform)


    print(f'[INFO] {dataset_name} dataset loaded. train items: {len(train_set)} ({len(train_set)//train_batch_size+1} batches), eval items: {len(test_set)} ({len(test_set)//eval_batch_size+1} batches)')
    train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True)
    eval_loader = DataLoader(test_set, batch_size=eval_batch_size, shuffle=True)

    return train_loader, eval_loader


if __name__ == '__main__':
    #configs=BypassConfig.from_preset('cifar10_3c3d')

    #train_loader, test_loader = load_data(configs)
    #print(1)
    
    if not 'IMAGENET_ROOT' in os.environ:
        os.environ['IMAGENET_ROOT'] = str('/workspace/dataset/Imagenet')
