import torch
import numpy as np
import random
from pytorch3d.transforms.rotation_conversions import quaternion_to_matrix
from pytorch3d.transforms import euler_angles_to_matrix, matrix_to_euler_angles
import matplotlib.cm as cm
import cv2
from einops import rearrange, repeat


ColorMap = cm.Greens
ColorList= [ (0, 0, 255), (255, 255, 0), (0, 255, 255)]



EndEffectorPts = [
    [0, 0, 0, 1],
    [0.1, 0, 0, 1],
    [0, 0.1, 0, 1],
    [0, 0, 0.1, 1]
]

Gripper2EEFCvt = [
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 1, 0],
    [0, 0, 0, 1]
]

#==================================================================================================
# 下面是工具函数
#==================================================================================================

def get_extrinsic_matrices(extrinsics):
    """
    从相机外参矩阵推导w2c和c2w变换矩阵
    
    参数:
        extrinsics: 相机外参矩阵 [batch_size, 4, 4]
                    描述从世界坐标系到相机坐标系的变换
    
    返回:
        w2c: 世界坐标系→相机坐标系变换矩阵 [batch_size, 4, 4]
        c2w: 相机坐标系→世界坐标系变换矩阵 [batch_size, 4, 4]
    """
    # 外参矩阵就是w2c变换 (世界到相机)
    w2c = extrinsics.clone()
    
    # c2w是w2c的逆矩阵 (即相机到世界坐标系)
    # 注意: 对于齐次坐标矩阵，求逆有特殊优化
    c2w = torch.empty_like(w2c)
    for i in range(w2c.shape[0]):
        # 提取旋转和平移分量
        R = w2c[i, :3, :3]
        T = w2c[i, :3, 3]
        
        # 计算逆变换: [R^T | -R^T * T]
        R_inv = R.T
        T_inv = -R_inv @ T
        
        # 构造齐次变换矩阵
        c2w[i, :3, :3] = R_inv
        c2w[i, :3, 3] = T_inv
        c2w[i, 3, :3] = 0
        c2w[i, 3, 3] = 1
    
    return w2c, c2w

def get_transformation_matrix_from_euler(euler_angles):
    """
    使用 PyTorch3D 将欧拉角转换为 4x4 齐次变换矩阵
    
    参数:
        euler_angles: [batch_size, 6] 张量
            前3个元素: 位置 (x, y, z)
            后3个元素: 欧拉角 (弧度) (rot_x, rot_y, rot_z)
    
    返回:
        4x4 变换矩阵 [batch_size, 4, 4]
    """
    batch_size = euler_angles.shape[0]
    device = euler_angles.device
    
    # 分离位置和旋转
    position = euler_angles[:, :3]
    rotation_angles = euler_angles[:, 3:]
    
    # 使用 PyTorch3D 将欧拉角转换为旋转矩阵
    # 注意: PyTorch3D 使用 XYZ 顺序，角度单位为弧度
    rotation_matrix = euler_angles_to_matrix(rotation_angles, convention="XYZ")
    
    # 创建齐次变换矩阵
    transform_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(batch_size, 1, 1)
    transform_matrix[:, :3, :3] = rotation_matrix
    transform_matrix[:, :3, 3] = position
    
    return transform_matrix

#==================================================================================================
# 下面是数据增强函数
#==================================================================================================


