from omni.isaac.kit import SimulationApp
simulation_app = SimulationApp({"headless": True})

import omni.usd
from pxr import UsdLux, UsdGeom, Gf
from omni.isaac.core import World
from omni.isaac.core.utils.stage import open_stage
from omni.kit.viewport.utility import get_active_viewport
from PIL import Image
import json, gc
import numpy as np
from pathlib import Path
import argparse
import collections

class ReplayCapture:
    def __init__(self, scene_usd_path, traj_json_path, output_path):
        self.scene_usd_path = Path(scene_usd_path)
        self.traj_json_path = Path(traj_json_path)
        self.output_path = Path(output_path)
        with open(self.traj_json_path, "r") as f:
            self.traj_data = json.load(f)
        self.scene_name = self.traj_data["scenes"][0]["scene_name"]
        self.scene_info = self.traj_data["scenes"][0]

    def _load_stage(self):
        omni.usd.get_context().close_stage()
        gc.collect()
        if not open_stage(usd_path=str(self.scene_usd_path)):
            print(f"❌ 无法打开 USD 场景 {self.scene_usd_path}")
            return None
        stage = omni.usd.get_context().get_stage()
        if not stage.GetPrimAtPath("/World/EnvLight"):
            dome = UsdLux.DomeLight.Define(stage, "/World/EnvLight")
            dome.CreateIntensityAttr(30000.0)
            dome.CreateColorAttr(Gf.Vec3f(1.0, 1.0, 1.0))
        return stage

    def run(self):
        stage = self._load_stage()
        if stage is None:
            return
        world = World()
        world.reset()
        for _ in range(5):
            world.step(render=True)

        # 直接创建 USD Camera prim
        cam_path = "/World/MySensorCamera"
        if not stage.GetPrimAtPath(cam_path):
            usd_cam = UsdGeom.Camera.Define(stage, cam_path)
        else:
            usd_cam = UsdGeom.Camera(stage.GetPrimAtPath(cam_path))

        viewport = get_active_viewport()
        viewport.set_active_camera(cam_path)

        all_results = []
        for sample in self.scene_info["samples"]:
            trajectory_id = sample["trajectory_id"]
            trajectory_output_dir = self.output_path / self.scene_name / f"trajectory_{trajectory_id}"
            trajectory_output_dir.mkdir(parents=True, exist_ok=True)

            for p in sample["points"]:
                # 光学参数
                if "focal_length" in p:
                    usd_cam.GetFocalLengthAttr().Set(p["focal_length"])
                if "horizontal_aperture" in p:
                    usd_cam.GetHorizontalApertureAttr().Set(p["horizontal_aperture"])
                if "vertical_aperture" in p:
                    usd_cam.GetVerticalApertureAttr().Set(p["vertical_aperture"])
                if "focus_distance" in p:
                    usd_cam.GetFocusDistanceAttr().Set(p["focus_distance"])
                if "clipping_range" in p:
                    usd_cam.GetClippingRangeAttr().Set(Gf.Vec2f(*p["clipping_range"]))

                # 设置位姿
                pos = Gf.Vec3d(*p["position"])
                quat = Gf.Quatd(p["rotation"][3],
                                Gf.Vec3d(p["rotation"][0], p["rotation"][1], p["rotation"][2]))
                xformable = UsdGeom.Xformable(stage.GetPrimAtPath(cam_path))
                m = Gf.Matrix4d().SetRotate(quat)
                m.SetTranslateOnly(pos)
                xformable.MakeMatrixXform().Set(m)

                # 渲染几帧
                for _ in range(5):
                    world.step(render=True)

                # 抓取图像
                img_data = viewport.get_texture("color")
                if img_data is None:
                    print(f"⚠️ {trajectory_id} 点 {p['point']} 无图像")
                    continue
                res = viewport.get_resolution()
                img_array = np.frombuffer(img_data, dtype=np.uint8).reshape((res[1], res[0], 4))

                image_id = f"scene_{self.scene_info['scene_id']}_traj_{trajectory_id}_point_{p['point']}"
                output_path = trajectory_output_dir / f"{image_id}.png"
                Image.fromarray(img_array[:, :, :3]).save(output_path)
                print(f"✅ 保存: {output_path}")

                all_results.append({
                    'trajectory_id': trajectory_id,
                    'point_id': p['point'],
                    'image_id': image_id,
                    'image_path': f"trajectory_{trajectory_id}/{image_id}.png",
                    'position': p['position'],
                    'rotation': p['rotation']
                })

        self._save_json(all_results)
        simulation_app.close()

    def _save_json(self, all_results):
        results_by_traj = collections.defaultdict(list)
        for res in all_results:
            results_by_traj[res['trajectory_id']].append(res)
        for sample in self.scene_info["samples"]:
            traj_id = sample["trajectory_id"]
            point_map = {p["point"]: p for p in sample["points"]}
            for res in results_by_traj.get(traj_id, []):
                point_map[res['point_id']].setdefault("camera_images", []).append({
                    "image_id": res['image_id'],
                    "image_path": res['image_path']
                })
            trajectory_info = {
                "trajectory_id": traj_id,
                "points": list(point_map.values())
            }
            traj_dir = self.output_path / self.scene_name / f"trajectory_{traj_id}"
            traj_info_path = traj_dir / "trajectory_info.json"
            with open(traj_info_path, 'w') as f:
                json.dump(trajectory_info, f, indent=2)
        scene_info_result = {
            'scene_id': self.scene_info['scene_id'],
            'scene_name': self.scene_name,
            'samples': self.scene_info['samples']
        }
        scene_info_path = self.output_path / self.scene_name / 'scene_info.json'
        with open(scene_info_path, 'w') as f:
            json.dump(scene_info_result, f, indent=2)
        print(f"📄 已保存 scene_info.json 至 {scene_info_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--scene_usd_path", required=True)
    parser.add_argument("--traj_json_path", required=False, default='none')
    parser.add_argument("--output_path", required=True)
    args = parser.parse_args()
    if args.traj_json_path == 'none':
        quit()
    ReplayCapture(args.scene_usd_path, args.traj_json_path, args.output_path).run()