import os

from PIL import Image

import cv2
import numpy as np
import  matplotlib.pyplot as plt

plt.style.use('seaborn')

from LIMA.utils import *
import matplotlib

from LIMA.models.submodular_single_modal import BlackBoxSingleModalSubModularExplanationQuicktest, BlackBoxSingleModalSubModularExplanationQuicktestRest
from transformers import ViTForImageClassification

matplotlib.get_cachedir()
plt.rc('font', family="Times New Roman")

from sklearn import metrics
import torch
from torchvision import transforms


def SubRegionDivision(image, mode="slico"):
    element_sets_V = []
    if mode == "slico":
        slic = cv2.ximgproc.createSuperpixelSLIC(image, region_size=30, ruler = 20.0)
        slic.iterate(20)     # The number of iterations, the larger the better the effect
        label_slic = slic.getLabels()        # Get superpixel label
        number_slic = slic.getNumberOfSuperpixels()  # Get the number of superpixels

        for i in range(number_slic):
            img_copp = image.copy()
            img_copp = img_copp * (label_slic == i)[:,:, np.newaxis]
            element_sets_V.append(img_copp)
    elif mode == "seeds":
        seeds = cv2.ximgproc.createSuperpixelSEEDS(image.shape[1], image.shape[0], image.shape[2], num_superpixels=50, num_levels=3)
        seeds.iterate(image,10)  # The input image size must be the same as the initialization shape and the number of iterations is 10
        label_seeds = seeds.getLabels()
        number_seeds = seeds.getNumberOfSuperpixels()

        for i in range(number_seeds):
            img_copp = image.copy()
            img_copp = img_copp * (label_seeds == i)[:,:, np.newaxis]
            element_sets_V.append(img_copp)
    return element_sets_V


