import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.nn.functional import softmax
import PIL




def tensor2im(input_image, imtype=np.uint8,mean=None,std=None):
    """"将tensor的数据类型转成numpy类型，并反归一化.

    Parameters:
        input_image (tensor) --  输入的图像tensor数组
        imtype (type)        --  转换后的numpy的数据类型
    """
    if mean==None and std==None:
        mean = [0.485,0.456,0.406] #dataLoader中设置的mean参数
        std = [0.229,0.224,0.225]  #dataLoader中设置的std参数
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor): #如果传入的图片类型为torch.Tensor，则读取其数据进行下面的处理

            image_t = input_image.squeeze().data
        else:
            return input_image
        image_numpy = image_t.cpu().float().numpy()  # convert it into a numpy array

        for i in range(len(mean)): #反标准化
            image_numpy[i] = image_numpy[i] * std[i] + mean[i]
        image_numpy = image_numpy * 255 #反ToTensor(),从[0,1]转为[0,255]
        image_numpy = np.transpose(image_numpy, (1, 2, 0))  # 从(channels, height, width)变为(height, width, channels)

    else:  # 如果传入的是numpy数组,则不做处理
        image_numpy = input_image
    return image_numpy.astype(imtype)


def img_hist_show(image, label_name_list, distribution_list, true_label=None, figure=None):
    '''show image and it's histogram of top 5 prediction  '''
    if isinstance(image, torch.Tensor):
        image = tensor2im(image)
    elif isinstance(image, np.ndarray):
        pass
    else:
        raise Exception('image should be tensor or numpy!')

    if figure == None:
        figure = plt.figure(figsize=(10, 10))

    if len(label_name_list) != len(distribution_list):
        print('label_name_list should match distribution')

    grid = plt.GridSpec(nrows=2, ncols=2, wspace=0, figure=figure)
    ax1 = plt.subplot(grid[0, 1])
    ax1.axis('off')
    if true_label != None:
        ax1.set_title(true_label,fontsize=18)
    ax1.imshow(image)

    ax2 = plt.subplot(grid[0, 0])
    ax2.bar(x=[i * 10 for i in range(len(label_name_list))], height=distribution_list, width=4, alpha=0.6)
    ax2.set_ylim([0, 1])
    props = {
        #'xlabel': 'class',
        'xticks': [i * 10 for i in range(len(label_name_list))]
    }
    ax2.set_xlabel('class', fontsize=18)
    ax2.set(**props)

    _ = ax2.set_xticklabels(label_name_list, rotation=30, fontsize=18)

    plt.show()
    return figure


def ILSVRC_2012_TOP5_IMAGE_HIST(image_torch, output, class_name_list, true_label=None,mean=None,std=None):
    # 根据model预测结果output输出ILSVRC2012预测类别直方图
    # 输入数（C,H,W）的tensor，模型的预测结果，还有ILSVRC2012的类别，样本的true_label(str)
    top_1_index = output.argmax()
    top_1_name = class_name_list[top_1_index]

    top_5_index = torch.topk(output, k=5, dim=1, largest=True).indices
    top_5_name = [class_name_list[i] for i in top_5_index[0]]
    top_5_distribution = [float(softmax(output.data, dim=1).squeeze()[i]) for i in top_5_index[0]]

    print('truth:', true_label)
    print('top 1:', top_1_name)
    print('top 5 :', top_5_name)
    image_torch_2=image_torch.clone()
    print(image_torch_2.shape)

    image_np = tensor2im(image_torch_2.squeeze(),mean=mean,std=std)

    figure = plt.figure(figsize=(10, 10))
    figure = img_hist_show(image=image_np,
                           label_name_list=top_5_name,
                           distribution_list=top_5_distribution,
                           figure=figure,
                           true_label=true_label)



def patch_attach(image,patch,patch_mask,loc):
    #传入的是,image,patch与patch的mask
    #输出是在image上的patch与image上的patch_mask
    patch_mask_im=torch.zeros_like(image).squeeze()
    patch_mask_im[:,loc[0]:loc[0]+patch.shape[-2],loc[1]:loc[1]+patch.shape[-1]]=patch_mask.squeeze()
    patch_mask_im.unsqueeze(0)
    patch_mask_im_reverse=torch.abs(patch_mask_im-1)
    patch_im=torch.zeros_like(image)
    patch_im[0][:,loc[0]:loc[0]+patch.shape[-2],loc[1]:loc[1]+patch.shape[-1]]=patch
    return patch_im,patch_mask_im

