import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import torch

def save_visualizations(level_idx, original_images, feature_maps, masked_feature_maps, save_dir="visualizations"):
    """
    可视化原图、特征图和被掩码后的特征图，并保存到文件
    Args:
    - original_images (np.ndarray): 原始图像，形状为 (batch_size, H_original, W_original, 3)
    - feature_maps (np.ndarray): 特征图，形状为 (batch_size, C, H_f, W_f)
    - masked_feature_maps (np.ndarray): 被掩码后的特征图，形状为 (batch_size, C, H_f, W_f)
    - save_dir (str): 保存图片的目录
    
    Returns:
    - None: 将可视化结果保存为文件
    """
    # 创建保存目录
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    batch_size = original_images.shape[0]
    
    if isinstance(original_images, torch.Tensor):
        original_images = original_images.detach().cpu().numpy()
    if isinstance(feature_maps, torch.Tensor):
        feature_maps = feature_maps.detach().cpu().numpy()
    if isinstance(masked_feature_maps, torch.Tensor):
        masked_feature_maps = masked_feature_maps.detach().cpu().numpy()
    
    for i in range(batch_size):
        # 获取当前 batch 的原图、特征图和被掩码后的特征图
        original_img = original_images[i].transpose(1, 2, 0)
        feature_map = feature_maps[i]
        masked_feature_map = masked_feature_maps[i]

        # 取出其中一个通道的特征图进行可视化，或者对所有通道进行求和以获得整体特征图
        # 这里我们对所有通道求和再可视化，展示每个像素的整体特征强度
        summed_feature_map = np.sum(feature_map, axis=0)
        summed_masked_feature_map = np.sum(masked_feature_map, axis=0)

        # 保存原图
        original_img_path = os.path.join(save_dir, f"original_image_level{level_idx}_batch{i}.png")
        cv2.imwrite(original_img_path, cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
        
        # # 保存特征图
        # feature_map_path = os.path.join(save_dir, f"feature_map_level{level_idx}_batch{i}.png")
        # plt.imsave(feature_map_path, summed_feature_map, cmap='viridis')

        # # 保存被掩码后的特征图
        # masked_feature_map_path = os.path.join(save_dir, f"masked_feature_map_level{level_idx}_batch{i}.png")
        # plt.imsave(masked_feature_map_path, summed_masked_feature_map, cmap='viridis')
        
        # 找到被掩码的区域（masked_feature_map中为0的区域）
        mask = np.sum(feature_map, axis=0) == 0
        # 将被掩码的区域用红色标记（在特征图上）
        feature_map_with = summed_feature_map.copy()
        feature_map_with[mask] = np.max(summed_feature_map)  # 高亮掩码区域
        feature_map_path = os.path.join(save_dir, f"feature_map_level{level_idx}_batch{i}.png")
        plt.imsave(feature_map_path, feature_map_with, cmap='viridis')
        
        # 找到被掩码的区域（masked_feature_map中为0的区域）
        mask = np.sum(masked_feature_map, axis=0) == 0
        # 将被掩码的区域用红色标记（在特征图上）
        feature_map_with_mask = summed_feature_map.copy()
        feature_map_with_mask[mask] = np.max(summed_masked_feature_map)  # 高亮掩码区域
        masked_feature_map_path = os.path.join(save_dir, f"masked_feature_map_level{level_idx}_batch{i}.png")
        plt.imsave(masked_feature_map_path, feature_map_with_mask, cmap='viridis')

        print(f"Saved level{level_idx}_batch {i} visualizations:")
        print(f" - Original image saved to: {original_img_path}")
        print(f" - Feature map saved to: {feature_map_path}")
        print(f" - Masked feature map saved to: {masked_feature_map_path}")
        


if __name__ == '__main__':
    # 示例使用
    batch_size = 2
    H_original, W_original = 512, 512  # 假设原始图像大小为 512x512
    original_images = np.random.randint(0, 255, (batch_size, H_original, W_original, 3), dtype=np.uint8)

    # 假设特征图和掩码后的特征图
    C, H_f, W_f = 256, 64, 64  # 假设特征图有 256 个通道，大小为 64x64
    feature_maps = np.random.rand(batch_size, C, H_f, W_f)
    masked_feature_maps = feature_maps.copy()

    # 对某些通道的某些区域进行掩码（将它们设置为0）
    masked_feature_maps[:, :, 20:40, 20:40] = 0  # 对一些区域进行掩码

    # 保存可视化结果
    save_visualizations(original_images, feature_maps, masked_feature_maps, save_dir="/home/jiask/Open-GroundingDino-main/visualizations")