import os
from pxr import Usd, UsdGeom, UsdPhysics, Gf

def is_collision_mesh(prim):
    if not prim.IsA(UsdGeom.Mesh):
        return False
    name_hit = "collision" in prim.GetName().lower()
    api_hit = False
    try:
        api_hit = prim.HasAPI(UsdPhysics.CollisionAPI) or prim.HasAPI(UsdPhysics.MeshCollisionAPI)
    except Exception:
        api_hit = False
    return name_hit or api_hit

def _unique_child_path(stage, parent_path, base_name):
    """保证子节点名唯一"""
    path = parent_path.AppendChild(base_name)
    if not stage.GetPrimAtPath(path).IsValid():
        return path
    i = 1
    while True:
        alt = parent_path.AppendChild(f"{base_name}_{i:04d}")
        if not stage.GetPrimAtPath(alt).IsValid():
            return alt
        i += 1

def extract_collision_meshes(src_usd, dst_usd):
    stage_src = Usd.Stage.Open(src_usd)

    up_axis = UsdGeom.GetStageUpAxis(stage_src)
    meters_per_unit = UsdGeom.GetStageMetersPerUnit(stage_src)

    stage_dst = Usd.Stage.CreateNew(dst_usd)
    UsdGeom.SetStageUpAxis(stage_dst, up_axis)
    UsdGeom.SetStageMetersPerUnit(stage_dst, meters_per_unit)

    root_dst = stage_dst.DefinePrim("/Root", "Xform")
    stage_dst.SetDefaultPrim(root_dst)

    xfc = UsdGeom.XformCache()  # 默认时间

    count = 0
    for prim in stage_src.Traverse():
        if not is_collision_mesh(prim):
            continue

        mesh_src = UsdGeom.Mesh(prim)

        mesh_dst_path = _unique_child_path(stage_dst, root_dst.GetPath(), prim.GetName())
        mesh_dst = UsdGeom.Mesh.Define(stage_dst, mesh_dst_path)

        # 复制几何数据
        if mesh_src.GetPointsAttr().HasValue():
            mesh_dst.CreatePointsAttr(mesh_src.GetPointsAttr().Get())
        if mesh_src.GetFaceVertexCountsAttr().HasValue():
            mesh_dst.CreateFaceVertexCountsAttr(mesh_src.GetFaceVertexCountsAttr().Get())
        if mesh_src.GetFaceVertexIndicesAttr().HasValue():
            mesh_dst.CreateFaceVertexIndicesAttr(mesh_src.GetFaceVertexIndicesAttr().Get())
        if mesh_src.GetNormalsAttr().HasValue():
            mesh_dst.CreateNormalsAttr(mesh_src.GetNormalsAttr().Get())
            mesh_dst.SetNormalsInterpolation(mesh_src.GetNormalsInterpolation())
        if mesh_src.GetExtentAttr().HasValue():
            mesh_dst.CreateExtentAttr(mesh_src.GetExtentAttr().Get())

        # 复制碰撞API（如果有）
        try:
            if prim.HasAPI(UsdPhysics.CollisionAPI):
                col_api = UsdPhysics.CollisionAPI.Apply(mesh_dst.GetPrim())
                col_api.GetCollisionEnabledAttr().Set(True)
            if prim.HasAPI(UsdPhysics.MeshCollisionAPI):
                mesh_col_dst = UsdPhysics.MeshCollisionAPI.Apply(mesh_dst.GetPrim())
                approx_src = UsdPhysics.MeshCollisionAPI(prim).GetApproximationAttr().Get()
                if approx_src:
                    mesh_col_dst.CreateApproximationAttr().Set(approx_src)
        except Exception:
            pass

        # 写入原世界变换，不 Reset 栈
        world_xf = xfc.GetLocalToWorldTransform(prim)
        xf_dst = UsdGeom.Xformable(mesh_dst.GetPrim())
        op = xf_dst.AddTransformOp()
        op.Set(world_xf)

        count += 1
        print(f"Copied collision mesh: {prim.GetPath()} -> {mesh_dst.GetPath()}")

    stage_dst.GetRootLayer().Save()
    print(f"Saved {count} collision meshes to: {dst_usd}")

def batch_extract_collision_meshes(input_root_dir, output_root_dir):
    """
    批量处理：
    input_root_dir: 含多个子目录，每个子目录有 result.usd
    output_root_dir: 输出目录，将生成对应子目录和 ${子dir}_collision.usd
    """
    for subdir in os.listdir(input_root_dir):
        scene_dir = os.path.join(input_root_dir, subdir)
        if not os.path.isdir(scene_dir):
            continue
        src_usd = os.path.join(scene_dir, "result.usd")
        if not os.path.exists(src_usd):
            print(f"跳过 {subdir}，没有 result.usd")
            continue

        # 输出到对应子文件夹
        dst_scene_dir = os.path.join(output_root_dir, subdir)
        os.makedirs(dst_scene_dir, exist_ok=True)
        dst_usd = os.path.join(dst_scene_dir, f"{subdir}_collision.usd")

        print(f"\n=== 处理 {subdir} ===")
        extract_collision_meshes(src_usd, dst_usd)

if __name__ == "__main__":
    input_root = r"/home/sig/sig/qianluo/qianluo/3DGS-WAIC/3DGS_usd_simplified/3DGS_usd_simplified"
    output_root = r"/home/sig/sig/qianluo/qianluo/3DGS_VLN_Benchmark/Dataset/Collision"

    batch_extract_collision_meshes(input_root, output_root)