# -*- coding: UTF-8 -*-

import torch
from torch.utils.data import Dataset, random_split
from torchvision import transforms, datasets
from torchvision.datasets import CIFAR10, MNIST, ImageFolder, ImageNet, DatasetFolder, CIFAR100, STL10, SVHN, Flowers102, Food101
from dataset.stanford_dog import SDog120
from dataset.flower102 import Flower102
import numpy as np
import cv2


def get_full_dataset(dataset_name, img_size=(32, 32)):
    if dataset_name == 'mnist':
        # mean = torch.Tensor((0.1307))
        # var = torch.Tensor((0.3081))
        train_dataset = MNIST('./data/mnist/', train=True, download=True,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),
                                  transforms.Resize(img_size),
                                  transforms.RandomHorizontalFlip(),
                                #   transforms.Normalize(mean.tolist(), var.tolist())
                              ]))
        test_dataset = MNIST('./data/mnist/', train=False, download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Resize(img_size),
                                #  transforms.Normalize(mean.tolist(), var.tolist())
                             ]))
        num_classes = 10
        num_channels = 1
    elif dataset_name == 'cifar10':
        mean = torch.Tensor((0.485, 0.456, 0.406))
        var = torch.Tensor((0.229, 0.224, 0.225))
        train_dataset = CIFAR10('./data/cifar10/', train=True, download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Resize(img_size),
                                    transforms.Pad(32, padding_mode="reflect"),
                                    transforms.RandomCrop(img_size),
                                    transforms.RandomHorizontalFlip(),
                                    # transforms.Normalize(mean.tolist(), var.tolist())
                                ]))
        test_dataset = CIFAR10('./data/cifar10/', train=False, download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Resize(img_size),
                                #    transforms.Normalize(mean.tolist(), var.tolist())
                               ]))
        num_classes = 10
        num_channels = 3
    elif dataset_name == 'cifar100':
        train_dataset = CIFAR100('./data/cifar100/', train=True, download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Resize(img_size),
                                    transforms.Pad(32, padding_mode="reflect"),
                                    transforms.RandomCrop(img_size),
                                    transforms.RandomHorizontalFlip(),
                                    # transforms.Normalize(mean.tolist(), var.tolist())
                                ]))
        test_dataset = CIFAR100('./data/cifar100/', train=False, download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Resize(img_size),
                                #    transforms.Normalize(mean.tolist(), var.tolist())
                               ]))
        num_classes = 100
        num_channels = 3
    elif dataset_name == 'stl10':
        train_dataset = STL10('./data/stl10/', split='train', download=False,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Resize(img_size),
                                    transforms.Pad(32, padding_mode="reflect"),
                                    transforms.RandomCrop(img_size),
                                    transforms.RandomHorizontalFlip(),
                                    # transforms.Normalize(mean.tolist(), var.tolist())
                                ]))
        test_dataset = STL10('./data/stl10/', split='test', download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Resize(img_size),
                                #    transforms.Normalize(mean.tolist(), var.tolist())
                               ]))
        num_classes = 10
        num_channels = 3
    elif dataset_name == 'svhn':
        train_dataset = SVHN('./data/svhn/', split='train', download=False,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Resize(img_size),
                                    transforms.Pad(32, padding_mode="reflect"),
                                    transforms.RandomCrop(img_size),
                                    transforms.RandomHorizontalFlip(),
                                    # transforms.Normalize(mean.tolist(), var.tolist())
                                ]))
        test_dataset = SVHN('./data/svhn/', split='test', download=False,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Resize(img_size),
                                #    transforms.Normalize(mean.tolist(), var.tolist())
                               ]))
        num_classes = 10
        num_channels = 3
    elif dataset_name == 'Flower102':
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        train_dataset = Flower102('./data/Flower102', True,
                                    transforms.Compose([
                                        transforms.Resize(256),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        # normalize,
                                  ]))
        test_dataset = Flower102('./data/Flower102', False,
                                    transforms.Compose([
                                        transforms.Resize(256),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        # normalize,
                                  ]))
        num_classes = 102
        num_channels = 3
    elif dataset_name == 'SDog120':
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        train_dataset = SDog120('./data/SDog120', True,
                                    transforms.Compose([
                                        transforms.Resize(256),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        # normalize,
                                  ]))
        test_dataset = SDog120('./data/SDog120', False,
                                    transforms.Compose([
                                        transforms.Resize(256),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        # normalize,
                                  ]))
        num_classes = 120
        num_channels = 3
    elif dataset_name == 'pets':
        full_dataset = ImageFolder('./data/PetImages/',
                                transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Resize(img_size),
                                # transforms.Normalize(mean.tolist(), var.tolist())
                            ]))
        train_size = int(0.8 * len(full_dataset))
        val_size = len(full_dataset) - train_size
        train_dataset, test_dataset = random_split(full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))
        train_dataset.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(img_size),
            transforms.Pad(4, padding_mode="reflect"),
            transforms.RandomCrop(img_size),
            transforms.RandomHorizontalFlip(),
            # transforms.Normalize(mean.tolist(), var.tolist())
        ])
        num_classes = 2
        num_channels = 3
    elif dataset_name == 'flowers':
        train_dataset = Flowers102('./data/Flowers102/', split='train', download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Resize(img_size),
                                    transforms.Pad(32, padding_mode="reflect"),
                                    transforms.RandomCrop(img_size),
                                    transforms.RandomHorizontalFlip(),
                                    # transforms.Normalize(mean.tolist(), var.tolist())
                                ]))
        test_dataset = Flowers102('./data/Flowers102/', split='val', download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Resize(img_size),
                                #    transforms.Normalize(mean.tolist(), var.tolist())
                               ]))
        num_classes = 102
        num_channels = 3
    elif dataset_name == 'food':
        train_dataset = Food101('./data/Food101/', split='train', download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Resize(img_size),
                                    transforms.Pad(32, padding_mode="reflect"),
                                    transforms.RandomCrop(img_size),
                                    transforms.RandomHorizontalFlip(),
                                    # transforms.Normalize(mean.tolist(), var.tolist())
                                ]))
        test_dataset = Flowers102('./data/Food101/', split='test', download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Resize(img_size),
                                #    transforms.Normalize(mean.tolist(), var.tolist())
                               ]))
        num_classes = 101
        num_channels = 3
    elif dataset_name == 'imagenet10':
        # mean = torch.Tensor((0.52283615, 0.47988218, 0.40605107))
        # var = torch.Tensor((0.29770654, 0.2888402, 0.31178293))
        full_dataset = ImageFolder('./data/imagenet10/train_set/',
                                    transform=transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Resize(img_size),
                                        # transforms.Normalize(mean.tolist(), var.tolist())
                                    ]))
        train_size = int(0.8 * len(full_dataset))
        val_size = len(full_dataset) - train_size
        train_dataset, test_dataset = random_split(full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))
        train_dataset.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(img_size),
            transforms.Pad(4, padding_mode="reflect"),
            transforms.RandomCrop(img_size),
            transforms.RandomHorizontalFlip(),
            # transforms.Normalize(mean.tolist(), var.tolist())
        ])
        num_classes = 10
        num_channels = 3
    elif dataset_name == "imagenet":
        # mean = torch.Tensor((0.485, 0.456, 0.406))
        # var = torch.Tensor((0.229, 0.224, 0.225))
        train_dataset = ImageNet("./data/imagenet/", split="train",
                                 transform=transforms.Compose([
                                     transforms.ToTensor(),
                                     transforms.Resize(img_size),
                                     transforms.Pad(32, padding_mode="reflect"),
                                     transforms.RandomCrop(img_size),
                                     transforms.RandomHorizontalFlip(),
                                    #  transforms.Normalize(mean.tolist(), var.tolist())
                                 ]))
        test_dataset = ImageNet("./data/Imagenet/", split="val",
                                transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Resize(img_size),
                                #    transforms.Normalize(mean.tolist(), var.tolist())
                               ]))
        num_classes = 1000
        num_channels = 3
    elif dataset_name == "imagenet100":
        # mean = torch.Tensor((0.485, 0.456, 0.406))
        # var = torch.Tensor((0.229, 0.224, 0.225))
        train_dataset = ImageFolder("./data/benign_100/train/", 
                                    transform=transforms.Compose([
                                        transforms.ToTensor(),
                                        # transforms.Resize(img_size),
                                        transforms.Pad(32, padding_mode="reflect"),
                                        transforms.RandomCrop(img_size),
                                        transforms.RandomHorizontalFlip(),
                                        # transforms.Normalize(mean.tolist(), var.tolist())
                                    ]))
        test_dataset = ImageFolder("./data/benign_100/val/", 
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Resize(img_size),
                                    #    transforms.Normalize(mean.tolist(), var.tolist())
                                   ]))
        num_classes = 100
        num_channels = 3
    elif dataset_name == "imagenet200":
        # mean = torch.Tensor((0.485, 0.456, 0.406))
        # var = torch.Tensor((0.229, 0.224, 0.225))
        train_dataset = ImageFolder("./data/sub-imagenet-200/train/", 
                                    transform=transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Resize(img_size),
                                        transforms.Pad(32, padding_mode="reflect"),
                                        transforms.RandomCrop(img_size),
                                        transforms.RandomHorizontalFlip(),
                                        # transforms.Normalize(mean.tolist(), var.tolist())
                                    ]))
        test_dataset = ImageFolder("./data/sub-imagenet-200/val/", 
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Resize(img_size),
                                    #    transforms.Normalize(mean.tolist(), var.tolist())
                                   ]))
        num_classes = 200
        num_channels = 3
    elif dataset_name == "vggface2":
        # mean = torch.Tensor((0.485, 0.456, 0.406))
        # var = torch.Tensor((0.229, 0.224, 0.225))
        train_dataset = ImageFolder("./data/vggface2/train/", 
                                    transform=transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Resize(img_size),
                                        transforms.Pad(32, padding_mode="reflect"),
                                        transforms.RandomCrop(img_size),
                                        transforms.RandomHorizontalFlip(),
                                        # transforms.Normalize(mean.tolist(), var.tolist())
                                    ]))
        test_dataset = ImageFolder("./data/vggface2/val/", 
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Resize(img_size),
                                    #    transforms.Normalize(mean.tolist(), var.tolist())
                                   ]))
        num_classes = 20
        num_channels = 3
    else:
        exit("Unknown Dataset")
    return train_dataset, test_dataset, num_classes, num_channels


# if __name__ == '__main__':
#     train_set, test_set = get_full_dataset('imagenet10', (224, 224))
#     # print(train_set.class_to_idx)
#     print(len(train_set))
#     print(len(test_set))
#     # print(train_set.targets)
