import os
import wget
import shutil

from torchvision import datasets
import torchvision.transforms as transforms

from ..constants import IMAGENET_PATH, IMAGENET_O_PATH, IMAGENET_A_PATH

transform = transforms.Compose([
    transforms.Resize(256), 
    transforms.CenterCrop(224), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

def get_imagenet_test():
    return datasets.ImageFolder(root=os.path.join(IMAGENET_PATH, "val"), transform=transform)

def get_imagenet_o():
    os.makedirs(IMAGENET_O_PATH, exist_ok=True)
    _download_imagenet_o(IMAGENET_O_PATH)
    return datasets.ImageFolder(root=os.path.join(IMAGENET_O_PATH, "imagenet-o"), transform=transform)

def get_imagenet_a():
    os.makedirs(IMAGENET_A_PATH, exist_ok=True)
    _download_imagenet_a(IMAGENET_A_PATH)
    return datasets.ImageFolder(root=os.path.join(IMAGENET_A_PATH, "imagenet-a"), transform=transform)

def _download_imagenet_o(path:str):
    if not (os.path.exists(os.path.join(path, "imagenet-o.tar")) or os.path.exists(os.path.join(path, "imagenet-o"))):
        wget.download("https://people.eecs.berkeley.edu/~hendrycks/imagenet-o.tar", os.path.join(path, "imagenet-o.tar"))
    if not os.path.exists(os.path.join(path, "imagenet-o")):
        print("unpacking...")
        shutil.unpack_archive(os.path.join(path, "imagenet-o.tar"), os.path.join(path))
        print("unpacked")

def _download_imagenet_a(path:str):
    if not (os.path.exists(os.path.join(path, "imagenet-a.tar")) or os.path.exists(os.path.join(path, "imagenet-a"))):
        wget.download("https://people.eecs.berkeley.edu/~hendrycks/imagenet-a.tar", os.path.join(path, "imagenet-a.tar"))
    if not os.path.exists(os.path.join(path, "imagenet-a")):
        print("unpacking...")
        shutil.unpack_archive(os.path.join(path, "imagenet-a.tar"), os.path.join(path))
        print("unpacked")
