from omni.isaac.kit import SimulationApp
simulation_app = SimulationApp({"headless": True})

import omni.usd
from pxr import UsdLux, UsdGeom, Gf
from omni.isaac.sensor import Camera
from omni.isaac.core import World
from omni.isaac.core.utils.stage import open_stage
from PIL import Image
import json, gc, os
import numpy as np
from pathlib import Path
import argparse
import collections

class ReplayCapture:
    def __init__(self, scene_usd_path, traj_json_path, output_path):
        self.scene_usd_path = Path(scene_usd_path)  # .usda / .usd
        self.traj_json_path = Path(traj_json_path)  # 轨迹JSON
        self.output_path = Path(output_path)        # 输出根目录

        # 加载轨迹文件
        with open(self.traj_json_path, "r") as f:
            self.traj_data = json.load(f)

        # 确定场景名
        self.scene_name = self.traj_data["scenes"][0]["scene_name"]
        self.scene_info = self.traj_data["scenes"][0]

    def _load_stage(self):
        omni.usd.get_context().close_stage()
        gc.collect()
        if not open_stage(usd_path=str(self.scene_usd_path)):
            print(f"❌ 无法打开 USD 场景 {self.scene_usd_path}")
            return None
        stage = omni.usd.get_context().get_stage()

        # 确保有环境光
        if not stage.GetPrimAtPath("/World/EnvLight"):
            dome = UsdLux.DomeLight.Define(stage, "/World/EnvLight")
            dome.CreateIntensityAttr(30000.0)
            dome.CreateColorAttr(Gf.Vec3f(1.0, 1.0, 1.0))
        return stage

    def run(self):
        stage = self._load_stage()
        if stage is None:
            return

        # World 初始化
        world = World()
        world.reset()
        for _ in range(5):
            world.step(render=True)

        # 创建相机
        sensor_cam_path = "/World/MySensorCamera"
        cam = Camera(prim_path=sensor_cam_path, frequency=30, resolution=(640, 480))
        cam.initialize()

        # 相机USD接口
        cam_prim = stage.GetPrimAtPath(sensor_cam_path)
        usd_cam = UsdGeom.Camera(cam_prim)

        all_results = []

        # 遍历轨迹点
        for sample in self.scene_info["samples"]:
            trajectory_id = sample["trajectory_id"]

            # 输出文件夹：scene_name/trajectory_X
            trajectory_output_dir = self.output_path / self.scene_name / f"trajectory_{trajectory_id}"
            trajectory_output_dir.mkdir(parents=True, exist_ok=True)

            for p in sample["points"]:
                # 光学参数
                if "focal_length" in p:
                    usd_cam.GetFocalLengthAttr().Set(p["focal_length"])
                if "horizontal_aperture" in p:
                    usd_cam.GetHorizontalApertureAttr().Set(p["horizontal_aperture"])
                if "vertical_aperture" in p:
                    usd_cam.GetVerticalApertureAttr().Set(p["vertical_aperture"])
                if "focus_distance" in p:
                    usd_cam.GetFocusDistanceAttr().Set(p["focus_distance"])
                if "clipping_range" in p:
                    usd_cam.GetClippingRangeAttr().Set(Gf.Vec2f(*p["clipping_range"]))

                # 设置位姿
                cam.set_world_pose(
                    position=np.array(p["position"], dtype=np.float32),
                    orientation=np.array(p["rotation"], dtype=np.float32)
                )

                # 渲染几帧等 pose 生效
                for _ in range(5):
                    world.step(render=True)

                img = cam.get_rgba()
                if img is None or img.size == 0:
                    print(f"⚠️ 轨迹 {trajectory_id} 点 {p['point']} 无图像")
                    continue

                # 文件名：scene_X_traj_Y_point_Z
                image_id = f"scene_{self.scene_info['scene_id']}_traj_{trajectory_id}_point_{p['point']}"
                output_path = trajectory_output_dir / f"{image_id}.png"
                Image.fromarray(img[:, :, :3]).save(output_path)

                print(f"✅ 保存: {output_path}")

                all_results.append({
                    'trajectory_id': trajectory_id,
                    'point_id': p['point'],
                    'image_id': image_id,
                    'image_path': f"trajectory_{trajectory_id}/{image_id}.png",
                    'position': p['position'],
                    'rotation': p['rotation']
                })

        # 更新并保存 JSON
        self._save_json(all_results)
        simulation_app.close()

    def _save_json(self, all_results):
        """将 all_results 写回成 trajectory_info.json + scene_info.json"""
        results_by_traj = collections.defaultdict(list)
        for res in all_results:
            results_by_traj[res['trajectory_id']].append(res)

        # 原 points 加上 camera_images
        for sample in self.scene_info["samples"]:
            traj_id = sample["trajectory_id"]
            point_map = {p["point"]: p for p in sample["points"]}
            for res in results_by_traj.get(traj_id, []):
                point_map[res['point_id']].setdefault("camera_images", []).append({
                    "image_id": res['image_id'],
                    "image_path": res['image_path']
                })

            # trajectory_info.json
            trajectory_info = {
                "trajectory_id": traj_id,
                "points": list(point_map.values())
            }
            traj_dir = self.output_path / self.scene_name / f"trajectory_{traj_id}"
            traj_info_path = traj_dir / "trajectory_info.json"
            with open(traj_info_path, 'w') as f:
                json.dump(trajectory_info, f, indent=2)

        # scene_info.json
        scene_info_result = {
            'scene_id': self.scene_info['scene_id'],
            'scene_name': self.scene_name,
            'samples': self.scene_info['samples']
        }
        scene_info_path = self.output_path / self.scene_name / 'scene_info.json'
        with open(scene_info_path, 'w') as f:
            json.dump(scene_info_result, f, indent=2)

        print(f"📄 已保存 scene_info.json 至 {scene_info_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--scene_usd_path", required=True, help="USD 文件路径（.usd/.usda）")
    parser.add_argument("--traj_json_path", required=True, help="轨迹 JSON 文件路径")
    parser.add_argument("--output_path", required=True, help="输出文件夹路径")
    args = parser.parse_args()

    ReplayCapture(args.scene_usd_path, args.traj_json_path, args.output_path).run()