from omni.isaac.kit import SimulationApp
simulation_app = SimulationApp({"headless": True})

import omni.usd
from pxr import UsdLux, Gf
from omni.isaac.sensor import Camera
from omni.isaac.core import World
from omni.isaac.core.utils.stage import open_stage
import gc
import json
import os
import numpy as np
from pathlib import Path
from PIL import Image
import argparse
from scipy.spatial.transform import Rotation as R


def usd_to_isaac(q_usd):
    # USD Camera: -Z forward, +Y up
    # Isaac Camera: +Z forward, +Y up
    # 这两个差一个绕X轴180°的翻转
    R_fix = R.from_euler('x', 180, degrees=True)
    return (R_fix * R.from_quat(q_usd)).as_quat()


class CameraBatchProcessor:
    def __init__(self, config_file, scene_usd, output_dir):
        self.config_path = Path(config_file)
        self.scene_usd = Path(scene_usd)
        self.output_dir = Path(output_dir)

        with open(self.config_path, "r") as f:
            self.config = json.load(f)

    def _load_world(self):
        """打开场景"""
        omni.usd.get_context().close_stage()
        gc.collect()

        if not open_stage(usd_path=str(self.scene_usd)):
            print(f"❌ 无法打开 {self.scene_usd}")
            return None, None

        stage = omni.usd.get_context().get_stage()
        if not stage.GetPrimAtPath("/World/EnvLight"):
            dome = UsdLux.DomeLight.Define(stage, "/World/EnvLight")
            dome.CreateIntensityAttr(30000.0)
            dome.CreateColorAttr(Gf.Vec3f(1.0, 1.0, 1.0))

        world = World()
        world.reset()
        for _ in range(5):
            world.step(render=True)

        return world, stage

    def create_camera(self, prim_path, position, rotation):
        """创建相机（不做任何会改动视角的额外操作）"""
        # orientation = np.array(rotation)  # 直接用JSON给的四元数
        orientation = usd_to_isaac(np.array(rotation))

        position = np.array(position)

        camera = Camera(
            prim_path=prim_path,
            position=position,
            frequency=10,
            resolution=(640, 480),
            orientation=orientation
        )
        camera.initialize()
        return camera

    def run(self):
        world, _ = self._load_world()
        if not world:
            return

        samples = self.config["scenes"][0]["samples"][0]["points"]

        cameras = []
        for i, point in enumerate(samples):
            prim_path = f"/World/cam_{i}"
            cam = self.create_camera(
                prim_path,
                point["position"],
                point["rotation"]
            )
            cameras.append(cam)

        # 渲染几帧以稳定画面
        for _ in range(5):
            world.step(render=True)

        for i, point in enumerate(samples):
            img = cameras[i].get_rgba()[:, :, :3]
            out_dir = self.output_dir / f"point_{point['point']}.png"
            out_dir.parent.mkdir(parents=True, exist_ok=True)
            Image.fromarray(img).save(out_dir)
            print(f"✅ {out_dir} Saved")

        simulation_app.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_file", required=True, help="轨迹 JSON 文件")
    parser.add_argument("--scene_usd", required=True, help="USD 场景文件")
    parser.add_argument("--output_dir", required=True, help="输出图片目录")

    args = parser.parse_args()

    processor = CameraBatchProcessor(
        config_file=args.config_file,
        scene_usd=args.scene_usd,
        output_dir=args.output_dir
    )
    processor.run()