import torch
import torchvision
from torchvision import transforms
import os
import wget
from pathlib import Path
import tarfile
import cv2 as hcv2
import shutil
import zipfile

def get_outdist_datasets(load_out_names, image_size=32):
    out_datasets = []
    out_names = []
    if "cifar10" in load_out_names:
        transform = transforms.ToTensor() if image_size == 32 else transforms.Compose([transforms.ToTensor(),
                                                                                       transforms.Resize(image_size)])
        cifar10 = torchvision.datasets.CIFAR10(root='./datasets/data', train=False, download=True, transform=transform)
        out_names.append("cifar10")
        out_datasets.append(cifar10)
        print("finish loading cifar10")

    if "cifar100" in load_out_names:
        transform = transforms.ToTensor() if image_size == 32 else transforms.Compose([transforms.ToTensor(),
                                                                                       transforms.Resize(image_size)])
        cifar100 = torchvision.datasets.CIFAR100(root='./datasets/data', train=False, download=True,
                                                 transform=transform)
        out_names.append("cifar100")
        out_datasets.append(cifar100)
        print("finish loading cifar100")

    if "mnist" in load_out_names:
        mnist = torchvision.datasets.MNIST(root='./datasets/data', train=False, download=True,
                                           transform=transforms.Compose([transforms.ToTensor(),
                                                                         transforms.Resize(image_size),
                                                                         transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                                                                         ]))
        out_names.append("mnist")
        out_datasets.append(mnist)
        print("finish loading mnist")
    # tiny-imagenet
    if "tiny_imagenet" in load_out_names:
        dest = "./datasets/data/tiny-imagenet-200.zip"
        if not Path(dest).is_file():
            wget.download("http://cs231n.stanford.edu/tiny-imagenet-200.zip", out=dest, bar=wget.bar_thermometer)
            z = zipfile.ZipFile(dest)
            z.extractall(path="./datasets/data")
        tiny_imagenet = torchvision.datasets.ImageFolder(root='./datasets/data/tiny-imagenet-200/test',
                                                         transform=transforms.Compose([transforms.ToTensor(),
                                                                                       transforms.Resize(image_size)]))
        out_names.append("tiny_imagenet")
        out_datasets.append(tiny_imagenet)
        print("finish loading tiny_imagenet")
    # places 365
    if "places365" in load_out_names:
        dest="./datasets/data/categories_places365.txt"
        if not Path(dest).is_file():
            wget.download("https://dl.dropboxusercontent.com/s/enr71zpolzi1xzm/categories_places365.txt", out=dest)
        dest="./datasets/data/places365_val.txt"
        if not Path(dest).is_file():
            wget.download("https://dl.dropboxusercontent.com/s/gaf1ygpdnkhzyjo/places365_val.txt", out=dest)
        dest="./datasets/data/val_256.tar"
        if not Path(dest).is_file():
            wget.download("https://dl.dropboxusercontent.com/s/3pwqsyv33f6if3z/val_256.tar", out=dest)
        dest_final = "./datasets/data/val_256"
        if not Path(dest_final).is_dir():
            with tarfile.open(dest) as tar:
                tar.extractall(path="./datasets/data")

        places365 = torchvision.datasets.Places365(root='./datasets/data', split='val', small=True, download=False,
                                                   transform=transforms.Compose([transforms.ToTensor(),
                                                                                 transforms.Resize(image_size)]))
        out_names.append("places365")
        out_datasets.append(places365)
        print("finish loading places365")
    # LSUN
    if "LSUN" in load_out_names:
        dest = "./datasets/data/LSUN_resize.tar.gz"
        if not Path(dest).is_file():
            wget.download("https://bit.ly/3wA55Wb", out=dest)
            with tarfile.open(dest) as tar:
                tar.extractall(path="./datasets/data/LSUN_resize")
        transform = transforms.ToTensor() if image_size == 32 else transforms.Compose([transforms.ToTensor(),
                                                                                       transforms.Resize(image_size)])
        LSUN = torchvision.datasets.ImageFolder(root='./datasets/data/LSUN_resize', transform=transform)
        out_names.append("LSUN")
        out_datasets.append(LSUN)
        print("finish loading LSUN")
    # isun
    if "iSUN" in load_out_names:
        dest = "./datasets/data/iSUN.tar.gz"
        if not Path(dest).is_file():
            wget.download("https://bit.ly/3yRMTJe", out=dest)
            with tarfile.open(dest) as tar:
                tar.extractall(path="./datasets/data/iSUN")
        transform = transforms.ToTensor() if image_size == 32 else transforms.Compose([transforms.ToTensor(),
                                                                                       transforms.Resize(image_size)])
        iSUN = torchvision.datasets.ImageFolder(root='./datasets/data/iSUN', transform=transform)
        out_names.append("iSUN")
        out_datasets.append(iSUN)
        print("finish loading iSUN")
    # birds
    if "birds" in load_out_names:
        dest = "./datasets/data/images.tgz"
        if not Path(dest).is_file():
            wget.download("https://bit.ly/3lwLTCm", out=dest, bar=wget.bar_thermometer)
            with tarfile.open(dest) as tar:
                tar.extractall(path="./datasets/data/images")

        def myloader(path):
            path = path.split('/')
            new_path = []
            for r in path:
                new_path += r.split("\\")
            path = new_path
            if path[-1][0:2] == '._':
                path[-1] = path[-1][2:]
            path = '/'.join(path)
            img = hcv2.imread(path)
            img = hcv2.resize(img, (image_size, image_size))

            return img

        birds = torchvision.datasets.ImageFolder(root='./datasets/data/images', loader=myloader,
                                                 transform=transforms.ToTensor())
        out_names.append("birds")
        out_datasets.append(birds)
        print("finish loading birds")
    # Flowers
    if "flowers" in load_out_names:
        dest = "./datasets/data/102flowers.tgz"
        if not Path(dest).is_file():
            wget.download("https://dl.dropboxusercontent.com/s/hbt8e7wjiplryoo/102flowers.tgz", out=dest)
        dest_folder = "./datasets/data/flowers"
        if not Path(dest_folder).is_dir():
            os.mkdir("./datasets/data/flowers")
            os.mkdir("./datasets/data/flowers/jpg2")

            with tarfile.open(dest) as tar:
                tar.extractall(path="./datasets/data/flowers")
            shutil.move("./datasets/data/flowers/jpg/image_00009.jpg", "./datasets/data/flowers/jpg2")

        def myloader(path):
            img = hcv2.imread(path)
            img = hcv2.resize(img, (image_size, image_size))
            return img

        flowers = torchvision.datasets.ImageFolder(root='./datasets/data/flowers', loader=myloader,
                                                   transform=transforms.ToTensor())
        out_names.append("flowers")
        out_datasets.append(flowers)
        print("finish loading flowers")
    # coil100
    if "coil_100" in load_out_names:
        dest = "./datasets/data/coil-100.zip"
        if not Path(dest).is_file():
            wget.download(url="http://www.cs.columbia.edu/CAVE/databases/SLAM_coil-20_coil-100/coil-100/coil-100.zip",
                          out=dest)

            z = zipfile.ZipFile(dest)
            z.extractall('./datasets/data/coil')

        coil_100 = torchvision.datasets.ImageFolder(root='./datasets/data/coil',
                                                    transform=transforms.Compose([transforms.ToTensor(),
                                                                                  transforms.Resize(image_size)]))
        out_names.append("coil_100")
        out_datasets.append(coil_100)
        print("finish loading coil_100")
    return out_names, out_datasets


