
import os
import cv2
import numpy as np
import torch
from PIL import Image, ImageOps, ImageDraw, ImageEnhance
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
from sklearn.metrics.pairwise import cosine_similarity


def find_significant_patches(img_0, img_1, sim_rate_base=0.996):
    
    patches1 = patchify(img_0)
    patches2 = patchify(img_1)

    similarity = calculate_similarity(patches1, patches2)
    significant_patches = get_significant_patches(similarity, threshold=sim_rate_base)
    
    return significant_patches


def patchify(image, patch_size=14, num_patches=16):
    image = np.array(image)
    patches = []
    for i in range(num_patches):
        for j in range(num_patches):
            patch = image[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size]
            patches.append(patch)
    return np.array(patches)


def calculate_similarity(patches1, patches2):
    similarities = []
    for patch1, patch2 in zip(patches1, patches2):
        # # 将patch展平
        patch1_flat = patch1.flatten().reshape(1, -1)
        patch2_flat = patch2.flatten().reshape(1, -1)
        similarity = cosine_similarity(patch1_flat, patch2_flat)
        # win_size = min(patch1.shape[0], patch1.shape[1], 7)  # 确保 win_size 是奇数且不超过图像的较小边
        # if win_size % 2 == 0:
        #     win_size -= 1
        # similarity, _ = ssim(patch1, patch2, full=True, win_size=win_size, channel_axis=-1)
        
        similarities.append(similarity)
    return np.array(similarities)


def get_significant_patches(similarity, threshold=0.99):
    grid_size = 14
    num_grids = 224 // grid_size
    similarity = similarity.reshape(num_grids, num_grids)
    # dynamic_threshold = cv2.resize(dynamic_threshold, (num_grids, num_grids))

    significant_patches = []
    # identical_patches = []

    for i in range(num_grids):
        for j in range(num_grids):
            if similarity[i, j] > threshold:
                patch_id = i * num_grids + j
                # if similarity[i, j] >= 1:
                #     # identical_patches.append(patch_id)
                #     identical_patches.append((patch_id, similarity[i, j]))
                # else:
                #     significant_patches.append((patch_id, similarity[i, j]))
                significant_patches.append((patch_id, similarity[i, j]))

    # 按相似度排序
    # identical_patches.sort(key=lambda x: x[1], reverse=True)
    significant_patches.sort(key=lambda x: x[1], reverse=True)

    # 选择相似度最高的前100个
    top_k_patches = [patch[0] for patch in significant_patches[:150]]
    all_significant_patches = top_k_patches

    return all_significant_patches



def task_relevent_attention(multihead_attention, image, significant_patches, primary=True, topk = 100):
    # Assuming the input is a numpy array of shape (1, num_heads, n_tokens, n_tokens)
    # First, we average the attention scores over the multiple heads
    layer = 15
    
    output_path = "./experiments/attn_vis/attn_vis_mask.jpg"
    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    # multihead_attention = multihead_attention[layer].cpu()
    relation_vis_text_score = token_attention_merge(multihead_attention, layer, primary)
    
    
    # layer_resuing_proportion = calculate_normalized_entropies(multihead_attention)

    
    result_image = visualize_significant_patches_mask(image, significant_patches, color=(15, 67, 223), alpha=0.4)

    # result_image.save(output_path.replace(".jpg", "_1.jpg"))
    blended_image, top_patches = plot_attention_heatmap(result_image, relation_vis_text_score, topk)
    result_image = visualize_significant_patches_mask(result_image, top_patches, color=(254, 55, 13), alpha=0.4)
    # result_image.save(output_path.replace(".jpg", "_2.jpg"))
    
    # Define colors
    color_significant_only = (40, 116, 166)   # Blue
    color_top_only = (241, 196, 15)            # Orange 
    color_overlap = (231, 76, 60 )           # Red(for overlap)
    
    significant_only = set(significant_patches) - set(top_patches)
    # significant_only = set(significant_patches)
    result_image = visualize_significant_patches_mask(image, significant_only, color=color_significant_only, alpha=0.4)
    # result_image.save(output_path.replace(".jpg", "_static.jpg"))
    top_only = set(top_patches) - set(significant_patches)
    result_image = visualize_significant_patches_mask(result_image, top_only, color=color_top_only, alpha=0.4)
    # result_image.save(output_path.replace(".jpg", "_task.jpg"))
    overlap = set(significant_patches) & set(top_patches)
    result_image = visualize_significant_patches_mask(result_image, overlap, color=color_overlap, alpha=0.4)
    # result_image.save(output_path.replace(".jpg", f"_overlap_primary_{primary}.jpg"))
    
    
    significant_patches = [patch_id for patch_id in significant_patches if patch_id not in top_patches]
    significant_patches = list(set(significant_patches))
    # print("Number of significant patches: ", len(significant_patches))

    # exit()
    # output_path_heatmap = output_path.replace(".jpg", "_heatmap.jpg")
    v_token_start = 1 if primary else 257
    
    significant_patches = [id + v_token_start for id in significant_patches]
    # significant_patches.append(0)
    significant_patches.sort()

    # image_patch.save(output_path)
    return np.array(result_image), significant_patches