def preprocess(image):
    image = Image.fromarray(image, 'RGB')
    transform = transforms.Compose([
        # transforms.Resize(256),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                      std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).cuda()


def imshow(img):
    """
    Visualizing images inside jupyter notebook
    """
    plt.axis('off')
    if len(img.shape)==3:
        img = img[:,:,::-1] 	# transform image to rgb
    else:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    plt.imshow(img)
    # plt.show()
    plt.savefig("example_plot.png")


def generate_masks(shape, patch_size):
    # 获取图像的宽度和高度
    height, width, channels = shape

    # 计算每个小块的宽度和高度
    patch_width = width // patch_size
    patch_height = height // patch_size

    # 初始化一个数组来存储所有掩码
    masks = np.zeros((patch_size * patch_size, height, width, channels), dtype=np.uint8)

    # 生成掩码
    for i in range(patch_size):
        for j in range(patch_size):
            # 计算当前小块的起始和结束位置
            start_x = j * patch_width
            start_y = i * patch_height
            end_x = start_x + patch_width
            end_y = start_y + patch_height

            # 创建掩码
            mask = np.zeros((height, width, channels), dtype=np.uint8)
            mask[start_y:end_y, start_x:end_x, :] = 1

            # 将掩码存储到 masks 数组中
            masks[i * patch_size + j] = mask

    return masks


def generate_masks_torch(patch_size):
    """
    Generate masks for patches in a given image shape.

    Args:
    shape: Tuple (height, width, channels) representing the shape of the image.
    patch_size: The number of patches along one dimension (assuming square patches).

    Returns:
    masks: Tensor of shape (patch_size * patch_size, height, width, channels) containing the masks.
    """
    # 计算每个小块的宽度和高度
    patch_width = 224 // patch_size
    patch_height = 224 // patch_size

    # 初始化一个张量来存储所有掩码
    masks = torch.zeros((patch_size * patch_size, 3, 224, 224), dtype=torch.uint8)

    # 生成掩码
    for i in range(patch_size):
        for j in range(patch_size):
            # 计算当前小块的起始和结束位置
            start_x = j * patch_width
            start_y = i * patch_height
            end_x = start_x + patch_width
            end_y = start_y + patch_height

            # 创建掩码
            mask = torch.zeros((3, 224, 224), dtype=torch.uint8)
            mask[:, start_y:end_y, start_x:end_x] = 1

            # 将掩码存储到 masks 张量中
            masks[i * patch_size + j] = mask

    return masks


def visualization(image, submodular_image_set, saved_json_file, save_name):
    insertion_ours_images = []
    deletion_ours_images = []

    insertion_image = submodular_image_set[0]
    insertion_ours_images.append(insertion_image)
    deletion_ours_images.append(image - insertion_image)
    for smdl_sub_mask in submodular_image_set[1:]:
        insertion_image = insertion_image.copy() + smdl_sub_mask
        insertion_ours_images.append(insertion_image)
        deletion_ours_images.append(image - insertion_image)

    insertion_ours_images_input_results = np.array(saved_json_file["consistency_score"])

    ours_best_index = np.argmax(insertion_ours_images_input_results)
    x = [(insertion_ours_image.sum(-1)!=0).sum() / (image.shape[0] * image.shape[1]) for insertion_ours_image in insertion_ours_images]
    i = len(x)

    fig, [ax2, ax3] = plt.subplots(1,2, gridspec_kw = {'width_ratios':[1, 1.5]}, figsize=(24,8))
    ax2.spines["left"].set_visible(False)
    ax2.spines["right"].set_visible(False)
    ax2.spines["top"].set_visible(False)
    ax2.spines["bottom"].set_visible(False)
    ax2.xaxis.set_visible(False)
    ax2.yaxis.set_visible(False)
    ax2.set_title('Ours', fontsize=54)
    ax2.set_facecolor('white')

    plt.xlim((0, 1))
    plt.ylim((0, 1))
    plt.xticks(fontsize=36)
    plt.yticks(fontsize=36)
    plt.title('Insertion', fontsize=54)
    plt.ylabel('Recognition Score', fontsize=44)
    plt.xlabel('Percentage of image revealed', fontsize=44)

    x_ = x[:i]
    ours_y = insertion_ours_images_input_results[:i]
    ax3.plot(x_, ours_y, color='dodgerblue', linewidth=3.5)  # draw curve

    # plt.legend(["Ours"], fontsize=40, loc="upper left")
    plt.scatter(x_[-1], ours_y[-1], color='dodgerblue', s=54)  # Plot latest point

    kernel = np.ones((3, 3), dtype=np.uint8)
    plt.plot([x_[ours_best_index], x_[ours_best_index]], [0, 1], color='red', linewidth=3.5)  # 绘制红色曲线

    # Ours
    mask = (image - insertion_ours_images[ours_best_index]).mean(-1)
    mask[mask>0] = 1

    dilate = cv2.dilate(mask, kernel, 3)
    # erosion = cv2.erode(dilate, kernel, iterations=3)
    # dilate = cv2.dilate(erosion, kernel, 2)
    edge = dilate - mask
    # erosion = cv2.erode(dilate, kernel, iterations=1)

    image_debug = image.copy()

    image_debug[mask>0] = image_debug[mask>0] * 0.5
    image_debug[edge>0] = np.array([0,0,255])
    ax2.imshow(image_debug[...,::-1])
    plt.savefig(save_name)



model = ViTForImageClassification.from_pretrained('pretrained_model',
                                                      subfolder='vit-base-patch16-224',
                                                      ignore_mismatched_sizes=True)

state_dict = torch.load('checkpoint/b16_224/epoch10-baseline/model_best.pth.tar')['state_dict']
# state_dict = torch.load('checkpoint/lima/b16_224/-epoch10-lambda0.5-reg_freq20/checkpoint_2000.pth.tar')['state_dict']
# state_dict = torch.load('checkpoint/lima/b16_224/-epoch10-lambda0.5-reg_freq20/checkpoint_001.pth.tar')['state_dict']
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model = torch.nn.DataParallel(model)
model = model.cuda()


# smdl = BlackBoxSingleModalSubModularExplanation(
#         model,
#         preprocess,
#         k=50,
#         lambda1=1,
#         lambda2=1,
#         lambda3=1)
#
# # image_path = "/mnt/huawei/jiaoxh/data/ImageNet100/val/n01558993/ILSVRC2012_val_00001598.JPEG"
# image_path = "/mnt/huawei/jiaoxh/data/ImageNet100/val/n13040303/ILSVRC2012_val_00000058.JPEG"
# image = cv2.imread(image_path)
# image = cv2.resize(image, (224, 224))
#
# # element_sets_V = SubRegionDivision(image, mode="seeds")
#
# patches = generate_masks(image.shape, 8)
# element_sets_V = patches * image
#
# smdl.k = len(element_sets_V)
# submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V)
# print(saved_json_file["sequence"])
# visualization(image, submodular_image_set, saved_json_file, "example_plot.png")
# # print(submodular_image.shape, submodular_image_set.shape)
# # for i in range(50):
# #     cv2.imwrite('masked_img/1/{}.jpg'.format(i), submodular_image_set[i])
# #     if saved_json_file["consistency_score"][i]>0.8:
# #         break
#
# image_flip = cv2.flip(image, 1)
# element_sets_V_flip = []
# for element in element_sets_V:
#     element_sets_V_flip.append(cv2.flip(element, 1))
#
# # element_sets_V_flip = SubRegionDivision(image_flip, mode="slico")
#
# # patches = generate_masks(image.shape, 8)
# # element_sets_V_flip = patches * image_flip
#
# submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V_flip)
# visualization(image_flip, submodular_image_set, saved_json_file, "example_plot_flip.png")
# # for i in range(50):
# #     cv2.imwrite('masked_img/2/{}.jpg'.format(i), submodular_image_set[i])
# #     if saved_json_file["consistency_score"][i]>0.8:
# #         break
#
# #
# # image_flip = cv2.flip(image, 0)
# # element_sets_V_flip = []
# # for element in element_sets_V:
# #     element_sets_V_flip.append(cv2.flip(element, 0))
# #
# # # element_sets_V_flip = SubRegionDivision(image_flip, mode="slico")
# #
# # # patches = generate_masks(image.shape, 8)
# # # element_sets_V_flip = patches * image_flip
# #
# # submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V_flip)
# # visualization(image_flip, submodular_image_set, saved_json_file, "example_plot_flip1.png")
# # for i in range(50):
# #     cv2.imwrite('masked_img/3/{}.jpg'.format(i), submodular_image_set[i])
# #     if saved_json_file["consistency_score"][i]>0.8:
# #         break
# #
# #
# # image_flip = cv2.flip(image, -1)
# # element_sets_V_flip = []
# # for element in element_sets_V:
# #     element_sets_V_flip.append(cv2.flip(element, -1))
# #
# # # element_sets_V_flip = SubRegionDivision(image_flip, mode="slico")
# #
# # # patches = generate_masks(image.shape, 8)
# # # element_sets_V_flip = patches * image_flip
# #
# # submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V_flip)
# # visualization(image_flip, submodular_image_set, saved_json_file, "example_plot_flip2.png")
# # for i in range(50):
# #     cv2.imwrite('masked_img/4/{}.jpg'.format(i), submodular_image_set[i])
# #     if saved_json_file["consistency_score"][i]>0.8:
# #         break







# smdl = BlackBoxSingleModalSubModularExplanationQuicktest(
#         model,
#         preprocess,
#         k=50,
#         lambda1=1,
#         lambda2=1,
#         lambda3=1)
smdl = BlackBoxSingleModalSubModularExplanationQuicktestRest(
        model,
        preprocess,
        k=50,
        lambda1=1,
        lambda2=1,
        lambda3=1)
# image_path = "/mnt/huawei/jiaoxh/data/ImageNet100/val/n01558993/ILSVRC2012_val_00001598.JPEG"
image_path = "/mnt/huawei/jiaoxh/data/ImageNet100/val/n13040303/ILSVRC2012_val_00000058.JPEG"
#
#
# image_path = "/mnt/huawei/jiaoxh/data/ImageNet100/val/n02114855/ILSVRC2012_val_00016911.JPEG"
# image_path = "/mnt/huawei/jiaoxh/data/ImageNet100/val/n04127249/ILSVRC2012_val_00012191.JPEG"
# image_path = "/mnt/huawei/jiaoxh/data/ImageNet100/val/n01983481/ILSVRC2012_val_00034402.JPEG"      # 这张图片应该是错判了
# image_path = "/mnt/huawei/jiaoxh/data/ImageNet100/val/n02105505/ILSVRC2012_val_00023582.JPEG"
# 'n03837869/ILSVRC2012_val_00044841.JPEG'
# 'n02105505/ILSVRC2012 val 00034063.JPEG'
#
# image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n04429376/ILSVRC2012_val_00028100.JPEG'

# image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n01855672/ILSVRC2012_val_00043968.JPEG'
image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n02009229/ILSVRC2012_val_00033760.JPEG'


# 最后可用的图 # todo
# image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n02138441/ILSVRC2012_val_00012478.JPEG'

image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n04229816/ILSVRC2012_val_00005577.JPEG'


image = cv2.imread(image_path)
image = cv2.resize(image, (224, 224))
# image = cv2.resize(image, (256, 256))
# image = image[16:240, 16:240]

# element_sets_V = SubRegionDivision(image, mode="seeds")

patches = generate_masks(image.shape, 6)

image_tensor = preprocess(image)
patches_tensor = generate_masks_torch(6).cuda()
element_sets_V = patches_tensor * image_tensor



# element_sets_V = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(element_sets_V)

smdl.k = len(element_sets_V)
# submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V)
submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V, [])
# visualization(image, submodular_image_set, saved_json_file, "example_plot.png")
print(saved_json_file["sequence"])
# print(saved_json_file["confidence_score"])
print(saved_json_file["consistency_score"])
# print(saved_json_file["collaboration_score"])
visualization(image, (patches * image)[saved_json_file["sequence"]], saved_json_file, "result/baseline.png")
# print(submodular_image.shape, submodular_image_set.shape)
# for i in range(50):
#     cv2.imwrite('masked_img/1/{}.jpg'.format(i), submodular_image_set[i])
#     if saved_json_file["consistency_score"][i]>0.8:
#         break

image_flip = cv2.flip(image, 1)
# element_sets_V_flip = []
# for element in element_sets_V:
#     element_sets_V_flip.append(cv2.flip(element, 1))
element_sets_V_flip = torch.flip(element_sets_V, [3])

# element_sets_V_flip = SubRegionDivision(image_flip, mode="slico")

# patches = generate_masks(image.shape, 8)
# element_sets_V_flip = patches * image_flip

# submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V_flip)
submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V_flip, [])
element_sets_V_flip_ = []
for element in (patches * image):
    element_sets_V_flip_.append(cv2.flip(element, 1))
