import os

from PIL import Image

import cv2
import numpy as np
import  matplotlib.pyplot as plt

plt.style.use('seaborn')

from LIMA.utils import *
import matplotlib

from LIMA.models.submodular_single_modal import BlackBoxSingleModalSubModularExplanationQuicktest, BlackBoxSingleModalSubModularExplanationQuicktestRest
from transformers import ViTForImageClassification

matplotlib.get_cachedir()
plt.rc('font', family="Times New Roman")

from sklearn import metrics
import torch
from torchvision import transforms


def preprocess(image):
    image = Image.fromarray(image, 'RGB')
    transform = transforms.Compose([
        # transforms.Resize(256),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                      std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).cuda()


def imshow(img):
    """
    Visualizing images inside jupyter notebook
    """
    plt.axis('off')
    if len(img.shape)==3:
        img = img[:,:,::-1] 	# transform image to rgb
    else:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    plt.imshow(img)
    # plt.show()
    plt.savefig("example_plot.png")


def generate_masks_lima(shape, patch_size):
    # 获取图像的宽度和高度
    height, width, channels = shape

    # 计算每个小块的宽度和高度
    patch_width = (width-2) // patch_size
    patch_height = (height-2) // patch_size

    # 初始化一个数组来存储所有掩码
    masks = np.zeros((patch_size * patch_size, height, width, channels), dtype=np.uint8)

    # 生成掩码
    for i in range(patch_size):
        for j in range(patch_size):
            # 计算当前小块的起始和结束位置
            start_x = j * patch_width + 1
            start_y = i * patch_height + 1
            end_x = start_x + patch_width
            end_y = start_y + patch_height

            # 创建掩码
            mask = np.zeros((height, width, channels), dtype=np.uint8)
            mask[start_y:end_y, start_x:end_x, :] = 1

            # 将掩码存储到 masks 数组中
            masks[i * patch_size + j] = mask

    return masks


def generate_masks(shape, patch_size):
    # 获取图像的宽度和高度
    height, width, channels = shape

    # 计算每个小块的宽度和高度
    patch_width = width // patch_size
    patch_height = height // patch_size

    # 初始化一个数组来存储所有掩码
    masks = np.zeros((patch_size * patch_size, height, width, channels), dtype=np.uint8)

    # 生成掩码
    for i in range(patch_size):
        for j in range(patch_size):
            # 计算当前小块的起始和结束位置
            start_x = j * patch_width
            start_y = i * patch_height
            end_x = start_x + patch_width
            end_y = start_y + patch_height

            # 创建掩码
            mask = np.zeros((height, width, channels), dtype=np.uint8)
            mask[start_y:end_y, start_x:end_x, :] = 1

            # 将掩码存储到 masks 数组中
            masks[i * patch_size + j] = mask

    return masks


def generate_masks_torch(patch_size):
    """
    Generate masks for patches in a given image shape.

    Args:
    shape: Tuple (height, width, channels) representing the shape of the image.
    patch_size: The number of patches along one dimension (assuming square patches).

    Returns:
    masks: Tensor of shape (patch_size * patch_size, height, width, channels) containing the masks.
    """
    # 计算每个小块的宽度和高度
    patch_width = 224 // patch_size
    patch_height = 224 // patch_size

    # 初始化一个张量来存储所有掩码
    masks = torch.zeros((patch_size * patch_size, 3, 224, 224), dtype=torch.uint8)

    # 生成掩码
    for i in range(patch_size):
        for j in range(patch_size):
            # 计算当前小块的起始和结束位置
            start_x = j * patch_width
            start_y = i * patch_height
            end_x = start_x + patch_width
            end_y = start_y + patch_height

            # 创建掩码
            mask = torch.zeros((3, 224, 224), dtype=torch.uint8)
            mask[:, start_y:end_y, start_x:end_x] = 1

            # 将掩码存储到 masks 张量中
            masks[i * patch_size + j] = mask

    return masks


def add_value_decrease(smdl_mask, json_file):
    single_mask = np.zeros_like(smdl_mask[0].mean(-1))

    value_list_1 = np.array(json_file["consistency_score"]) + np.array(json_file["collaboration_score"])

    value_list_2 = np.array(
        [1 - json_file["collaboration_score"][-1]] + json_file["consistency_score"][:-1]) + np.array(
        [1 - json_file["consistency_score"][-1]] + json_file["collaboration_score"][:-1])

    value_list = value_list_1 - value_list_2

    values = []
    value = 0
    for smdl_single_mask, smdl_value in zip(smdl_mask, value_list):
        value = value - abs(smdl_value)
        single_mask[smdl_single_mask.sum(-1) > 0] = value
        values.append(value)

    attribution_map = single_mask - single_mask.min()
    attribution_map /= attribution_map.max()

    return attribution_map, np.array(values)


def gen_cam(image, mask):
    """
    Generate heatmap
        :param image: [H,W,C]
        :param mask: [H,W],range 0-1
        :return: tuple(cam,heatmap)
    """
    mask = mask[3:222, 3:222]
    mask = cv2.resize(mask, (400, 400))

    # Read image cv2.COLORMAP_COOL cv2.COLORMAP_JET
    heatmap = cv2.applyColorMap(np.uint8(mask), cv2.COLORMAP_COOL)
    heatmap = np.float32(heatmap)


    # merge heatmap to original image
    cam = 0.3 * heatmap + 0.7 * np.float32(image)
    return cam, (heatmap).astype(np.uint8)


def norm_image(image):
    """
    Normalization image
    :param image: [H,W,C]
    :return:
    """
    image = image.copy()
    image -= np.max(np.min(image), 0)
    image /= np.max(image)
    image *= 255.
    return np.uint8(image)


def find_first_above_threshold(arr, threshold=0.7):
    for index, value in enumerate(arr):
        if value > threshold:
            return index
    return 35


def visualization(image, submodular_image_set, save_dir):
    kernel = np.ones((3, 3), dtype=np.uint8)

    # Ours
    # mask = (image - submodular_image_set).mean(-1)
    # mask[mask>0] = 1
    mask = 1 - submodular_image_set.mean(-1)

    dilate = cv2.dilate(mask, kernel, 3)
    # erosion = cv2.erode(dilate, kernel, iterations=3)
    # dilate = cv2.dilate(erosion, kernel, 2)
    edge = dilate - mask
    # erosion = cv2.erode(dilate, kernel, iterations=1)

    image_debug = image.copy()

    image_debug[mask>0] = image_debug[mask>0] * 0.4
    image_debug[edge>0] = np.array([0,0,255])
    cv2.imwrite(save_dir, image_debug.astype(np.uint8))


# image_path = "/mnt/huawei/jiaoxh/data/ImageNet100/val/n01558993/ILSVRC2012_val_00001598.JPEG"
image_path = "/mnt/huawei/jiaoxh/data/ImageNet100/val/n13040303/ILSVRC2012_val_00000058.JPEG"
#
#
# image_path = "/mnt/huawei/jiaoxh/data/ImageNet100/val/n02114855/ILSVRC2012_val_00016911.JPEG"
# image_path = "/mnt/huawei/jiaoxh/data/ImageNet100/val/n04127249/ILSVRC2012_val_00012191.JPEG"
# image_path = "/mnt/huawei/jiaoxh/data/ImageNet100/val/n01983481/ILSVRC2012_val_00034402.JPEG"      # 这张图片应该是错判了
# image_path = "/mnt/huawei/jiaoxh/data/ImageNet100/val/n02105505/ILSVRC2012_val_00023582.JPEG"
# 'n03837869/ILSVRC2012_val_00044841.JPEG'
# 'n02105505/ILSVRC2012 val 00034063.JPEG'
#
# image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n04429376/ILSVRC2012_val_00028100.JPEG'

# image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n01855672/ILSVRC2012_val_00043968.JPEG'

# 小蓝鹭
# image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n02009229/ILSVRC2012_val_00033760.JPEG'


# 最后可用的图 # todo
# 狐獴
# image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n02138441/ILSVRC2012_val_00012478.JPEG'
#
# image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n03930630/ILSVRC2012_val_00049789.JPEG'
# image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n01692333/ILSVRC2012_val_00033706.JPEG'
# image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n02089973/ILSVRC2012_val_00037826.JPEG'
# image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n04429376/ILSVRC2012_val_00043870.JPEG'
# image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n01773797/ILSVRC2012_val_00026030.JPEG'
# image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n02138441/ILSVRC2012_val_00047118.JPEG'
# image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n13040303/ILSVRC2012_val_00027169.JPEG'
# image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/n04229816/ILSVRC2012_val_00008102.JPEG'

# =3, pick_stage
# image_list = ['n02009229/ILSVRC2012_val_00033760.JPEG', 'n02018207/ILSVRC2012_val_00002370.JPEG', 'n02087046/ILSVRC2012_val_00038107.JPEG', 'n02087046/ILSVRC2012_val_00019566.JPEG', 'n03642806/ILSVRC2012_val_00045326.JPEG', 'n03930630/ILSVRC2012_val_00021802.JPEG', 'n02231487/ILSVRC2012_val_00041596.JPEG', 'n02138441/ILSVRC2012_val_00038452.JPEG', 'n02138441/ILSVRC2012_val_00012478.JPEG', 'n04136333/ILSVRC2012_val_00028859.JPEG', 'n04435653/ILSVRC2012_val_00039351.JPEG', 'n04435653/ILSVRC2012_val_00026714.JPEG', 'n02804414/ILSVRC2012_val_00027152.JPEG', 'n03775546/ILSVRC2012_val_00035190.JPEG', 'n03775546/ILSVRC2012_val_00028647.JPEG', 'n03775546/ILSVRC2012_val_00011309.JPEG', 'n04127249/ILSVRC2012_val_00022604.JPEG', 'n07753275/ILSVRC2012_val_00018941.JPEG', 'n02113799/ILSVRC2012_val_00018934.JPEG', 'n04229816/ILSVRC2012_val_00000550.JPEG', 'n04229816/ILSVRC2012_val_00005577.JPEG', 'n01729322/ILSVRC2012_val_00040890.JPEG', 'n02701002/ILSVRC2012_val_00047635.JPEG']
# =3, formal stage
# image_list = ['n03775546/ILSVRC2012_val_00011309.JPEG',
#               'n04229816/ILSVRC2012_val_00000550.JPEG',
#               'n04229816/ILSVRC2012_val_00005577.JPEG',
#               'n04136333/ILSVRC2012_val_00028859.JPEG',
#               'n04435653/ILSVRC2012_val_00039351.JPEG']
# =2, baseline 差异大
# image_list = ['n01855672/ILSVRC2012_val_00043968.JPEG', 'n01855672/ILSVRC2012_val_00029351.JPEG', 'n02087046/ILSVRC2012_val_00019128.JPEG', 'n03594734/ILSVRC2012_val_00044132.JPEG', 'n04517823/ILSVRC2012_val_00033301.JPEG', 'n04517823/ILSVRC2012_val_00039615.JPEG', 'n02086910/ILSVRC2012_val_00048254.JPEG', 'n02877765/ILSVRC2012_val_00018878.JPEG', 'n01735189/ILSVRC2012_val_00002452.JPEG', 'n03891251/ILSVRC2012_val_00038706.JPEG', 'n02113978/ILSVRC2012_val_00028901.JPEG', 'n03775546/ILSVRC2012_val_00006688.JPEG', 'n02109047/ILSVRC2012_val_00023740.JPEG', 'n02123045/ILSVRC2012_val_00020136.JPEG', 'n03494278/ILSVRC2012_val_00024537.JPEG', 'n02859443/ILSVRC2012_val_00002227.JPEG', 'n04099969/ILSVRC2012_val_00018123.JPEG']

image_list = ['n02123045/ILSVRC2012_val_00020136.JPEG', 'n02877765/ILSVRC2012_val_00018878.JPEG']

# image = cv2.imread(image_path)
# image = cv2.resize(image, (224, 224))
#
# img = cv2.imread(image_path)
# img = cv2.resize(img, (400, 400))
# # image = cv2.resize(image, (256, 256))
# # image = image[16:240, 16:240]
#
# # element_sets_V = SubRegionDivision(image, mode="seeds")
#
# patches = generate_masks(image.shape, 6)
#
# image_tensor = preprocess(image)
# patches_tensor = generate_masks_torch(6).cuda()
# element_sets_V = patches_tensor * image_tensor

pick_stage = False

ckpt_list = ['checkpoint/b16_224/epoch10-baseline/model_best.pth.tar',
             'checkpoint/b16_224/eclip-epoch10-gcc-lambda0.5/model_best.pth.tar',
             'checkpoint/b16_224/eclip-epoch10-cgc-lambda0.5/model_best.pth.tar',
             'checkpoint/lima/b16_224/ab-lambda/epoch10-lambda0.1-reg_freq20/model_best.pth.tar']

name = ['baseline', 'gc', 'cgc', 'reg']

for image_name in image_list:
    image_path = '/mnt/huawei/jiaoxh/data/ImageNet100/val/' + image_name
    image = cv2.imread(image_path)
    image = cv2.resize(image, (224, 224))

    img = cv2.imread(image_path)
    img = cv2.resize(img, (400, 400))
    patches = generate_masks(image.shape, 6)

    image_tensor = preprocess(image)
    patches_tensor = generate_masks_torch(6).cuda()
    element_sets_V = patches_tensor * image_tensor

    for i, ckpt_path in enumerate(ckpt_list):
        if os.path.exists(ckpt_path):
            image_name = image_path.split('/')[-1].split('.')[0]
            os.makedirs('result/' + image_name,exist_ok=True)
            model = ViTForImageClassification.from_pretrained('pretrained_model', subfolder='vit-base-patch16-224', ignore_mismatched_sizes=True)
            state_dict = torch.load(ckpt_path)['state_dict']
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
            model.load_state_dict(state_dict)
            model = torch.nn.DataParallel(model)
            model = model.cuda()


            smdl = BlackBoxSingleModalSubModularExplanationQuicktestRest(model,
                    preprocess, k=50, lambda1=1, lambda2=1, lambda3=1)
            smdl.k = len(element_sets_V)
            # submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V)
            submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V, [])
            attribution_map, value_list = add_value_decrease(submodular_image_set.cpu().numpy().transpose(0, 2,3,1), saved_json_file)
            # regions = patches_tensor[saved_json_file['sequence']].cpu().numpy().transpose(0, 2,3,1)
            # attribution_map, value_list = add_value_decrease(regions, saved_json_file)
            im, heatmap = gen_cam(img, norm_image(attribution_map))
            cv2.imwrite('result/{}/{}.png'.format(image_name, name[i]), im.astype(np.uint8))
            if not pick_stage:
                index = find_first_above_threshold(saved_json_file['consistency_score'])
                img_lima = cv2.imread(image_path)
                img_lima = cv2.resize(img_lima, (422, 422))
                patches_lima = generate_masks_lima(img_lima.shape, 6)

                # visualization(img_lima, (patches_lima * img_lima)[saved_json_file['sequence'][:index+1]].sum(0), 'result/{}/{}_lima.png'.format(image_name, name[i]))
                visualization(img_lima, patches_lima[saved_json_file['sequence'][:index+1]].sum(0), 'result/{}/{}_lima.png'.format(image_name, name[i]))


            image_flip = cv2.flip(image, 1)
            img_flip = cv2.flip(img, 1)
            # element_sets_V_flip = []
            # for element in element_sets_V:
            #     element_sets_V_flip.append(cv2.flip(element, 1))
            element_sets_V_flip = torch.flip(element_sets_V, [3])

            # element_sets_V_flip = SubRegionDivision(image_flip, mode="slico")

            # patches = generate_masks(image.shape, 8)
            # element_sets_V_flip = patches * image_flip

            # submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V_flip)
            submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V_flip, [])
            attribution_map, value_list = add_value_decrease(submodular_image_set.cpu().numpy().transpose(0, 2,3,1), saved_json_file)
            # regions = patches_tensor[saved_json_file['sequence']].cpu().numpy().transpose(0, 2, 3, 1)
            # attribution_map, value_list = add_value_decrease(regions, saved_json_file)
            im, heatmap = gen_cam(img_flip, norm_image(attribution_map))
            cv2.imwrite('result/{}/{}_flip.png'.format(image_name,name[i]), im.astype(np.uint8))
            if not pick_stage:
                index = find_first_above_threshold(saved_json_file['consistency_score'])
                img_lima = cv2.imread(image_path)
                img_lima = cv2.resize(img_lima, (422, 422))
                patches_lima = generate_masks_lima(img_lima.shape, 6)

                # visualization(img_lima, (patches_lima * img_lima)[saved_json_file['sequence'][:index+1]].sum(0), 'result/{}/{}_lima_flip.png'.format(image_name, name[i]))
                visualization(img_lima, patches_lima[saved_json_file['sequence'][:index+1]].sum(0), 'result/{}/{}_lima_flip.png'.format(image_name, name[i]))