def token_attention_merge(multihead_attention, layer_id = 15, primary=True):
    attention_position = multihead_attention[-1]
    
    v_token_num = 256
    t_token_num = 34
    a_token_num = 56
    v_token_start = 1 if primary else 257
    v_token_end = v_token_num + v_token_start

    v_mask = (v_token_start <= attention_position) & (attention_position < v_token_end)
    
    t_token_start = v_token_num*2 + 1
    t_token_end = t_token_num + t_token_start
    
    t_mask = (t_token_start <= attention_position) & (attention_position < t_token_end)
    
    
    # relation_vis_text = []
    # for i in  range(len(multihead_attention)):
    attention = multihead_attention[layer_id].squeeze(0)
    attention = torch.mean(attention, dim=0)
    vision_attention = attention[:, v_token_start:v_token_end]
    # relation_vis_text.append(attention)
    
    # relation_vis_text = torch.cat(relation_vis_text, dim=0)
    relation_vis_text = vision_attention[t_mask,:]
    # relation_vis_text = vision_attention[-a_token_num-t_token_num-1:-1,:]
    relation_vis_text = relation_vis_text.mean(0)
    return relation_vis_text.float().cpu()



def calculate_normalized_entropies(multihead_attention):
    """
    Calculate normalized entropies for attention distributions across layers.

    Args:
        multihead_attention (list of torch.Tensor): A list where each element is a tensor
            representing the multihead attention of a layer (e.g., [num_heads, tokens, tokens]).

    Returns:
        torch.Tensor: Normalized entropies for each layer, scaled to the range [0, 1].
    """
    entropies = []

    # Iterate through each layer's multihead attention
    for i in range(len(multihead_attention) - 1):
        layer_attention = multihead_attention[i]
        # Average the multihead attention across heads
        averaged_attention = torch.mean(layer_attention, axis=1)[0].float()

        # Normalize attention scores for entropy calculation
        attention_sum = torch.sum(averaged_attention, dim=-1, keepdim=True)
        normalized_attention = averaged_attention / attention_sum  # Shape: (n_tokens, n_tokens)

        # Replace NaNs with zeros for safety
        normalized_attention = torch.nan_to_num(normalized_attention, nan=0.0)

        # Flatten the attention scores for entropy calculation
        flattened_attention = normalized_attention.view(-1)

        # Calculate entropy
        entropy = -torch.sum(flattened_attention * torch.log(flattened_attention + 1e-10)).item()
        entropies.append(entropy)

    # Convert entropies to a tensor and normalize to [0, 1]
    entropies_tensor = torch.tensor(entropies)
    min_entropy = torch.min(entropies_tensor)
    max_entropy = torch.max(entropies_tensor)

    if max_entropy - min_entropy > 1e-10:  # Avoid division by zero
        normalized_entropies = (entropies_tensor - min_entropy) / (max_entropy - min_entropy)
    else:
        # If all entropies are equal, assign equal values
        normalized_entropies = torch.ones_like(entropies_tensor)
    layer_resuing_proportion = 1 - normalized_entropies


    accumulated_offset = 0

    for i, value in enumerate(layer_resuing_proportion):
        if i == 0:
            continue
        
        # 计算与前一项的偏移量
        offset = value - layer_resuing_proportion[i - 1]
        if offset > 0:
            accumulated_offset += offset * 0.55  # 对正偏移量累积并加权

            layer_resuing_proportion[i] = accumulated_offset

    # layer_resuing_proportion = torch.tensor(accumulated_growth)
    return layer_resuing_proportion


