import mmcv
import numpy as np
import matplotlib.pyplot as plt

def visualize_l2_norm(bev_feat):
    heatmap = np.linalg.norm(bev_feat, axis=2)
    plt.imshow(heatmap, cmap='hot')
    plt.title("L2 Norm Heatmap")
    plt.colorbar()
    plt.savefig("bev_l2_norm.png", dpi=300, bbox_inches='tight')
    plt.clf()

# def visualize_channel_mean(bev_feat):
#     heatmap = np.mean(bev_feat, axis=2)
#     plt.imshow(heatmap, cmap='viridis')
#     plt.title("Channel Mean Heatmap")
#     plt.colorbar()
#     plt.savefig("bev_channel_mean.png", dpi=300, bbox_inches='tight')
#     plt.clf()

# def visualize_rgb_first3(bev_feat):
#     rgb = bev_feat[:, :, :3]
#     rgb -= rgb.min()
#     rgb /= (rgb.max() + 1e-5)
#     plt.imshow(rgb)
#     plt.title("RGB from First 3 Channels")
#     plt.savefig("bev_rgb_first3.png", dpi=300, bbox_inches='tight')
#     plt.clf()

# def visualize_rgb_split255(bev_feat):
#     rgb_feat = bev_feat[:, :, :-1]  # 去掉最后1维 → 255维
#     r = np.mean(rgb_feat[:, :, :85], axis=2)
#     g = np.mean(rgb_feat[:, :, 85:170], axis=2)
#     b = np.mean(rgb_feat[:, :, 170:], axis=2)

#     rgb = np.stack([r, g, b], axis=2)
#     rgb -= rgb.min()
#     rgb /= (rgb.max() + 1e-5)
#     plt.imshow(rgb)
#     plt.title("RGB from 255D → Avg per 85D")
#     plt.savefig("bev_rgb_split255.png", dpi=300, bbox_inches='tight')
#     plt.clf()

if __name__ == '__main__':
    bevformer_feature_results = mmcv.load('/opt/nvme0/pengyh/projects/world-model-RL/SSR/vis_data/scene_0ac05652a4c44374998be876ba5cd6fd+frame_30.pkl')

    print("Type of bevformer_feature_results:", type(bevformer_feature_results))
    if isinstance(bevformer_feature_results, dict):
        print("Available top-level keys:", bevformer_feature_results.keys())

    bev_feat = bevformer_feature_results['world_model_pred_bev_embed']
    
    if hasattr(bev_feat, 'cpu'):
        bev_feat = bev_feat.cpu().numpy()
        
    print("BEV feature shape:", bev_feat.shape)

    bev_feat = bev_feat.reshape(100, 100, 256)

    # 多种可视化方式
    visualize_l2_norm(bev_feat)
    # visualize_channel_mean(bev_feat)
    # visualize_rgb_first3(bev_feat)
    # visualize_rgb_split255(bev_feat)

    print("所有 BEV 可视化图像已保存。")
