import os
import cv2
import torch
import argparse
import numpy as np
from torch import nn, Tensor
import torch.nn.functional as F
import matplotlib.pyplot as plt

def plot_3d_comparison(ax, data, kp_3d_pred, view='overview' ,view_names = ['1', '2', '3']):
    """3D点云可视化"""
    # 获取ground truth点云（各视角先reshape为(N,3)再合并）
    gt_points = []
    num_views = len(view_names)
    for view_name in view_names:
        pc = data[f'point_map_view_{view_name}'][0]
        # print(f"pc.shape is {pc.shape}")
        
        if isinstance(pc, torch.Tensor):
            pc = pc.cpu().numpy()  # 将Tensor转为NumPy
        pc = np.transpose(pc, (1, 2, 0)) # shape: (W,H,3)
        gt_points.append(pc.reshape(-1,3))          # (16384, 3)
    gt_points = np.concatenate(gt_points, axis=0)   # (65536, 3)
    # 自动计算合理的坐标范围（保留10%边界）
    x_min, x_max = np.percentile(gt_points[:,0], [5, 95])
    y_min, y_max = np.percentile(gt_points[:,1], [5, 95]) 
    z_min, z_max = np.percentile(gt_points[:,2], [5, 95])
    bounds = [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6]
    # 过滤无效点（全零点和NaN）
    valid_mask = ~np.all(gt_points == 0, axis=1) & ~np.isnan(gt_points).any(axis=1)
    gt_points = gt_points[valid_mask]
    # 绘制ground truth点云（使用更高效的绘制方式）
    ax.scatter(gt_points[:,0], gt_points[:,1], gt_points[:,2], 
              c='gray', alpha=0.1, s=5, label='Scene PointCloud')

    # 绘制关键点（过滤无效点）
    colors = ['red', 'green', 'blue', 'yellow']
    for i in range(1, num_views):
        kp = kp_3d_pred[f'kp_{i}'][0]  # (300, 3)
        if isinstance(kp, torch.Tensor):
            kp = kp.cpu().numpy() 
        valid_kp = kp[~np.all(kp == 0, axis=1)]  # 过滤全零点
        if len(valid_kp) > 0:
            ax.scatter(valid_kp[:,0], valid_kp[:,1], valid_kp[:,2],
                      c=colors[i-1], s=30, marker='o',
                      label=f'KP_{i} ({len(valid_kp)} pts)')
    # 设置坐标轴
    ax.set_xlabel('X')
    ax.set_ylabel('Y') 
    ax.set_zlabel('Z')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
    ax.legend(loc='upper right', fontsize=8)
    # 设置视角和范围
    if view == '1':
        ax.view_init(elev=0, azim=0)  # 正前方视角
    elif view == '2':
        ax.view_init(elev=90, azim=0)  # 正上方视角
    else:
        ax.view_init(elev=30, azim=60)
    if bounds:
        ax.set_xlim(bounds[0], bounds[3])
        ax.set_ylim(bounds[1], bounds[4])
        ax.set_zlim(bounds[2], bounds[5])


