import torch
import torch.nn as nn
import numpy as np
import torchvision
from torchvision.datasets import CIFAR10
from PIL import Image, ImageEnhance, ImageOps
import random
from timm.data import create_transform
from torchvision import datasets, transforms
from timm.data.transforms import str_to_pil_interp


def load_cifar10(simple_augmentation=False):
    # train_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
    #     torchvision.transforms.RandomCrop(32, padding=4),
    #     torchvision.transforms.RandomHorizontalFlip(),
    #     torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    #               ])
    # test_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
    #                     torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    #                     ])

    # train_transforms, test_transforms = get_transform(size=32, padding=4, mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
    if simple_augmentation:
        train_transforms = torchvision.transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                      ])
        test_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                            torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                            ])
    else:
        train_transforms = build_transform(True)
        test_transforms = build_transform(False)

    ds_train = CIFAR10(root='./data', train=True, download=True, transform=train_transforms)
    ds_tst = CIFAR10(root='./data', train=False, download=True, transform=test_transforms)

    return ds_train, ds_tst


def build_transform(is_train):
    resize_im = True
    if is_train:
        # this should always dispatch to transforms_imagenet_train
        transform = create_transform(
            input_size=224,
            is_training=True,
            color_jitter=0.4,
            auto_augment='rand-m9-mstd0.5-inc1',
            re_prob=0.25,
            re_mode='pixel',
            re_count=1,
            interpolation='bicubic',
        )
        return transform

    t = []
    if resize_im:
        size = int((256 / 224) * 224)
        t.append(
            transforms.Resize(size, interpolation=str_to_pil_interp('bicubic')),
            # to maintain same ratio w.r.t. 224 images
        )
        t.append(transforms.CenterCrop(224))
        # else:
        #     t.append(
        #         transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
        #                           interpolation=_pil_interp(config.DATA.INTERPOLATION))
        #     )

    t.append(transforms.ToTensor())
    t.append(transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)))
    return transforms.Compose(t)