def gen_batch_ray_parellel(intrinsic , extrinsics , W , H):
    """批量生成相机光线（方向、原点和单位方向向量）
    
    参数:
        intrinsic: 相机内参矩阵 [batch_size, 3, 3]
        extrinsics: 相机到世界坐标系的变换矩阵 [batch_size, 4, 4]
        W: 图像宽度 (像素)
        H: 图像高度 (像素)
        
    返回:
        rays_d: 光线方向向量 (世界坐标系) [batch_size, H, W, 3]
        rays_o: 光线起点坐标 (世界坐标系) [batch_size, H, W, 3]
        viewdir: 单位化的光线方向 [batch_size, H, W, 3]
    """
    batch_size = intrinsic.shape[0]
    w2c, c2w = get_extrinsic_matrices(extrinsics)
    # 分解内参矩阵参数 (fx, fy, cx, cy)
    # 使用unsqueeze调整维度以便后续广播 [batch_size, 1, 1]
    fx, fy, cx, cy = intrinsic[:,0,0].unsqueeze(1).unsqueeze(2), intrinsic[:,1,1].unsqueeze(1).unsqueeze(2), intrinsic[:,0,2].unsqueeze(1).unsqueeze(2), intrinsic[:,1,2].unsqueeze(1).unsqueeze(2)
    # 创建图像网格坐标系（每个像素中心坐标）
    # i: 水平坐标网格 (0.5到W-0.5) [W, H]
    # j: 垂直坐标网格 (0.5到H-0.5) [W, H]
    # 注意：pytorch默认meshgrid使用indexing='ij'（矩阵索引）
    i, j = torch.meshgrid(torch.linspace(0.5, W-0.5, W, device=c2w.device), torch.linspace(0.5, H-0.5, H, device=c2w.device))  # pytorch's meshgrid has indexing='ij'
    i = i.t()# 形状从[W,H]转为[H,W]
    j = j.t()
    # 扩展网格到批处理维度 [batch_size, H, W]
    i = i.unsqueeze(0).repeat(batch_size,1,1)
    j = j.unsqueeze(0).repeat(batch_size,1,1)
    # 构建相机坐标系中的方向向量
    # 公式：(u - cx)/fx, (v - cy)/fy, 1
    # 形状: [batch_size, H, W, 3]
    dirs = torch.stack([(i-cx)/fx, (j-cy)/fy, torch.ones_like(i)], -1)
    # 将方向向量转换到世界坐标系（应用旋转变换）
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:,np.newaxis,np.newaxis, :3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # 计算光线起点（相机原点在世界坐标系中的位置）
    # 取c2w的平移部分（最后一列）[batch_size, 3]
    rays_o = c2w[:, :3, -1].unsqueeze(1).unsqueeze(2).repeat(1,H,W,1)
    # 计算单位方向向量
    viewdir = rays_d/torch.norm(rays_d,dim=-1,keepdim=True)
    return rays_d, rays_o, viewdir