def get_outdist_dataloaders(load_out_names, image_size=32, batch_size=128, num_workers=0):
    out_names, out_datasets = [], []
    for l in load_out_names:
        name, dataset = get_outdist_datasets([l], image_size)
        out_names += name
        out_datasets += dataset

    for i, out_dataset in enumerate(out_datasets):
        out_datasets[i] = torch.utils.data.DataLoader(out_dataset, batch_size=batch_size,
                                                      shuffle=False, num_workers=num_workers)
    return list(zip(out_names, out_datasets))


def get_trainloader_cifar10(image_size=32, batch_size=128):
    # code from  ccj@mit.edu
    # load test set in-distribution
    transform = torchvision.transforms.ToTensor()
    if image_size != 32:
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((image_size, image_size)),
            torchvision.transforms.ToTensor(),
        ])
    testset_in = torchvision.datasets.CIFAR10(root='./datasets/data', train=True,
                                              download=True, transform=transform)
    testloader_in = torch.utils.data.DataLoader(testset_in, batch_size=batch_size,
                                                shuffle=False, num_workers=0)
    return "CIFAR 10", testloader_in, 10


def get_trainloader_cifar100(image_size=32, batch_size=128):
    # code from  ccj@mit.edu
    # load test set out-distribution
    transform = torchvision.transforms.ToTensor()
    if image_size != 32:
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((image_size, image_size)),
            torchvision.transforms.ToTensor(),
        ])
    testset_out = torchvision.datasets.CIFAR100(root='./datasets/data', train=True,
                                                download=True, transform=transform)
    testloader_out = torch.utils.data.DataLoader(testset_out, batch_size=batch_size,
                                                 shuffle=False, num_workers=0)
    return "CIFAR 100", testloader_out, 100

def get_trainloader_tiny_imagenet(image_size=32, batch_size=128):
    transform = torchvision.transforms.ToTensor()
    if image_size != 32:
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((image_size, image_size)),
            torchvision.transforms.ToTensor(),
        ])
    dest = "./datasets/data/tiny-imagenet-200.zip"
    if not Path(dest).is_file():
        wget.download("http://cs231n.stanford.edu/tiny-imagenet-200.zip", out=dest, bar=wget.bar_thermometer)
        z = zipfile.ZipFile(dest)
        z.extractall(path="./datasets/data")
    tiny_imagenet = torchvision.datasets.ImageFolder(root='./datasets/data/tiny-imagenet-200/train',
                                                     transform=transforms.Compose([transforms.ToTensor(),
                                                                                   transforms.Resize(image_size)]))
    testloader_out = torch.utils.data.DataLoader(tiny_imagenet, batch_size=batch_size,
                                                 shuffle=False, num_workers=0)
    return "Tiny imagenet", testloader_out, 200