import mmcv
import numpy as np
import matplotlib.pyplot as plt
import os

def vis_16_query(pkl_path):
        # 加载数据
    bevformer_results = mmcv.load(pkl_path)
    bev_slected_feat = bevformer_results["current_selected_attention_map"]

    if hasattr(bev_slected_feat, 'cpu'):
        bev_slected_feat = bev_slected_feat.cpu().numpy()

    # reshape 成 (16, 100, 100)，每个 Query 一张图
    bev_slected_feat = bev_slected_feat.reshape(16, 100, 100)

    # 创建保存目录
    output_dir = "bev_feature_query_maps"
    os.makedirs(output_dir, exist_ok=True)

    # 遍历每一个 query
    for query_idx in range(16):
        query_map = bev_slected_feat[query_idx]  # shape: (100, 100)

        plt.figure(figsize=(5, 5))
        plt.imshow(query_map, cmap='viridis')
        plt.title(f"BEV Query {query_idx}")
        plt.colorbar()
        output_path = os.path.join(output_dir, f"query_{query_idx}.png")
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()

    print(f"Saved 16 query maps to: {output_dir}")
    
def extract_and_plot_bev_query_max_positions(pkl_path: str, output_path: str = "bev_16_query_max.png"):
    """
    从 pkl 文件中读取 BEV 16-query 特征，提取每个 query 的最大值位置，并在一张图上仅显示这16个点。
    
    Args:
        pkl_path (str): 输入的 .pkl 文件路径。
        output_path (str): 最终保存的图像路径。
    """
    # 1. 加载 BEV 特征
    bevformer_results = mmcv.load(pkl_path)
    bev_feat = bevformer_results["current_selected_attention_map"]

    if hasattr(bev_feat, 'cpu'):
        bev_feat = bev_feat.cpu().numpy()

    bev_feat = bev_feat.reshape(16, 100, 100)

    # 2. 提取每个 query 的最大值位置
    positions = []
    for query_idx in range(16):
        query_map = bev_feat[query_idx]
        max_pos = np.unravel_index(np.argmax(query_map), query_map.shape)  # (y, x)
        positions.append((max_pos[1], max_pos[0]))  # 转为 (x, y)

    # 3. 可视化：只画16个点
    plt.figure(figsize=(6, 6))
    plt.xlim(0, 100)
    plt.ylim(0, 100)
    plt.gca().invert_yaxis()  # 保持和图像坐标一致（左上角为原点）

    for idx, (x, y) in enumerate(positions):
        plt.scatter(x, y, c='red', s=50, label=f'Q{idx}' if idx == 0 else "")  # 只加一个label避免重复
        plt.text(x + 1, y, str(idx), color='black', fontsize=9)

    plt.title("Max Attention Position of Each BEV Query (16)")
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved pure point map to: {output_path}")
    
if __name__ == '__main__':

    pkl_path = '/opt/nvme0/pengyh/projects/world-model-RL/SSR/vis_data/scene_0ac05652a4c44374998be876ba5cd6fd+frame_0.pkl'
    extract_and_plot_bev_query_max_positions(pkl_path)
