import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
import os
from einops import rearrange

OUTPUT_DIR = "./trajectory_visualizations"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# 定义所需常量 (模拟真实场景)
EndEffectorPts = np.array([
    [0, 0, 0, 1],
    [0.1, 0, 0, 1],
    [0, 0.1, 0, 1],
    [0, 0, 0.1, 1]
])

Gripper2EEFCvt = np.array([
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 1, 0.23],
    [0, 0, 0, 1]
])

ColorMap = plt.cm.viridis  # 使用matplotlib的颜色映射
ColorList = [(255,0,0), (0,255,0), (0,0,255)]  # RGB基础颜色

# 辅助函数：欧拉角转变换矩阵
def get_transformation_matrix_from_euler(euler_angles):
    """
    将欧拉角(角度制)转换为4x4变换矩阵
    格式: [x, y, z, roll, pitch, yaw]
    """
    positions = euler_angles[:, :3]
    rotations = np.radians(euler_angles[:, 3:])  # 角度转弧度
    
    matrices = []
    for i in range(positions.shape[0]):
        # 提取位置和旋转
        x, y, z = positions[i]
        roll, pitch, yaw = rotations[i]
        
        # 创建旋转矩阵 (ZYX顺序)
        cy = np.cos(yaw)
        sy = np.sin(yaw)
        cp = np.cos(pitch)
        sp = np.sin(pitch)
        cr = np.cos(roll)
        sr = np.sin(roll)
        
        rotation_matrix = np.array([
            [cy*cp, cy*sp*sr - sy*cr, cy*sp*cr + sy*sr],
            [sy*cp, sy*sp*sr + cy*cr, sy*sp*cr - cy*sr],
            [-sp, cp*sr, cp*cr]
        ])
        
        # 创建齐次变换矩阵
        transform = np.eye(4)
        transform[:3, :3] = rotation_matrix
        transform[:3, 3] = [x, y, z]
        
        matrices.append(transform)
    
    return torch.tensor(np.stack(matrices), dtype=torch.float32)

# 辅助函数：相机参数生成
def create_camera_params(num_views=2):
    """
    创建模拟相机参数 (外参和内参)
    """
    # 外参矩阵 (世界->相机变换)
    extrinsics = []
    for i in range(num_views):
        # 不同位置的相机
        T = [i*0.5, 0, 2.0]  # 位置偏移
        # 旋转矩阵
        theta = i * 30  # 不同角度
        rad = np.radians(theta)
        R = np.array([
            [np.cos(rad), 0, np.sin(rad)],
            [0, 1, 0],
            [-np.sin(rad), 0, np.cos(rad)]
        ])
        # 齐次变换矩阵
        extrinsic = np.eye(4)
        extrinsic[:3, :3] = R
        extrinsic[:3, 3] = T
        extrinsics.append(extrinsic)
    
    # 内参矩阵 (假设所有相机相同)
    focal_length = 500
    intrinsic = np.array([
        [focal_length, 0, 256],
        [0, focal_length, 256],
        [0, 0, 1]
    ]) 
    intrinsics = [intrinsic.copy() for _ in range(num_views)]
    
    return torch.tensor(np.stack(extrinsics), dtype=torch.float32), torch.tensor(np.stack(intrinsics), dtype=torch.float32)

# 辅助函数：创建轨迹可视化
def visualize_traj(traj, view_idx=0, timestep=0, save_dir=OUTPUT_DIR):
    """
    保存轨迹图像的单个视图和时间点到文件
    """
    img = traj[:, view_idx, timestep].permute(1, 2, 0).numpy()
    
    plt.figure(figsize=(8, 8))
    plt.imshow(img)
    plt.title(f"View {view_idx}, Timestep {timestep}")
    plt.axis('off')
    
    # 保存图片
    filename = os.path.join(save_dir, f"view_{view_idx}_timestep_{timestep}.png")
    plt.savefig(filename, bbox_inches='tight', pad_inches=0)
    plt.close()
    print(f"Saved visualization to: {filename}")
    
