import os
import random
import torch
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import pickle
import torchvision.datasets as datasets
import os
import os.path as osp


def set_seed(seed=2020):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False

def load_obj(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


def makedirs(dir):
    if not osp.exists(dir):
        os.makedirs(dir)


def save_obj(obj, path):
    with open(path, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

    f.close()


def accuracy(out, y):
    _, pred = out.max(1)
    correct = pred.eq(y)
    return 100 * correct.sum().float() / y.size(0)


def get_data(args, shuffle=False):
    # mean/std stats
    if args.dataset == 'cifar10':
        data_class = 'CIFAR10'
        num_classes = 10
        stats = {
            'mean': [0.491, 0.482, 0.447],
            'std': [0.247, 0.243, 0.262]
        }
        if args.data_augmentation:
            trans_da = [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(0.5),
                transforms.ToTensor(),
                lambda t: t.type(torch.get_default_dtype()),
                transforms.Normalize(**stats)
            ]

            tr_data = getattr(datasets, data_class)(
                root=args.path,
                train=True,
                download=True,
                transform=transforms.Compose(trans_da)
            )

            te_data = getattr(datasets, data_class)(
                root=args.path,
                train=False,
                download=True,
                transform=transforms.Compose(trans_da)
            )
        else:
            trans = [
                transforms.ToTensor(),
                lambda t: t.type(torch.get_default_dtype()),
                transforms.Normalize(**stats)
            ]

            tr_data = getattr(datasets, data_class)(
                root=args.path,
                train=True,
                download=True,
                transform=transforms.Compose(trans)
            )

            te_data = getattr(datasets, data_class)(
                root=args.path,
                train=False,
                download=True,
                transform=transforms.Compose(trans)
            )

        # get tr_loader for train/eval and te_loader for eval
        train_loader = torch.utils.data.DataLoader(
            dataset=tr_data,
            batch_size=args.batch_size_train,
            shuffle=shuffle,
        )

        train_loader_eval = torch.utils.data.DataLoader(
            dataset=tr_data,
            batch_size=args.batch_size_eval,
            shuffle=False,
        )
        test_loader_eval = torch.utils.data.DataLoader(
            dataset=te_data,
            batch_size=args.batch_size_eval,
            shuffle=False,
        )
    elif args.dataset == "mnist":
        data_class = 'MNIST'
        num_classes = 10
        stats = {
            'mean': [0.1307],
            'std': [0.3081]
        }
        # input transformation w/o preprocessing for now
        trans_da = [
            transforms.ToTensor(),
            lambda t: t.type(torch.get_default_dtype()),
            transforms.Normalize(**stats)
        ]

        # get tr and te data with the same normalization
        tr_data = getattr(datasets, data_class)(
            root=args.path,
            train=True,
            download=True,
            transform=transforms.Compose(trans_da)
        )

        te_data = getattr(datasets, data_class)(
            root=args.path,
            train=False,
            download=True,
            transform=transforms.Compose(trans_da)
        )
        # get tr_loader for train/eval and te_loader for eval
        train_loader = torch.utils.data.DataLoader(
            dataset=tr_data,
            batch_size=args.batch_size_train,
            shuffle=False,
        )

        train_loader_eval = torch.utils.data.DataLoader(
            dataset=tr_data,
            batch_size=args.batch_size_eval,
            shuffle=False,
        )
        test_loader_eval = torch.utils.data.DataLoader(
            dataset=te_data,
            batch_size=args.batch_size_eval,
            shuffle=False,
        )
    else:
        raise NotImplementedError(f"Unknown dataset: {args.dataset}.")

    return train_loader, test_loader_eval, train_loader_eval, num_classes