@torch.no_grad()
def get_traj(sample_size, pose, extrinsics, intrinsic, radius=50):
    """
    生成单臂机器人末端执行器(EE)在相机视图中的轨迹可视化图
    
    参数:
        sample_size: 输出图像尺寸 (宽度, 高度)
        pose: 机械臂位姿信息 [batch_size, 7]
                格式: [位置_x, 位置_y, 位置_z, 旋转_x, 旋转_y, 旋转_z, 开合度]
        extrinsics: 相机外参矩阵 [batch_size, 4, 4]
        intrinsic: 相机内参矩阵 [batch_size, 3, 3]
        radius: 可视化关键点的半径 (默认50像素)
    
    返回:
        all_img_list: 轨迹图张量 [channels, batch_size, batch_size, height, width]

    示例：
        traj = get_traj(
        sample_size=sample_size, #[512, 512]
        pose=pose, #[1, 7] 目前只支持batch=1的输入
        extrinsics=extrinsics,  #[batch_size, 4, 4]
        intrinsic=intrinsics,   #[batch_size, 3, 3]
        radius=50

        提取图片
        img = traj[:, view_idx, batch_idx].permute(1, 2, 0).numpy()
    )

    """
    # 解构输出图像尺寸
    w, h = sample_size
    w2c, c2w = get_extrinsic_matrices(extrinsics) #[batch_size, 4, 4]
    print(f"pose.shape is {pose.shape}") 
    
    
    # 如果输入位姿是NumPy数组，转换为PyTorch张量
    if isinstance(pose, np.ndarray):
        pose = torch.tensor(pose, dtype=torch.float32)
    
    # 定义末端执行器的关键点坐标 (EndEffectorPts是预定义的常量)
    # 形状: [1, 1, 4, 4] -> (批次, 时间步, 点坐标+齐次, 点数量)
    # 转置后: [1, 1, 点数量, 4]
    ee_key_pts = torch.tensor(EndEffectorPts, dtype=torch.float32, device=pose.device)
    ee_key_pts = ee_key_pts.view(1, 1, 4, 4).permute(0, 1, 3, 2)

    # 将位姿 (位置+旋转角) 转换为齐次变换矩阵
    # 假设旋转角是欧拉角 (XYZ顺序)
    pose_mat = get_transformation_matrix_from_euler(pose[:, :6]).unsqueeze(0)  # [1, num_timesteps, 4, 4]
    print(f"w2c.shape is {w2c.shape}")
    print(f"pose_mat.shape is {pose_mat.shape}") 
    # 将末端执行器从世界坐标系转换到相机坐标系
    ee2cam = torch.matmul(w2c.unsqueeze(1) , pose_mat)  # [batch_size, num_timesteps, 4, 4]


    # 应用从夹爪坐标系到末端执行器坐标系的转换
    cvt_matrix = torch.tensor(Gripper2EEFCvt, dtype=torch.float32, device=pose.device).view(1, 1, 4, 4)
    ee2cam = torch.matmul(ee2cam, cvt_matrix)
    
    # 计算关键点在相机坐标系中的位置
    pts = torch.matmul(ee2cam, ee_key_pts)  # [batch_size, num_timesteps, num_points, 4]
    
    # 准备内参矩阵用于投影
    intrinsic = intrinsic.unsqueeze(1)  # [batch_size, 1, 3, 3]
    
    # 将3D点投影到2D图像平面 (透视投影)
    # uvs: [batch_size, num_timesteps, 3, num_points] (齐次坐标)
    uvs = torch.matmul(intrinsic, pts[:, :, :3, :])
    # 透视除法并提取前两个坐标 (u,v)
    uvs = (uvs / pts[:, :, 2:3, :])[:, :, :2, :]
    # 调整维度顺序: [batch_size, num_timesteps, num_points, 2]
    uvs = uvs.permute(0, 1, 3, 2).to(dtype=torch.int64)

    # 创建可视化图像
    all_img_list = []
    for iv in range(w2c.shape[0]):  # 遍历每个视图
        img_list = []
        for i in range(pose.shape[0]):  # 遍历每个时间步
            # 创建空白图像 (RGB格式)
            img = np.zeros((h, w, 3), dtype=np.uint8) + 50  # 浅灰色背景
            
            # 获取当前时间步的夹爪开合度 (第7个元素)
            normalized_value = pose[i, 6].item() / 120  # 假设最大开合度为120
            color = ColorMap(normalized_value)[:3]  # 获取RGB值
            color = tuple(int(c * 255) for c in color)  # 转换为0-255范围

            # 获取当前视图和时间步的关键点
            points = uvs[iv, i]
            
            # 获取第一个关键点作为基座点
            base = np.array(points[0])
            
            # 检查基座点是否在图像范围内
            if base[0] >= 0 and base[0] < w and base[1] >= 0 and base[1] < h:
                # 绘制末端执行器基座（实心圆）
                cv2.circle(img, tuple(base[:2]), radius, color, -1)
                
                # 绘制关键点之间的连接线
                for p_idx, point in enumerate(points):
                    point = np.array(point[:2])
                    if p_idx == 0:  # 跳过基座点
                        continue
                    # 绘制从基座到当前点的线段
                    cv2.line(img, tuple(base), tuple(point), ColorList[p_idx-1], 8)
            
            # 将当前时间步的图像添加到列表
            img_list.append(img/255.)  # 归一化到[0,1]范围
        
        # 将当前视图的所有时间步图像堆叠
        img_list = np.stack(img_list, axis=0)
        all_img_list.append(img_list)
    
    # 将所有视图和所有时间步的图像堆叠
    all_img_list = np.stack(all_img_list, axis=0)
    # 重组维度顺序: [channels, views, timesteps, height, width]
    all_img_list = rearrange(torch.tensor(all_img_list), "v t h w c -> c v t h w").float()
    print(f"all_img_list.shape is {all_img_list.shape}") 
    return all_img_list
