import os
import sys
sys.path.append('..')

import numpy as np
from PIL import Image
import torch
import torchvision

import data
from config import opt


normalize_dict = {
    'cifar10':
        {'mean': (0.4914, 0.4822, 0.4465), 'std': (0.2023, 0.1994, 0.2010)},
    'cifar100':
        {'mean': (0.5071, 0.4865, 0.4409), 'std': (0.2673, 0.2564, 0.2762)},
    'TinyImageNet':
        {'mean': (0.5071, 0.4865, 0.4409), 'std': (0.2673, 0.2564, 0.2762)},
    'ImageNet':
        {'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225)}
}

def tensor2im(input_image, mean, std, imtype=np.uint8):
    """"
    Parameters:
        input_image (tensor) --  input tensor
        imtype (type)        --  converted numpy data type
    """
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor.cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        for i in range(len(mean)):
            image_numpy[i] = image_numpy[i] * std[i] + mean[i]
        image_numpy = image_numpy * 255
        image_numpy = np.transpose(image_numpy, (1, 2, 0))  # post-processing: tranpose and scaling
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)


def save_img(im, mean, std, path, size):
    """im可是没经过任何处理的tensor类型的数据,将数据存储到path中

    Parameters:
        im (tensor) --  输入的图像tensor数组
        path (str)  --  图像保存的路径
        size (int)  --  一行有size张图,最好是2的倍数
    """
    im_grid = torchvision.utils.make_grid(im, size) #将batchsize的图合成一张图
    im_numpy = tensor2im(im_grid, mean, std) #转成numpy类型并反归一化
    im_array = Image.fromarray(im_numpy)
    im_array.save(path)


def visualize_sub_dataset(sub_dataset, loop):
    n_fuse = opt.n_fuse
    size = 5*(n_fuse+1)

    mean = normalize_dict[opt.surrogate_dataset]['mean']
    std = normalize_dict[opt.surrogate_dataset]['std']

    dataloader = torch.utils.data.DataLoader(
        sub_dataset,
        batch_size=50,
        shuffle=True,
        num_workers=4
    )
    for i, (images, perturbed_images, _, _) in enumerate(dataloader):
        # images: [nfuse, B, C, H, W]
        batch_images = torch.randn([0, 3, 32, 32])
        for j in range(perturbed_images.shape[0]):  # batch
            for k in range(n_fuse):
                batch_images = torch.cat((batch_images, images[k][j].unsqueeze(0)))
            batch_images = torch.cat((batch_images, perturbed_images[j].unsqueeze(0)))
        save_dir = os.path.join(opt.work_dir, f'results/visualize/{opt.victim_dataset}_{opt.surrogate_dataset}')
        os.makedirs(save_dir, exist_ok=True)
        save_name = f'{opt.sub_model}_{opt.seed}_loop{loop}.png'
        save_path = os.path.join(save_dir, save_name)
        save_img(batch_images, mean, std, save_path, size)
        break

    print('Query dataset visualized.')



# def tensor_to_pil_image(tensor):
#     np_array = np.moveaxis(tensor.cpu().numpy() * 255, 0, -1)
#     return Image.fromarray(np_array.astype(np.uint8))
#
#
# def visualize(tensor, save_path=None):
#     pil_image = tensor_to_pil_image(tensor)
#     if save_path:
#         pil_image.save()
#     return pil_image

if __name__ == '__main__':
    root = os.path.join(opt.data_dir, 'datasets')

    train_dataloader = data.CIFAR100().train_dataloader()
    for i, (images, targets) in enumerate(train_dataloader):
        save_img(images, os.path.join(opt.work_dir, 'results/visualize', 'train.png'), 16)
        break

    test_dataloader = data.CIFAR100().test_dataloader()
    for i, (images, targets) in enumerate(test_dataloader):
        save_img(images, os.path.join(opt.work_dir, 'results/visualize', 'test.png'), 16)
        break
