import os
import os.path
import hashlib
import errno
import torch
from torchvision import transforms
import numpy as np
import random
from PIL import Image, ImageEnhance, ImageOps, ImageFilter
from torchvision import transforms as T
import PIL
import PIL.ImageOps
import PIL.ImageEnhance
import PIL.ImageDraw
from PIL import Image

dataset_stats = {
    "CIFAR100": {
        "mean": (0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
        "std": (0.2673342858792409, 0.25643846291708816, 0.2761504713256834),
        "size": 224,
    },
    "ImageNet_R": {
        "mean": (0.485, 0.456, 0.406),
        "std": (0.229, 0.224, 0.225),
        "size": 224,
    },
    "CUB200": {
        "mean": (0.485, 0.456, 0.406),
        "std": (0.229, 0.224, 0.225),
        "size": 224,
    },
}

# transformations
def get_transform(dataset="cifar100", phase="test", aug=True, resize_imnet=False):
    transform_list = []
    # get out size
    crop_size = dataset_stats[dataset]["size"]

    # get mean and std
    dset_mean = (0.0, 0.0, 0.0)  # dataset_stats[dataset]['mean']
    dset_std = (1.0, 1.0, 1.0)  # dataset_stats[dataset]['std']
    # dset_mean = dataset_stats[dataset]['mean']
    # dset_std = dataset_stats[dataset]['std']

    if dataset == "ImageNet32" or dataset == "ImageNet84":
        transform_list.extend([transforms.Resize((crop_size, crop_size))])

    if phase == "train":
        transform_list.extend(
            [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(dset_mean, dset_std),
            ]
        )
    else:
        if dataset.startswith("ImageNet") or dataset == "CUB200":
            transform_list.extend(
                [
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize(dset_mean, dset_std),
                ]
            )
        else:
            transform_list.extend(
                [
                    transforms.Resize(224),
                    transforms.ToTensor(),
                    transforms.Normalize(dset_mean, dset_std),
                ]
            )

    return transforms.Compose(transform_list)


# from hide
def build_transform(is_train, input_size):
    resize_im = input_size > 32
    if is_train:
        scale = (0.05, 1.0)
        ratio = (3.0 / 4.0, 4.0 / 3.0)
        transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(input_size, scale=scale, ratio=ratio),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
            ]
        )
        return transform

    t = []
    if resize_im:
        size = int((256 / 224) * input_size)
        t.append(
            transforms.Resize(
                size, interpolation=3
            ),  # to maintain same ratio w.r.t. 224 images
        )
        t.append(transforms.CenterCrop(input_size))
    t.append(transforms.ToTensor())

    return transforms.Compose(t)


def build_cifar_transform(is_train, input_size):
    resize_im = input_size > 32
    if is_train:
        scale = (0.05, 1.0)
        ratio = (3.0 / 4.0, 4.0 / 3.0)
        transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(224, interpolation=3),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=63 / 255),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)
                ),
            ]
        )
        return transform

    t = []
    if resize_im:
        size = int((256 / 224) * input_size)
        t.append(
            transforms.Resize(
                size, interpolation=3
            ),  # to maintain same ratio w.r.t. 224 images
        )
        t.append(transforms.CenterCrop(input_size))
    t.append(transforms.ToTensor())
    t.append(
        transforms.Normalize(
            mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)
        )
    )

    return transforms.Compose(t)


# worse than previous get_transform; become even worse when using npz
def get_transform_hide(dataset="cifar100", phase="test", aug=True, resize_imnet=False):
    input_size = dataset_stats[dataset]["size"]

    if "cifar" in dataset:
        transform = build_cifar_transform(phase, input_size)
    else:
        transform = build_transform(phase, input_size)

    return transform


def check_integrity(fpath, md5):
    if not os.path.isfile(fpath):
        return False
    md5o = hashlib.md5()
    with open(fpath, "rb") as f:
        # read in 1MB chunks
        for chunk in iter(lambda: f.read(1024 * 1024), b""):
            md5o.update(chunk)
    md5c = md5o.hexdigest()
    if md5c != md5:
        return False
    return True


def download_url(url, root, filename, md5):
    from six.moves import urllib

    root = os.path.expanduser(root)
    fpath = os.path.join(root, filename)

    try:
        os.makedirs(root)
    except OSError as e:
        if e.errno == errno.EEXIST:
            pass
        else:
            raise

    # downloads file
    if os.path.isfile(fpath) and check_integrity(fpath, md5):
        print("Using downloaded and verified file: " + fpath)
    else:
        try:
            print("Downloading " + url + " to " + fpath)
            urllib.request.urlretrieve(url, fpath)
        except:
            if url[:5] == "https":
                url = url.replace("https:", "http:")
                print(
                    "Failed download. Trying https -> http instead."
                    " Downloading " + url + " to " + fpath
                )
                urllib.request.urlretrieve(url, fpath)
