import numpy as np
from scipy.linalg import logm,expm

def SE3_to_se3(T):
    """
    将 SE(3) 矩阵转换为 se(3) 李代数的6维向量表示
    :param T: 4x4 的 SE(3) 矩阵
    :return: 6维向量 [v_x, v_y, v_z, w_x, w_y, w_z]
    """
    se3_mat = logm(T) # 计算矩阵对数
    v = se3_mat[0:3, 3] # 提取平移部分
    w_hat = se3_mat[0:3, 0:3] # 提取旋转部分（反对称矩阵）
    
    # 将反对称矩阵转换为向量
    w = np.array([
        w_hat[2, 1],
        w_hat[0, 2],
        w_hat[1, 0]
    ])
    
    se3_vec = np.concatenate((v, w))
    return se3_vec

def se3_to_SE3(se3_vec):
    """
    将 se(3) 的6维向量表示转换为 SE(3) 矩阵
    :param se3_vec: 6维向量 [v_x, v_y, v_z, w_x, w_y, w_z]
    :return: 4x4 的 SE(3) 矩阵
    """
    v = se3_vec[0:3]  # 平移部分
    w = se3_vec[3:6]  # 旋转部分

    # 构建反对称矩阵 w_hat
    w_hat = np.array([
        [ 0,     -w[2],  w[1]],
        [ w[2],   0,    -w[0]],
        [-w[1],  w[0],   0   ]
    ])

    # 构建 se(3) 矩阵
    se3_mat = np.zeros((4, 4))
    se3_mat[0:3, 0:3] = w_hat
    se3_mat[0:3, 3] = v

    T = expm(se3_mat)
    return T


def rotate_matrix(angle, axis, R):
    """
    旋转矩阵R绕指定轴旋转指定角度

    参数:
    angle -- 旋转角度（弧度）
    axis -- 旋转轴（'x', 'y' 或 'z'）
    R -- 3x3旋转矩阵

    返回:
    旋转后的3x3矩阵
    """
    if axis == 'x':
        Rx = np.array([[1, 0, 0],
                       [0, np.cos(angle), -np.sin(angle)],
                       [0, np.sin(angle), np.cos(angle)]])
        return Rx @ R
    elif axis == 'y':
        Ry = np.array([[np.cos(angle), 0, np.sin(angle)],
                       [0, 1, 0],
                       [-np.sin(angle), 0, np.cos(angle)]])
        return Ry @ R
    elif axis == 'z':
        Rz = np.array([[np.cos(angle), -np.sin(angle), 0],
                       [np.sin(angle), np.cos(angle), 0],
                       [0, 0, 1]])
        return Rz @ R
    else:
        raise ValueError("轴必须是 'x', 'y' 或 'z'")


if __name__ == "__main__":
    
    import json
    json_file = "./datasets/DLD3DV/DL3DV-10K/1K/0a45aa466f114e6daa13a82775e7bd5fc295e37a1e3d61deb3741a7e7a1b1f8a/transforms.json"
    with open(json_file, "r") as f:
        data = json.load(f)
    
    # max_item=0
    # with open('./cache/se3.txt', 'w') as f:
    #     for i in range(len(data["frames"])):
    #         SE3 = np.array(data["frames"][i]["transform_matrix"])
    #         se3_vector = SE3_to_se3(SE3)
    #         print("se(3) 向量:\n", se3_vector)
    #         f.write(str(se3_vector)+'\n')
    #         max_item = max(max_item, np.max(np.abs(se3_vector)))
    #         # SE3_recover = se3_to_SE3(se3_vector)
    #         # print("恢复的 SE(3) 矩阵:\n", SE3_recover)
            
    # print(max_item)
    
    SE3_recover = se3_to_SE3(np.array([1, 2, 3, 1, 0, 0]))
    print("恢复的 SE(3) 矩阵:\n", SE3_recover)