import pdb
import torchvision
import torch.utils.data as data
import torch
import sys
import math
from PIL import Image
import numpy as np
from torchvision import datasets, transforms
import os
from torch.utils.data import Subset, DataLoader, SubsetRandomSampler
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import ConcatDataset
from torch.utils.data import DataLoader, random_split
NORMALIZE_DICT = {
    'mnist': dict(mean=(0.1307,), std=(0.3081,)),
    'cifar10': dict(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
    'cifar100': dict(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)),
    'imagenet': dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    'mini_imagenet': dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    'OfficeHome': dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    'cub200': dict(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    'stanford_dogs': dict(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    'stanford_cars': dict(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    'places365_32x32': dict(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    'places365_64x64': dict(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    'places365': dict(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    'svhn': dict(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
}


def normalize(tensor, mean, std, reverse=False, keep_zero=True):
    if tensor.dim() not in (3, 4):
        raise ValueError("The input tensor must have 3 or 4 dimensions")

    if keep_zero:
        zero_mask = tensor == 0

    if reverse:
        _mean = [-m / s for m, s in zip(mean, std)]
        _std = [1 / s for s in std]
    else:
        _mean = mean
        _std = std

    _mean = torch.as_tensor(_mean, dtype=tensor.dtype, device=tensor.device)
    _std = torch.as_tensor(_std, dtype=tensor.dtype, device=tensor.device)

    if tensor.dim() == 4:
        _mean = _mean[None, :, None, None]
        _std = _std[None, :, None, None]
    else:
        _mean = _mean[:, None, None]
        _std = _std[:, None, None]

    tensor = (tensor - _mean) / _std

    if keep_zero:
        tensor[zero_mask] = 0

    return tensor


class Normalizer(object):
    def __init__(self, mean, std,keep_zero=True):
        self.mean = mean
        self.std = std
        self.keep_zero=keep_zero

    def __call__(self, x, reverse=False):
        return normalize(x, self.mean, self.std, reverse=reverse,keep_zero=self.keep_zero)


def build_transform(input_size=224, interpolation="bicubic",
                    mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),
                    crop_pct=0.875, aug=False,keep_zero=False,inverse_img=False):
    def _pil_interp(method):
        if method == "bicubic":
            return Image.BICUBIC
        elif method == "lanczos":
            return Image.LANCZOS
        elif method == "hamming":
            return Image.HAMMING
        else:
            return Image.BILINEAR
    resize_im = input_size > 32
    t = []
    if resize_im:
        if inverse_img==False:
            size = int(math.floor(input_size / crop_pct))
            ip = _pil_interp(interpolation)
            t.append(
                transforms.Resize(
                    size, interpolation=ip
                ),  # to maintain same ratio w.r.t. 224 images
            )
        t.append(transforms.CenterCrop(input_size))
        if aug:
            t.append(transforms.RandomHorizontalFlip())

    t.append(transforms.ToTensor())
    t.append(Normalizer(mean, std, keep_zero))
    return transforms.Compose(t)


def load_data(args):
    crop_pct = 0.875
    mean, std = NORMALIZE_DICT[args.dataset]['mean'], NORMALIZE_DICT[args.dataset]['std']
    train_transform = build_transform(input_size=224, interpolation="bicubic", mean=mean, std=std,
                                      crop_pct=crop_pct, aug=True, keep_zero=True, inverse_img=True)
    test_transform = build_transform(input_size=224, interpolation="bicubic", mean=mean, std=std, crop_pct=crop_pct,
                                     aug=False, keep_zero=True, inverse_img=False)

    normalizer = Normalizer(**NORMALIZE_DICT[args.dataset])
    if args.dataset == "cifar10":
        data_dir = os.path.join(args.data_dir, 'CIFAR10')

        train_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True,
                                                     transform=train_transform)
        test_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=False,
                                                     transform=test_transform)
        if args.use_test:
            proxy_dataset = test_dataset
        else:
            random_sampler = SubsetRandomSampler(torch.randperm(len(train_dataset)))
            subset_indices = list(random_sampler)[:int(0.2 * len(train_dataset))]
            subset_dataset = Subset(train_dataset, subset_indices)
            proxy_dataset = subset_dataset
        y_train = np.array(train_dataset.targets)
    elif args.dataset == "cifar100":
        data_dir = os.path.join(args.data_dir, 'CIFAR100')

        train_dataset = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True,
                                                     transform=train_transform)
        test_dataset = torchvision.datasets.CIFAR100(root=data_dir, train=False, download=False,
                                                     transform=test_transform)
        if args.use_test:
            proxy_dataset = test_dataset
        else:
            random_sampler = SubsetRandomSampler(torch.randperm(len(train_dataset)))
            subset_indices = list(random_sampler)[:int(0.2 * len(train_dataset))]
            subset_dataset = Subset(train_dataset, subset_indices)
            proxy_dataset = subset_dataset
        y_train = np.array(train_dataset.targets)

    elif args.dataset in ["mini_imagenet", "OfficeHome"]:
        train_data_dir = os.path.join(args.data_dir, args.dataset, 'train')
        test_data_dir = os.path.join(args.data_dir, args.dataset, 'val')
        train_dataset = datasets.ImageFolder(root=train_data_dir, transform=train_transform)
        test_dataset = datasets.ImageFolder(root=test_data_dir, transform=test_transform)
        y_train = np.array([int(s[1]) for s in train_dataset.imgs])

        if args.use_test:
            proxy_dataset = test_dataset
        else:
            random_sampler = SubsetRandomSampler(torch.randperm(len(train_dataset)))
            subset_indices = list(random_sampler)[:int(0.2 * len(train_dataset))]
            subset_dataset = Subset(train_dataset, subset_indices)
            proxy_dataset = subset_dataset
    else:
        train_dataset = None
        test_dataset = None
        proxy_dataset = None
        y_train = None

    return train_dataset, test_dataset, proxy_dataset, y_train, train_transform, normalizer