def validate_keypoints(traj, view_idx=0, timestep=0, pose=None, save_dir=OUTPUT_DIR):
    """
    保存关键点验证对比图
    """
    img = traj[:, view_idx, timestep].permute(1, 2, 0).numpy()
    
    # 红色通道分析
    red_channel = img[:, :, 0]
    y, x = np.unravel_index(red_channel.argmax(), red_channel.shape)
    
    # 创建对比图
    marked_img = img.copy()
    cv2.circle(marked_img, (x, y), 5, (1, 1, 0), -1)  # 黄色标记
    
    fig = plt.figure(figsize=(10, 5))
    plt.subplot(121)
    plt.imshow(img)
    plt.title("Original Trajectory")
    plt.subplot(122)
    plt.imshow(marked_img)
    plt.title("Keypoint Detection")
    
    # 保存验证结果
    validation_path = os.path.join(save_dir, f"validation_view{view_idx}_step{timestep}.png")
    plt.savefig(validation_path, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved validation result to: {validation_path}")

    # 返回检测到的位置和预期位置
    if pose is not None:
        expected_x = pose[timestep, 0].item()
        return (x, y), expected_x
    return (x, y), None

# 测试主函数
def test_get_traj():
    print("===== 开始测试 get_traj 函数 =====")
    
    # 1. 创建模拟数据
    sample_size = (512, 512)  # 图像尺寸 HxW
    
    
    pose_sequence = []
    for t in range(3):  
        # 位置分量 (单位：米)
        x = 0.5 + t * 0.1     # 沿X轴每次移动0.1米
        y = 0 + t * 0.1         # Y轴固定偏移0.2米
        z = 0.5 + t * 0.1 # Z轴从0.5米开始逐渐升高
        
        # 旋转分量 (单位：度)
        roll = 90 + t * 10          # 绕X轴旋转
        pitch = 90 + t * 10   # 绕Y轴每次增加5度
        yaw = 0 + t * 10         # 绕Z轴固定偏转10度
        
        # 夹爪开合度 (0-100%)
        gripper = 30 + t * 10 # 从30%逐步增加到110%
        
        pose_sequence.append([x, y, z, roll, pitch, yaw, gripper])
        
    pose = torch.tensor(pose_sequence, dtype=torch.float32)
    # 创建相机参数
    extrinsics, intrinsics = create_camera_params(num_views=2)
    
    # 2. 调用待测函数
    traj = get_traj(
        sample_size=sample_size, #[512, 512]
        pose=pose, #[1, 7]
        extrinsics=extrinsics,  #[num_views, 4, 4]
        intrinsic=intrinsics,   #[num_views, 3, 3]
        radius=50
    )
    
    # 3. 验证输出形状
    # 期望形状: [channels, num_views, timesteps, height, width]
    expected_shape = (3, extrinsics.shape[0], pose.shape[0], sample_size[0], sample_size[1])
    actual_shape = tuple(traj.shape)
    
    print(f"期望形状: {expected_shape}")
    print(f"实际形状: {actual_shape}")
    
    if actual_shape == expected_shape:
        print("✅ 形状验证通过")
    else:
        print("❌ 形状验证失败")
        return
    
    # 4. 可视化轨迹
    # 显示第一个视图的所有时间步轨迹
    print("\n====== 轨迹可视化 ======")
    for t in range(pose.shape[0]):
        visualize_traj(traj, view_idx=0, timestep=t)
    
    # 显示最后一个时间步的所有视图
    last_timestep = pose.shape[0] - 1
    for v in range(extrinsics.shape[0]):
        visualize_traj(traj, view_idx=v, timestep=last_timestep)
    
    # 5. 保存关键点验证结果
    print("\n===== Saving keypoint validation =====")
    detected_pos, expected_pos = validate_keypoints(traj, view_idx=0, timestep=0, pose=pose)
    
    # 保存位置对比结果
    with open(os.path.join(OUTPUT_DIR, "position_report.txt"), "w") as f:
        f.write(f"Detected position (x,y): {detected_pos}\n")
        if expected_pos is not None:
            f.write(f"Expected X position: {expected_pos}\n")
            f.write(f"X position error: {abs(detected_pos[0] - expected_pos)} pixels\n")

# 运行测试
if __name__ == "__main__":
    from rvt.utils.feature_aug import get_traj  
    
    test_get_traj()