from pxr import Usd, UsdGeom, UsdPhysics, Gf

def is_collision_mesh(prim):
    if not prim.IsA(UsdGeom.Mesh):
        return False
    name_hit = "collision" in prim.GetName().lower()
    api_hit = False
    try:
        api_hit = prim.HasAPI(UsdPhysics.CollisionAPI) or prim.HasAPI(UsdPhysics.MeshCollisionAPI)
    except Exception:
        api_hit = False
    return name_hit or api_hit

def _unique_child_path(stage, parent_path, base_name):
    """确保在 /Root 下子节点命名不冲突"""
    path = parent_path.AppendChild(base_name)
    if not stage.GetPrimAtPath(path).IsValid():
        return path
    i = 1
    while True:
        alt = parent_path.AppendChild(f"{base_name}_{i:04d}")
        if not stage.GetPrimAtPath(alt).IsValid():
            return alt
        i += 1

def extract_collision_meshes(src_usd, dst_usd):
    stage_src = Usd.Stage.Open(src_usd)

    # 同步源场景的 UpAxis 和 MetersPerUnit，避免比例与方向错误
    up_axis = UsdGeom.GetStageUpAxis(stage_src)
    meters_per_unit = UsdGeom.GetStageMetersPerUnit(stage_src)

    # 目标场景
    stage_dst = Usd.Stage.CreateNew(dst_usd)
    UsdGeom.SetStageUpAxis(stage_dst, up_axis)
    UsdGeom.SetStageMetersPerUnit(stage_dst, meters_per_unit)

    root_dst = stage_dst.DefinePrim("/Root", "Xform")
    stage_dst.SetDefaultPrim(root_dst)

    # 如需整个集合成为一个动态刚体，可启用下述两行；若仅需可编辑/可移动，建议注释掉
    # rb = UsdPhysics.RigidBodyAPI.Apply(root_dst)
    # rb.GetKinematicEnabledAttr().Set(False)  # 动态刚体

    # 用 XformCache 获取每个 prim 的世界变换
    xfc = UsdGeom.XformCache()  # 默认时间点

    count = 0
    for prim in stage_src.Traverse():
        if not is_collision_mesh(prim):
            continue

        mesh_src = UsdGeom.Mesh(prim)

        # 在新 stage /Root 下创建对应 mesh，避免同名冲突
        mesh_dst_path = _unique_child_path(stage_dst, root_dst.GetPath(), prim.GetName())
        mesh_dst = UsdGeom.Mesh.Define(stage_dst, mesh_dst_path)

        # 拷贝几何数据（点/面/法线/extent）
        if mesh_src.GetPointsAttr().HasValue():
            mesh_dst.CreatePointsAttr(mesh_src.GetPointsAttr().Get())
        if mesh_src.GetFaceVertexCountsAttr().HasValue():
            mesh_dst.CreateFaceVertexCountsAttr(mesh_src.GetFaceVertexCountsAttr().Get())
        if mesh_src.GetFaceVertexIndicesAttr().HasValue():
            mesh_dst.CreateFaceVertexIndicesAttr(mesh_src.GetFaceVertexIndicesAttr().Get())
        if mesh_src.GetNormalsAttr().HasValue():
            mesh_dst.CreateNormalsAttr(mesh_src.GetNormalsAttr().Get())
            mesh_dst.SetNormalsInterpolation(mesh_src.GetNormalsInterpolation())
        if mesh_src.GetExtentAttr().HasValue():
            mesh_dst.CreateExtentAttr(mesh_src.GetExtentAttr().Get())

        # 不设置 purpose=guide，保持默认，这样在 Isaac Sim 里可直接选中/移动/旋转
        # UsdGeom.Xform(mesh_dst).GetPurposeAttr().Set("guide")

        # 复制碰撞 API（如果存在），并确保启用
        try:
            if prim.HasAPI(UsdPhysics.CollisionAPI):
                col_api = UsdPhysics.CollisionAPI.Apply(mesh_dst.GetPrim())
                col_api.GetCollisionEnabledAttr().Set(True)
            if prim.HasAPI(UsdPhysics.MeshCollisionAPI):
                mesh_col_dst = UsdPhysics.MeshCollisionAPI.Apply(mesh_dst.GetPrim())
                approx_src = UsdPhysics.MeshCollisionAPI(prim).GetApproximationAttr().Get()
                if approx_src:
                    mesh_col_dst.CreateApproximationAttr().Set(approx_src)
        except Exception:
            pass

        # 写入初始变换：不 Reset 栈，这样父节点 /Root 的移动会传递到子 Mesh
        world_xf = xfc.GetLocalToWorldTransform(prim)  # Gf.Matrix4d
        xf_dst = UsdGeom.Xformable(mesh_dst.GetPrim())
        op = xf_dst.AddTransformOp()  # xformOp:transform
        op.Set(world_xf)

        count += 1
        print(f"Copied collision mesh: {prim.GetPath()} -> {mesh_dst.GetPath()}")

    stage_dst.GetRootLayer().Save()
    print(f"Saved {count} collision meshes to: {dst_usd}")

if __name__ == "__main__":
    src = r"/home/sig/sig/qianluo/qianluo/3DGS-WAIC/3DGS_usd_simplified/3DGS_usd_simplified/839920/result.usd"
    dst = r"/home/sig/sig/qianluo/qianluo/3DGS_VLN_Benchmark/Dataset/Collision/839920_collision.usd"
    extract_collision_meshes(src, dst)