from PIL import Image
from glob import glob
import os
from torch.utils.data import Dataset
import math
import numpy as np
import random
from .imagenet_dic import IMAGENET_DIC

def get_imagenet_dataset(data_root, config, class_num=None, random_crop=True, random_flip=False):
    train_dataset = IMAGENET_dataset(data_root, mode='train', class_num=class_num, img_size=config.data.image_size,
                                     random_crop=random_crop, random_flip=random_flip)
    test_dataset = IMAGENET_dataset(data_root, mode='val', class_num=class_num, img_size=config.data.image_size,
                                    random_crop=random_crop, random_flip=random_flip)

    return train_dataset, test_dataset


###################################################################


class IMAGENET_dataset(Dataset):
    def __init__(self, image_root, mode='val', class_num=None, img_size=512, random_crop=True, random_flip=False):
        super().__init__()
        if class_num is not None:
            if not os.path.isdir(os.path.join(image_root, "train", IMAGENET_DIC[str(class_num)][0])):
                os.makedirs(os.path.join(image_root, "val", IMAGENET_DIC[str(class_num)][0]), exist_ok=True)
                os.makedirs(os.path.join(image_root, "train", IMAGENET_DIC[str(class_num)][0]), exist_ok=True)
                print(f"scp -r mingi@165.132.183.115:/d/datasets/imagenet/val/{IMAGENET_DIC[str(class_num)][0]} {os.path.join(image_root, 'val', IMAGENET_DIC[str(class_num)][0])}")
                print(f"scp -r mingi@165.132.183.115:/d/datasets/imagenet/train/{IMAGENET_DIC[str(class_num)][0]} {os.path.join(image_root, 'train', IMAGENET_DIC[str(class_num)][0])}")
                exit()
            self.data_dir = os.path.join(image_root, mode, IMAGENET_DIC[str(class_num)][0], IMAGENET_DIC[str(class_num)][0], '*.jpeg' if mode == 'train' else '*.JPEG')
            self.image_paths = sorted(glob(self.data_dir))

        else:
            print("class_num is None")
            exit()
            self.data_dir = os.path.join(image_root, mode, '*', '*.JPEG')
            self.image_paths = sorted(glob(self.data_dir))


        self.img_size = img_size
        self.random_crop = random_crop
        self.random_flip = random_flip
        self.class_num = class_num

    def __getitem__(self, index):
        f = self.image_paths[index]
        pil_image = Image.open(f)
        pil_image.load()
        pil_image = pil_image.convert("RGB")

        if self.random_crop:
            arr = random_crop_arr(pil_image, self.img_size)
        else:
            arr = center_crop_arr(pil_image, self.img_size)

        if self.random_flip and random.random() < 0.5:
            arr = arr[:, ::-1]

        arr = arr.astype(np.float32) / 127.5 - 1

        # y = [self.class_num, IMAGENET_DIC[str(self.class_num)][0], IMAGENET_DIC[str(self.class_num)][1]]
        # y = self.class_num

        return np.transpose(arr, [2, 0, 1])#, y

    def __len__(self):
        return len(self.image_paths)


def center_crop_arr(pil_image, image_size):
    # We are not on a new enough PIL to support the `reducing_gap`
    # argument, which uses BOX downsampling at powers of two first.
    # Thus, we do it by hand to improve downsample quality.
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]


def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
    min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
    max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
    smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)

    # We are not on a new enough PIL to support the `reducing_gap`
    # argument, which uses BOX downsampling at powers of two first.
    # Thus, we do it by hand to improve downsample quality.
    while min(*pil_image.size) >= 2 * smaller_dim_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = smaller_dim_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = random.randrange(arr.shape[0] - image_size + 1)
    crop_x = random.randrange(arr.shape[1] - image_size + 1)
    return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
