import os
import torch
import torchvision
from torchvision import transforms, datasets

from torch.utils.data import DataLoader

# TODO train transforms
transform_train = transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

transform_test = transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

transform_train_largescale = transforms.Compose([
        transforms.Resize(256),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

transform_test_largescale = transforms.Compose([
        transforms.Resize(256),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

kwargs = {'num_workers': 2, 'pin_memory': True}
num_classes_dict = {'CIFAR-100': 100, 'CIFAR-10': 10, 'imagenet': 1000}

def get_id_loader(args, config_type='eval', split='val'):

    config = {
        "default": {
            'transform_train': transform_train,
            'transform_test': transform_test,
            'batch_size': args.batch_size,
            'transform_test_largescale': transform_test_largescale,
            'transform_train_largescale': transform_train_largescale,
        },
        "eval": {
            'transform_train': transform_test,
            'transform_test': transform_test,
            'batch_size': args.batch_size,
            'transform_test_largescale': transform_test_largescale,
            'transform_train_largescale': transform_test_largescale,
        },
    }[config_type]

    train_loader, val_loader = None, None

    if args.in_dataset == "CIFAR-10":
        # Data loading code
        if 'train' in split:
            trainset = datasets.CIFAR10(root='../../datasets/data', train=True, download=True, transform=config.transform_train)
            train_loader = DataLoader(trainset, batch_size=config.batch_size, shuffle=True, **kwargs)
        if 'val' in split:
            valset = datasets.CIFAR10(root='../../datasets/data', train=False, download=True, transform=transform_test)
            val_loader = DataLoader(valset, batch_size=config.batch_size, shuffle=True, **kwargs)
    elif args.in_dataset == "CIFAR-100":
        # TODO
        if 'train' in split:
            trainset = datasets.CIFAR100(root='../../datasets/data', train=True, download=True, transform=config.transform_train)
            train_loader = DataLoader(trainset, batch_size=config.batch_size, shuffle=True, **kwargs)
        if 'val' in split:
            valset = torchvision.datasets.CIFAR100(root='../../datasets/data', train=False, download=True, transform=config.transform_test)
            val_loader = torch.utils.data.DataLoader(valset, batch_size=config.batch_size, shuffle=True, **kwargs)
    elif args.in_dataset == "imagenet":
        root = args.imagenet_root
        # TODO
        if 'train' in split:
            train_loader = torch.utils.data.DataLoader(
                torchvision.datasets.ImageFolder(os.path.join(root, 'train'), config.transform_train_largescale),
                batch_size=config.batch_size, shuffle=False, **kwargs)
        if 'val' in split:
            val_loader = torch.utils.data.DataLoader(
                torchvision.datasets.ImageFolder(os.path.join(root, 'val'), config.transform_test_largescale),
                batch_size=config.batch_size, shuffle=False, **kwargs)
    return {
        "train_loader": train_loader,
        "val_loader": val_loader,
        "num_classes": num_classes_dict[args.in_dataset],
    }

def get_ood_loader(args, config_type='default', split=('train', 'val')):

    config = {
        "default": {
            'transform_train': transform_train,
            'transform_test': transform_test,
            'transform_test_largescale': transform_test_largescale,
            'transform_train_largescale': transform_train_largescale,
            'batch_size': args.batch_size
        },
    }[config_type]
    train_ood_loader, val_ood_loader = None, None

    # if 'train' in split:
    #     if dataset[0].lower() == 'imagenet':
    #         train_ood_loader = torch.utils.data.DataLoader(
    #             ImageNet(transform=config.transform_train),
    #             batch_size=config.batch_size, shuffle=True, **kwargs)
    #     elif dataset[0].lower() == 'tim':
    #         train_ood_loader = torch.utils.data.DataLoader(
    #             TinyImages(transform=config.transform_train),
    #             batch_size=config.batch_size, shuffle=True, **kwargs)

    if 'val' in split:
        val_dataset = args.ood_dataset
        batch_size = args.batch_size
        imagesize = 224 if args.in_dataset in {'imagenet'} else 32
        if val_dataset == 'SVHN':
            val_ood_loader = DataLoader(datasets.SVHN(root='../../datasets', split='test', transform=transform_test, download=False),
                                                       batch_size=batch_size, shuffle=False,
                                                        num_workers=2)
        elif val_dataset == 'dtd':
            transform = config.transform_test_largescale if args.in_dataset in {'imagenet'} else config.transform_test
            val_ood_loader = DataLoader(datasets.ImageFolder(root="../../datasets/dtd/images", transform=transform),
                                                       batch_size=batch_size, shuffle=True, num_workers=2)
        elif val_dataset == 'places365':
            val_ood_loader = DataLoader(datasets.ImageFolder(root="../../datasets/places365", transform=transform_test),
                                                       batch_size=batch_size, shuffle=True, num_workers=2)
        elif val_dataset == 'CIFAR-100':
            val_ood_loader = DataLoader(datasets.CIFAR10(root='../../datasets', train=False, download=True, transform=transform_test),
                                                       batch_size=batch_size, shuffle=True, num_workers=2)
        elif val_dataset == 'CIFAR-10':
            val_ood_loader = DataLoader(datasets.CIFAR10(root='../../datasets', train=False, download=True, transform=transform_test),
                batch_size=batch_size, shuffle=True, num_workers=2)
            
        elif val_dataset == 'places365':
            transform = config.transform_test_largescale if args.in_dataset in {'imagenet'} else config.transform_test
            val_ood_loader = DataLoader(datasets.ImageFolder("../../datasets/places365",transform=transform), 
                                        batch_size=batch_size,shuffle=True, num_workers=2)
        elif val_dataset == 'isun':
            transform = config.transform_test_largescale if args.in_dataset in {'imagenet'} else config.transform_test
            val_ood_loader = DataLoader(datasets.ImageFolder("../../datasets/iSUN", transform=transform), 
                                        batch_size=batch_size, shuffle=True, num_workers=2)
        elif val_dataset == 'lsun':
            transform = config.transform_test_largescale if args.in_dataset in {'imagenet'} else config.transform_test
            val_ood_loader = DataLoader(datasets.ImageFolder("../../datasets/LSUN",transform=transform), 
                                        batch_size=batch_size, shuffle=True,num_workers=2)
        # elif val_dataset == 'tim':
        #     val_ood_loader = torch.utils.data.DataLoader(
        #         TinyImages(transform=transform_test),
        #         batch_size=batch_size, shuffle=True, num_workers=2)
        # elif val_dataset == 'imagenet':
        #     val_ood_loader = torch.utils.data.DataLoader(
        #         torchvision.datasets.ImageFolder(os.path.join('dataset/imagenet', 'val'), config.transform_test_largescale),
        #         batch_size=config.batch_size, shuffle=True, **kwargs)
        # elif val_dataset == 'noise':
        #     val_ood_loader = torch.utils.data.DataLoader(
        #         GaussianRandom(image_size=imagesize, data_size=10000),
        #         batch_size=batch_size, shuffle=False, num_workers=2)
        #     # val_ood_loader = torch.utils.data.DataLoader(
        #     #     GaussianRandom(image_size=imagesize, data_size=10000, transform=config.transform_test_largescale),
        #     #     batch_size=batch_size, shuffle=False, num_workers=2)
        # elif val_dataset == 'lfnoise':
        #     val_ood_loader = torch.utils.data.DataLoader(
        #         LowFreqRandom(image_size=imagesize, data_size=10000),
        #         batch_size=batch_size, shuffle=False, num_workers=2)
        else:
            print("Unvalid Dataset")
            # val_ood_loader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder("./datasets/ood_data/{}".format(val_dataset),
            #                                               transform=transform_test), batch_size=batch_size, shuffle=False, num_workers=2)

    return {
        "train_ood_loader": train_ood_loader,
        "val_ood_loader": val_ood_loader,
    }