element_sets_V_flip_ = np.array(element_sets_V_flip_)

visualization(image_flip, element_sets_V_flip_[saved_json_file["sequence"]], saved_json_file, "result/baseline_flip.png")
print(saved_json_file["sequence"])
print(saved_json_file["consistency_score"])




model = ViTForImageClassification.from_pretrained('pretrained_model',
                                                      subfolder='vit-base-patch16-224',
                                                      ignore_mismatched_sizes=True)
# state_dict = torch.load('checkpoint/b16_224/epoch10-baseline/model_best.pth.tar')['state_dict']
# state_dict = torch.load('checkpoint/lima/b16_224/-epoch10-lambda0.5-reg_freq20/checkpoint_001.pth.tar')['state_dict']
# state_dict = torch.load('checkpoint/lima/b16_224/0904-epoch10-lambda0.5-reg_freq20/model_best.pth.tar')['state_dict']
# state_dict = torch.load('/mnt/huawei/jiaoxh/CGC/checkpoint/lima/b16_224/0905_1-epoch10-lambda0.5-reg_freq20/model_best.pth.tar')['state_dict']
# state_dict = torch.load('/mnt/huawei/jiaoxh/CGC/checkpoint/lima/b16_224/0905_2-epoch10-lambda0.5-reg_freq20/model_best.pth.tar')['state_dict']
# state_dict = torch.load('checkpoint/lima/b16_224/0908_2-epoch10-lambda0.5-reg_freq20/checkpoint_9200.pth.tar')['state_dict']

state_dict = torch.load('checkpoint/lima/b16_224/ab-lambda/epoch10-lambda0.1-reg_freq20/model_best.pth.tar')['state_dict']
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model = torch.nn.DataParallel(model)
model = model.cuda()
# smdl = BlackBoxSingleModalSubModularExplanationQuicktest(model, preprocess, k=50, lambda1=1, lambda2=1, lambda3=1)
smdl = BlackBoxSingleModalSubModularExplanationQuicktestRest(model, preprocess, k=50, lambda1=1, lambda2=1, lambda3=1)

smdl.k = len(element_sets_V)
# submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V)
submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V, [])
print(saved_json_file["sequence"])
print(saved_json_file["consistency_score"])
# print(saved_json_file["confidence_score"])
# print(saved_json_file["consistency_score"])
# print(saved_json_file["collaboration_score"])
visualization(image, (patches * image)[saved_json_file["sequence"]], saved_json_file, "result/reg.png")

# submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V_flip)
submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V_flip, [])
visualization(image_flip, element_sets_V_flip_[saved_json_file["sequence"]], saved_json_file, "result/reg_flip.png")
print(saved_json_file["sequence"])
print(saved_json_file["consistency_score"])
