from pathlib import Path
import open3d as o3d
import torch
import yaml

from furniture_bench_api.utils.pose_utils import normalize_pose_to_z, transform_pose_in_local_coords, transform_pose_in_world_coords

furniture = "round-table"
part = "leg"

mesh_dir = Path(__file__).parent.parent / f"3rdparty/furniture-bench/furniture_bench/assets/furniture/mesh/{furniture}"
obj_file = mesh_dir / f"{furniture}_{part}.obj"

mesh = o3d.io.read_triangle_mesh(str(obj_file))
mesh.compute_vertex_normals()


yaml_str = """
- op: "tf_local"
  args:
      tfs:
      - rx: 90
"""

pose = torch.as_tensor([0, 0, 0, 1, 0, 0, 0],dtype=torch.float32)  # [x, y, z, qw, qx, qy, qz]
config = yaml.safe_load(yaml_str)
for operation in config:
    op_name = operation["op"]
    op_args = operation.get("args", {})
    assert isinstance(op_args, dict), "args must be a dict"
    if op_name == "tf_local":
        for tf in op_args["tfs"]:
            pose = transform_pose_in_local_coords(pose, **tf)
    if op_name == "tf_world":
        for tf in op_args["tfs"]:
            pose = transform_pose_in_world_coords(pose, **tf)
    elif op_name == "transform":
        # apply different transform
        raise NotImplementedError()
        # pose = self.get_transform(part=part, pose=op_args["to"])(pose)
    elif op_name == "normalize_to_z":
        pose = normalize_pose_to_z(pose)



frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
posed_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.13, origin=[0, 0, 0])
from python_utils.transformations import pose_to_affine
posed_frame.transform(pose_to_affine(pose))


o3d.visualization.draw_geometries([mesh, frame, posed_frame], mesh_show_back_face=True)
