import torch
import torch.nn as nn
import numpy as np
import torchvision
from torchvision.datasets import ImageNet
from PIL import Image, ImageEnhance, ImageOps
import random
from timm.data import create_transform
from torchvision import datasets, transforms
from timm.data.transforms import str_to_pil_interp
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import os


def load_imagenet_100(data_path, simple_augmentation=False):
    # train_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
    #     torchvision.transforms.RandomCrop(32, padding=4),
    #     torchvision.transforms.RandomHorizontalFlip(),
    #     torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    #               ])
    # test_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
    #                     torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    #                     ])

    # train_transforms, test_transforms = get_transform(size=32, padding=4, mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
    train_transforms = build_transform(True, simple_augmentation)
    test_transforms = build_transform(False)

    # ds_train = ImageNet(data_path, split='train', transform=train_transforms)
    # ds_tst = ImageNet(data_path, split='val', transform=test_transforms)
    ds_train = datasets.ImageFolder(os.path.join(data_path, 'train'), transform=train_transforms)
    ds_tst = datasets.ImageFolder(os.path.join(data_path, 'val'), transform=test_transforms)

    return ds_train, ds_tst


def build_transform(is_train, simple_augmentation=False):
    resize_im = True
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    if is_train:
        # this should always dispatch to transforms_imagenet_train
        if simple_augmentation:
            # transform = create_transform(
            #     input_size=224,
            #     is_training=True,
            #     hflip=0.5,
            #     interpolation='bicubic',
            # )
            transform_train = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
            return transform_train
        else:
            transform = create_transform(
                input_size=224,
                is_training=True,
                color_jitter=0.4,
                auto_augment='rand-m9-mstd0.5-inc1',
                re_prob=0.25,
                re_mode='pixel',
                re_count=1,
                interpolation='bicubic',
            )
        return transform

    #
    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    return transform_test
    # t = []
    # if resize_im:
    #     size = int((256 / 224) * 224)
    #     t.append(
    #         transforms.Resize(size, interpolation=str_to_pil_interp('bicubic')),
    #         # to maintain same ratio w.r.t. 224 images
    #     )
    #     t.append(transforms.CenterCrop(224))
    #     # else:
    #     #     t.append(
    #     #         transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
    #     #                           interpolation=_pil_interp(config.DATA.INTERPOLATION))
    #     #     )
    #
    # t.append(transforms.ToTensor())
    # t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
    # return transforms.Compose(t)