import os
import torch
from dataset import DataLoader
from utils.models.resnet import resnet18
import torchvision.transforms as transforms
import warnings
from torch.utils.data import Dataset
from PIL import Image

warnings.filterwarnings("ignore")


def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    n_samples = int(source.size()[0])+int(target.size()[0])
    total = torch.cat([source, target], dim=0)
    total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    L2_distance = ((total0-total1)**2).sum(2)
    # num_iter = total0.size(2)
    # L2_distance = torch.zeros([20000, 20000, 512])
    # for i in range(num_iter):
    #     L2_distance[:, :, i] = (total0-total1)[:, :, i]**2

    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
    return sum(kernel_val)#/len(kernel_val)

def mmd(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    batch_size = int(source.size()[0])
    kernels = guassian_kernel(source, target,
                              kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
    XX = kernels[:batch_size, :batch_size]
    YY = kernels[batch_size:, batch_size:]
    XY = kernels[:batch_size, batch_size:]
    YX = kernels[batch_size:, :batch_size]
    loss = torch.mean(XX + YY - XY -YX)
    return loss


class CommonDataset(Dataset):

    def __init__(self, data_dir, is_train=True, img_size=224, corrupt=None, per_imgs=None, num_classes=None):
        "num_imgs: number of images per category"
        self.data = []
        self.labels = []
        self.is_train = is_train
        self.img_size = img_size
        if num_classes is None:
            self.classes_num = len(os.listdir(data_dir))
        else:
            self.classes_num = num_classes
        self.corrupt = corrupt
        self.per_imgs = per_imgs
        self.num_classes = num_classes

        if 'iid' in data_dir:
            distribution_option = True
        else:
            distribution_option = False

        if distribution_option:
            for class_id, dirs in enumerate(os.listdir(data_dir)):
                class_dir = os.path.join(data_dir, dirs)
                for property in os.listdir(class_dir):
                    property_dir = os.path.join(class_dir, property)
                    for img in os.listdir(property_dir):
                        self.data.append(os.path.join(class_dir, property, img))
                        self.labels.append(class_id)

        else:
            for class_id, dirs in enumerate(os.listdir(data_dir)):
                class_dir = os.path.join(data_dir, dirs)
                for i, basename in enumerate(os.listdir(class_dir)):
                    self.data.append(os.path.join(class_dir, basename))
                    self.labels.append(class_id)

        self.data_transform = transforms.Compose([
            # transforms.Resize((self.img_size, self.img_size), 0),
            transforms.Resize((256, 256), 0),
            transforms.CenterCrop(self.img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        path = self.data[item]
        image = Image.open(path).convert('RGB')

        image = self.data_transform(image)
        label = self.labels[item]
        return image, label

def initial(batch_size):
    imagenet_dataset = CommonDataset('/data/datasets/ILSVRC2012/train')
    imagenet_loader = DataLoader(imagenet_dataset,
                                  batch_size=batch_size,
                                  num_workers=8,
                                  shuffle=True)
    pre_trained_model = resnet18(pretrained=True)
    pre_trained_model = pre_trained_model.cuda(device=device[0]).eval()

    features_imagenet = []
    with torch.no_grad():
        for j, (images, labels) in enumerate(imagenet_loader):
            if j == 100:
                break
            images, labels = images.cuda(device=device[0]), labels.cuda(device=device[0])
            features, _ = pre_trained_model(images)
            features = features / torch.norm(features, 2, 1, True)
            features_imagenet.extend(features.cpu())
    features_imagenet = torch.stack(features_imagenet)
    torch.save(features_imagenet, './data/imagenet_100.pth')

def test(dataset, batch_size):
    train_dir = '/data/datasets/' + dataset + '/train'
    train_dataset = CommonDataset(train_dir)
    train_loader = DataLoader(train_dataset,
                               batch_size=batch_size,
                               num_workers=8,
                               shuffle=True)


    pre_trained_model = resnet18(pretrained=True)
    pre_trained_model = pre_trained_model.cuda(device=device[0]).eval()


    mmd_value = 0
    it = 0
    features_imagenet = torch.load('./data/imagenet_100.pth')
    num_iter = len(features_imagenet)//batch_size
    with torch.no_grad():
        for j, (images, labels) in enumerate(train_loader):
            if j >= 5 or images.size(0) != batch_size:
                break
            images, labels = images.cuda(device=device[0]), labels.cuda(device=device[0])
            features, _ = pre_trained_model(images)
            features = features / torch.norm(features, 2, 1, True)
            for i in range(num_iter):
                benchmark = features_imagenet[i*batch_size:(i+1)*batch_size]
                benchmark = benchmark/torch.norm(benchmark, 2, 1, True)
                mmd_value += mmd(features.cpu(), benchmark)
                it += 1

    print('dataset:{}, mmd:{}'.format(dataset, mmd_value/it))




if __name__ == '__main__':
    device = [2]
    datasets_list = ['alphabet','cifar10', 'food', 'flowers', 'cars', 'pets', 'animal_iid', ]
    initial(1000)
    for dataset in datasets_list:
        test(dataset=dataset, batch_size=1000)

