import os
import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import sys
sys.path.append(".")


def prepare_dataset(dataset, batch_size=128, pin_memory=False):
    from cfg import data_path
    data_path = os.path.join(data_path, dataset)
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2471, 0.2435, 0.2616]
    if dataset == "mnist":
        train_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        test_transform = train_transform
        train_data = datasets.MNIST(root = data_path, train = True, download = True, transform = train_transform)
        train_loader = DataLoader(train_data, batch_size, shuffle = True, num_workers=0)
        test_data = datasets.MNIST(root = data_path, train = False, download = True, transform = test_transform)
        test_loader = DataLoader(test_data, batch_size, shuffle = False, num_workers=0)
        cls_num = 10
    elif dataset == "cifar10":
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        train_data = datasets.CIFAR10(root = data_path, train = True, download = True, transform = train_transform)
        train_loader = DataLoader(train_data, batch_size, shuffle = True, num_workers=0, pin_memory=pin_memory)
        test_data = datasets.CIFAR10(root = data_path, train = False, download = True, transform = test_transform)
        test_loader = DataLoader(test_data, batch_size, shuffle = False, num_workers=0, pin_memory=pin_memory)
        cls_num = 10
    elif dataset == "cifar100":
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        train_data = datasets.CIFAR100(root = data_path, train = True, download = True, transform = train_transform)
        train_loader = DataLoader(train_data, 128, shuffle = True, num_workers=2)
        test_data = datasets.CIFAR100(root = data_path, train = False, download = True, transform = test_transform)
        test_loader = DataLoader(test_data, 128, shuffle = False, num_workers=2)
        cls_num = 100
    elif dataset == "svhn":
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        train_data = datasets.SVHN(root = data_path, split = 'train', download = True, transform = train_transform)
        train_loader = DataLoader(train_data, 128, shuffle = True, num_workers=2)
        test_data = datasets.SVHN(root = data_path, split = 'test', download = True, transform = test_transform)
        test_loader = DataLoader(test_data, 128, shuffle = False, num_workers=2)
        cls_num = 10
    elif dataset == "gtsrb":
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop((32, 32)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        test_transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        train_data = datasets.GTSRB(root = data_path, split = 'train', download = True, transform = train_transform)
        train_loader = DataLoader(train_data, 512, shuffle = True, num_workers=2)
        test_data = datasets.GTSRB(root = data_path, split = 'test', download = True, transform = test_transform)
        test_loader = DataLoader(test_data, 512, shuffle = False, num_workers=2)
        cls_num = 43
    elif dataset == "flowers102":
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        test_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        train_data = datasets.Flowers102(root = data_path, split = "train", download = True, transform = train_transform)
        train_loader = DataLoader(train_data, 512, shuffle = True, num_workers=4)
        test_data = datasets.Flowers102(root = data_path, split = "test", download = True, transform = test_transform)
        test_loader = DataLoader(test_data, 512, shuffle = False, num_workers=4)
        cls_num = 102
    elif dataset == "country211":
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        test_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        train_data = datasets.Country211(root = data_path, split = "train", download = True, transform = train_transform)
        train_loader = DataLoader(train_data, 512, shuffle = True, num_workers=4)
        test_data = datasets.Country211(root = data_path, split = "test", download = True, transform = test_transform)
        test_loader = DataLoader(test_data, 512, shuffle = False, num_workers=4)
        cls_num = 211
    elif dataset == "oxfordpets":
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        test_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        train_data = datasets.OxfordIIITPet(root = data_path, split = "trainval", download = True, transform = train_transform)
        train_loader = DataLoader(train_data, 512, shuffle = True, num_workers=4)
        test_data = datasets.OxfordIIITPet(root = data_path, split = "test", download = True, transform = test_transform)
        test_loader = DataLoader(test_data, 512, shuffle = False, num_workers=4)
        cls_num = 37
    else:
        raise NotImplementedError
    return {
        'train': train_loader,
        'test': test_loader,
    }, cls_num
