from omni.isaac.kit import SimulationApp
simulation_app = SimulationApp({"headless": True})

import omni.usd
from pxr import UsdLux, UsdGeom, Gf
from omni.isaac.sensor import Camera
from omni.isaac.core import World
from omni.isaac.core.utils.stage import open_stage
import omni.replicator.core as rep
from PIL import Image
import json, gc
import numpy as np
from pathlib import Path
import argparse

class ReplayCapture:
    def __init__(self, scene_usd, pose_json, output_dir):
        self.scene_usd = Path(scene_usd)
        self.pose_json = Path(pose_json)
        self.output_dir = Path(output_dir)
        with open(self.pose_json, "r") as f:
            self.points = json.load(f)["points"]

    def _load_stage(self):
        omni.usd.get_context().close_stage()
        gc.collect()
        if not open_stage(usd_path=str(self.scene_usd)):
            print(f"❌ 无法打开 USD 场景 {self.scene_usd}")
            return None
        stage = omni.usd.get_context().get_stage()
        if not stage.GetPrimAtPath("/World/EnvLight"):
            dome = UsdLux.DomeLight.Define(stage, "/World/EnvLight")
            dome.CreateIntensityAttr(30000.0)
            dome.CreateColorAttr(Gf.Vec3f(1.0, 1.0, 1.0))
        return stage

    def run(self):
        stage = self._load_stage()
        if stage is None:
            return
        world = World()
        world.reset()
        for _ in range(5):
            world.step(render=True)

        self.output_dir.mkdir(parents=True, exist_ok=True)

        sensor_cam_path = "/World/MySensorCamera"
        cam = Camera(prim_path=sensor_cam_path, frequency=30, resolution=(640, 480))
        cam.initialize()

        rep.create.render_product(sensor_cam_path, (640, 480))

        cam_prim = stage.GetPrimAtPath(sensor_cam_path)
        usd_cam = UsdGeom.Camera(cam_prim)

        for idx, p in enumerate(self.points):
            # 还原光学参数
            usd_cam.GetFocalLengthAttr().Set(p["focal_length"])
            usd_cam.GetHorizontalApertureAttr().Set(p["horizontal_aperture"])
            usd_cam.GetVerticalApertureAttr().Set(p["vertical_aperture"])
            usd_cam.GetFocusDistanceAttr().Set(p["focus_distance"])
            usd_cam.GetClippingRangeAttr().Set(Gf.Vec2f(*p["clipping_range"]))

            # 设置姿态
            cam.set_world_pose(
                position=np.array(p["position"], dtype=np.float32),
                orientation=np.array(p["rotation"], dtype=np.float32)
            )

            # 渲染若干帧等 pose 生效
            for _ in range(5):
                world.step(render=True)

            img = cam.get_rgba()
            if img is None or img.size == 0:
                print(f"⚠️ pose {idx} 无图像")
                continue
            Image.fromarray(img[:, :, :3]).save(self.output_dir / f"pose_{idx}.png")
            print(f"✅ 保存: pose_{idx}.png")

        simulation_app.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--scene_usd", required=True)
    parser.add_argument("--pose_json", required=True)
    parser.add_argument("--output_dir", required=True)
    args = parser.parse_args()
    ReplayCapture(args.scene_usd, args.pose_json, args.output_dir).run()