def plot_2d_projection(ax, data, kp_3d_pred, view_name, overlay_type='rgb', original_kps=None):
    """
    绘制2D投影对比图，包括原始关键点和预测关键点。
    
    :param ax: matplotlib的axes对象，用于绘图
    :param data: 测试数据字典，包含RGB、深度图、相机内外参等信息
    :param kp_3d_pred: 预测的3D关键点字典
    :param view_name: 视角名称（例如'1', '2'等）
    :param overlay_type: 可选'rgb'或'depth'，决定背景是RGB图像还是深度图
    :param original_kps: 原始2D关键点坐标列表，如果提供，则会画在图像上
    """
    # 获取当前视角参数（添加分辨率处理）
    rgb = data[f'rgb_{view_name}']#.transpose(1,2,0)
    depth = data[f'depth_{view_name}'].squeeze(0)
    intrinsics = data[f'camera_intrinsics_{view_name}']
    extrinsics = data[f'camera_extrinsics_{view_name}'] 
    print(f"rgb shape is {rgb.shape}")
    print(f"depth shape is {depth.shape}")
    print(f"intrinsics shape is {intrinsics.shape}")
    print(f"extrinsics shape is {extrinsics.shape}")
    
    H, W = depth.shape
    num_views = 3

    # 显示背景（添加分辨率适配）
    if overlay_type == 'rgb':
        img = rgb.copy()
        if original_kps is not None:
            original_kps = np.clip(original_kps, [0, 0], [W-1, H-1])
            ax.scatter(original_kps[:,0], original_kps[:,1], 
                      c='blue', s=20, marker='x', label='Original KPs')
        ax.imshow(img)
    else:
        ax.imshow(depth, cmap='gray', vmin=0, vmax=2)
    
    colors = plt.get_cmap('tab10').colors
    markers = ['*', 's', 'D', 'o']
    # 投影关键点
    for kp_id in range(1, num_views):
        kp_world = kp_3d_pred[f'kp_{kp_id}'][0]
        # 世界→相机坐标系 (使用外参的逆变换)
        R = extrinsics[:3,:3]
        t = extrinsics[:3,3]
        kp_cam = (kp_world - t) @ R  # 等价于 R.T @ (kp_world - t)
        # 相机→像素坐标
        kp_pixel = kp_cam @ intrinsics.T
        kp_pixel = kp_pixel[:,:2] / kp_pixel[:,2:]
        # 过滤有效点
        valid_mask = (
            (kp_pixel[:,0] >= 0) & (kp_pixel[:,0] < W) &
            (kp_pixel[:,1] >= 0) & (kp_pixel[:,1] < H) &
            (kp_cam[:,2] > 0)  # z值必须为正（在相机前方）
        )
        kp_pixel = kp_pixel[valid_mask]
        # 绘制当前视角的关键点，使用特殊标记
        current_view = (kp_id-1 == ['1', '2', '3'].index(view_name))
        label = f'KP_{kp_id}' + (' (current)' if current_view else '')
        
        if len(kp_pixel) > 0:
            ax.scatter(kp_pixel[:,1], kp_pixel[:,0], c=[colors[kp_id-1]], 
                       s=50 if current_view else 30, marker=markers[kp_id-1], alpha=0.8,
                      edgecolors='white', linewidth=0.5, label=label)

    ax.set_xlim(0, W)
    ax.set_ylim(H, 0)
    ax.legend(loc='upper right', fontsize=8)
    ax.set_title(f'{view_name} Projection')


def visualize_comparison(data, kp_3d_pred, save_dir="debug_runs/comparison_results"):
    """
    多模态比较预测关键点与原始点云
    包含：3D点云对比、2D投影对比、深度图叠加、误差分析
    """
    os.makedirs(save_dir, exist_ok=True)
    print("================================")
    print(f"save to dir {save_dir}!")
    print("================================")
    
    # ================= 1. 3D点云对比可视化 =================
    fig = plt.figure(figsize=(18, 10))
    # 3D视图1
    ax1 = fig.add_subplot(121, projection='3d')
    plot_3d_comparison(ax1, data, kp_3d_pred, view='overview')
    # 3D视图2（不同视角）
    ax2 = fig.add_subplot(122, projection='3d')
    plot_3d_comparison(ax2, data, kp_3d_pred, view='top_down')
    plt.tight_layout()
    plt.savefig(f"{save_dir}/3d_comparison.png", dpi=300)
    plt.close()
    # ================= 2. 各相机视角2D投影对比 =================
    view_mapping = {
        '1': '1',
        '2': '2',
        '3': '3',

    }
    original_keypoints = {
        '1': data['kp_1'].squeeze(0) * 128 / 518,
        '2': data['kp_2'].squeeze(0) * 128 / 518,
        '3': data['kp_3'].squeeze(0) * 128 / 518,
   
    }    
    for view_name in view_mapping.keys():
        fig = plt.figure(figsize=(15, 6))
        # RGB图像
        ax1 = fig.add_subplot(131)
        plot_2d_projection(ax1, data, kp_3d_pred, view_name, overlay_type='rgb', original_kps=original_keypoints[view_name])
        ax1.set_title(f'{view_name} RGB Projection')
        # 深度图
        ax2 = fig.add_subplot(132)
        plot_2d_projection(ax2, data, kp_3d_pred, view_name, overlay_type='depth')
        ax2.set_title(f'{view_name} Depth Projection')
        # 误差热力图
        ax3 = fig.add_subplot(133)
        plot_error_heatmap(ax3, data, kp_3d_pred, view_name)
        ax3.set_title(f'{view_name} Error Heatmap')
        plt.tight_layout()
        plt.savefig(f"{save_dir}/2d_{view_name}_comparison.png", dpi=300)
        plt.close()