def visualize_significant_patches_mask(image, significant_patches, patch_size=14, alpha=0.5, color=(255, 255, 255)):
    """
    Highlights significant patches by applying a semi-transparent overlay with a specified color.
    
    :param image: PIL Image instance.
    :param significant_patches: List of patch indices to highlight.
    :param patch_size: Size of each patch (assumed square).
    :param alpha: Transparency level of the overlay (0.0 to 1.0, where 1.0 is fully opaque).
    :param color: Tuple (R, G, B) specifying the overlay color.
    :return: Modified PIL Image with highlighted patches.
    """
    # Ensure image is in RGBA mode to handle transparency
    image = image.convert("RGBA")
    overlay = Image.new("RGBA", image.size, (color[0], color[1], color[2], 0))  # Transparent overlay
    draw = ImageDraw.Draw(overlay)
    
    width, height = image.size
    num_patches = width // patch_size
    
    for patch_id in significant_patches:
        i = patch_id // num_patches
        j = patch_id % num_patches
        top_left = (j * patch_size, i * patch_size)
        bottom_right = ((j + 1) * patch_size, (i + 1) * patch_size)
        draw.rectangle([top_left, bottom_right], fill=(color[0], color[1], color[2], int(255 * alpha)))
    
    # Composite overlay with original image
    return Image.alpha_composite(image, overlay).convert("RGB")


def plot_attention_heatmap(image, attention_scores, topk):
    """
    将注意力权重映射到图片上，并绘制注意力热度图。

    参数:
    image (PIL.Image or np.array): 输入图片。
    attention_scores (torch.Tensor): 注意力权重。
    output_path (str, optional): 保存热度图的路径。如果为None，则不保存。

    返回:
    np.array: 生成的热度图。
    """
    width, height = image.size
    # 归一化注意力权重
    attention_scores = attention_scores - attention_scores.min()
    attention_scores = attention_scores / attention_scores.max()
    # 将注意力权重调整为16x16的矩阵
    attention_scores = attention_scores.reshape(16, 16)

    # 调整热度图大小以匹配图片大小
    heatmap = zoom(attention_scores, (width / 16, height / 16), order=1)
    
    heatmap = np.where(np.isfinite(heatmap), heatmap, 0) 
    heatmap = (heatmap * 255).astype(np.uint8)
    top_patches = get_top_patches_with_radiation(heatmap, top_k=topk)
    
    
    # 应用 Jet 伪彩色映射
    heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    # 转换为 PIL Image 格式并调整模式
    heatmap_colored = Image.fromarray(cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB))

    # 确保原图模式为 RGBA
    if image.mode != "RGBA":
        image = image.convert("RGBA")

    # 确保热度图尺寸与原图一致
    heatmap_colored = heatmap_colored.resize((width, height), Image.BILINEAR)
    heatmap_colored = heatmap_colored.convert("RGBA")

    # 叠加原图和热度图
    blended = Image.blend(image, heatmap_colored, alpha=0.4)

    # 转换为 RGB 模式以保存为 JPEG
    blended = blended.convert("RGB")

    return blended, top_patches


def get_top_patches_with_radiation(heatmap, top_k=120):
    """ 
    根据热度图先辐射周围patch的热度值，再获取热度最高的前K个patch的ID列表。
    spatial: 50, 0.8, 1
    object: 
    goal: 120, 0.9, 3
    10: 60, 0.9, 1

    参数:
    heatmap (np.array): 热度图，大小为图片分辨率。
    top_k (int): 选择热度最高的前K个patch。
    radiation_factor (float): 辐射系数，控制周围patch热度增加的比例。
    radius (int): 辐射半径，表示影响范围。

    返回:
    list: 辐射后热度最高的patch的ID列表。
    """
    # 定义网格大小和总网格数
    grid_size = 14
    num_grids = 16

    # 调整热度图大小到网格分辨率
    heatmap_resized = cv2.resize(heatmap, (num_grids, num_grids))
    
    # 创建一个列表存储patch的热度和ID
    patches = []
    for i in range(num_grids):
        for j in range(num_grids):
            patch_id = i * num_grids + j
            patch_heat = heatmap_resized[i, j]  # 获取更新后的patch热度
            patches.append((patch_id, patch_heat))

    # 按热度从高到低排序
    patches.sort(key=lambda x: x[1], reverse=True)

    # 获取热度最高的前K个patch
    top_patches = [patch[0] for patch in patches[:top_k]]
    top_heat = [patch[1] for patch in patches[:top_k]]
    # print(f"Min of Heat: {min(top_heat)}")

    return top_patches