from isaacsim import SimulationApp
simulation_app = SimulationApp({"headless": False})

import os
import sys
import json
import random
import numpy as np
from termcolor import cprint

sys.path.append(os.getcwd())
from Env_StandAlone.BaseEnv import BaseEnv
from Env_Config.Table.Table import Table
from Env_Config.Room.Real_Ground import Real_Ground
from Env_Config.Camera.Recording_Camera import Recording_Camera
from Env_Config.Room.Object_Tools import set_prim_visible_group

from Env_Config.Flat_Object.Flat_Object import Rigid

OBJECT_CONFIGS = {
    # Train / Test_Seen
    "Board": {"path": "Assets/Flat_Object/Board", "count": 6},
    "Book": {"path": "Assets/Flat_Object/Book", "count": 8},
    "Box": {"path": "Assets/Flat_Object/Box", "count": 8},
    "Clock": {"path": "Assets/Flat_Object/Clock", "count": 4},
    "Disk": {"path": "Assets/Flat_Object/Disk", "count": 6},
    "Keyboard": {"path": "Assets/Flat_Object/Keyboard", "count": 6},
    "Photo_Album": {"path": "Assets/Flat_Object/Photo_Album", "count": 4},
    "Plate": {"path": "Assets/Flat_Object/Plate", "count": 6},
    # Test_Unseen
    "Chessboard": {"path": "Assets/Flat_Object/Chessboard", "count": 4},
    "Cutting_Board": {"path": "Assets/Flat_Object/Cutting_Board", "count": 6},
    "Laptop": {"path": "Assets/Flat_Object/Laptop", "count": 2},
    "Magazine": {"path": "Assets/Flat_Object/Magazine", "count": 6},
    "Pad": {"path": "Assets/Flat_Object/Pad", "count": 6},
    "Painting": {"path": "Assets/Flat_Object/Painting", "count": 4},
    "Disk_Case": {"path": "Assets/Flat_Object/Disk_Case", "count": 4},
}

WOOD_TABLE_PATH = "Assets/Table/Collected_Willow/Wood.usd"
VIEW5_POS = np.array([0.0, -3.1, 3.8])
VIEW5_ORI_QUAT = np.array([0.92388, 0.38268, 0.0, 0.0])

RIGID_Z = 0.83
X_ROT_DEG = 0.0


def euler_to_quaternion(euler_deg, order: str = 'xyz') -> np.ndarray:
    roll_deg, pitch_deg, yaw_deg = float(euler_deg[0]), float(euler_deg[1]), float(euler_deg[2])
    rx = np.radians(roll_deg) / 2.0
    ry = np.radians(pitch_deg) / 2.0
    rz = np.radians(yaw_deg) / 2.0

    cx, sx = np.cos(rx), np.sin(rx)
    cy, sy = np.cos(ry), np.sin(ry)
    cz, sz = np.cos(rz), np.sin(rz)

    # 'xyz'  q = qx * qy * qz
    w = cx*cy*cz - sx*sy*sz
    x = sx*cy*cz + cx*sy*sz
    y = cx*sy*cz - sx*cy*sz
    z = cx*cy*sz + sx*sy*cz
    return np.array([w, x, y, z], dtype=float)


def round2(v):
    return float(f"{v:.2f}")


def rand_pose_train(object_type: str):
    group_xp = {"Book", "Disk", "Plate", "Magazine", "Pad", "Disk_Case"}
    group_xy = {"Board", "Box", "Clock", "Keyboard", "Photo_Album", "Cutting_Board", "Chessboard", "Painting", "Laptop"}
    if object_type in group_xp:
        x = round2(random.uniform(0.15, 0.50))
        y = round2(random.uniform(0.00, 0.21))
        yaw = random.randint(-10, 10)
    elif object_type in group_xy:
        x = round2(random.uniform(-0.14, 0.20))
        y = round2(random.uniform(-0.05, 0.11))
        yaw = random.randint(-10, 10)
    else:
        x = round2(random.uniform(-0.14, 0.20))
        y = round2(random.uniform(-0.05, 0.11))
        yaw = random.randint(-10, 10)
    return [x, y, RIGID_Z], yaw


def rand_pose_yaw_only():
    return [0.0, 0.0, RIGID_Z], random.randint(-10, 10)


def normalize_scale_entry(entry):
    try:
        if isinstance(entry, (int, float)):
            s = float(entry)
            return [s, s, s]
        if isinstance(entry, (list, tuple)):
            arr = list(map(float, entry))
            if len(arr) == 2:
                return [arr[0], arr[1], 0.8]
            if len(arr) >= 3:
                return [arr[0], arr[1], arr[2]]
    except Exception:
        pass
    return [0.8, 0.8, 0.8]


def load_scales_from_config(cfg_path: str, object_type: str, split: str, object_index: int | None = None):
    default_only = [[0.8, 0.8, 0.8]]
    try:
        with open(cfg_path, "r", encoding="utf-8") as f:
            cfg = json.load(f)
    except Exception:
        return default_only

    item_key = None
    if object_index is not None:
        item_key = f"{object_type}{int(object_index)}"
    node = cfg.get(item_key) if item_key in cfg else None
    if node is None:
        node = cfg.get(object_type)
    if isinstance(node, dict):
        if split in node and isinstance(node[split], list) and node[split]:
            return [normalize_scale_entry(x) for x in node[split]]
        if "scales" in node and isinstance(node["scales"], list) and node["scales"]:
            return [normalize_scale_entry(x) for x in node["scales"]]
    if "Default" in cfg:
        d = cfg["Default"]
        if isinstance(d, dict):
            if split in d and isinstance(d[split], list) and d[split]:
                return [normalize_scale_entry(x) for x in d[split]]
            if "scales" in d and isinstance(d["scales"], list) and d["scales"]:
                return [normalize_scale_entry(x) for x in d["scales"]]
        if isinstance(d, list) and d:
            return [normalize_scale_entry(x) for x in d]
    return default_only


def parse_scale_arg(scale_str: str | None, cfg_path: str, object_type: str, split: str, object_index: int | None = None):
    if scale_str:
        parts = [p.strip() for p in scale_str.split(',')]
        if len(parts) == 1:
            s = float(parts[0])
            return [s, s, s]
        if len(parts) == 2:
            sx, sy = map(float, parts)
            return [sx, sy, 0.8]
        sx, sy, sz = map(float, parts[:3])
        return [sx, sy, sz]
    scales = load_scales_from_config(cfg_path, object_type, split, object_index)
    return scales[0]


class RigidVisionEnv(BaseEnv):
    def __init__(self, object_type: str, object_index: int, scale_vec, split: str):
        super().__init__()
        if object_type not in OBJECT_CONFIGS:
            raise ValueError(f"{object_type}")
        self.object_type = object_type
        self.object_index = object_index
        self.object_name = f"{object_type}/{object_type}{object_index}"

        self.ground = Real_Ground(self.scene, visual_material_usd=None)
        self.table = Table(
            world=self.world,
            path=os.path.join(os.getcwd(), WOOD_TABLE_PATH),
            position=[0.0, 0.0, 0.0],
            orientation=[0.0, 0.0, 0.0],
            scale = np.array([0.0088, 0.0104, 0.01]),
            
        )

        self.object_camera5 = Recording_Camera(
            camera_position=VIEW5_POS,
            camera_orientation=VIEW5_ORI_QUAT,
            resolution=(600, 400),
            prim_path=f"/World/object_camera5",
        )
        self.env_camera5 = Recording_Camera(
            camera_position=VIEW5_POS,
            camera_orientation=VIEW5_ORI_QUAT,
            resolution=(600, 400),
            prim_path=f"/World/env_camera5",
        )

 
        cfg = OBJECT_CONFIGS[object_type]
        usd_path = os.path.join(os.getcwd(), f"{cfg['path']}/{object_type}{object_index}.usd")
        self.current_object = Rigid(
            world=self.world,
            path=usd_path,
            position=[0.0, 0.0, RIGID_Z],
            orientation=[90.0, 0.0, 0.0],
            scale=np.array(scale_vec, dtype=float),
        )
        self.scale_idx = 0  

        self.reset()
        self.object_camera5.initialize(
            depth_enable=True,
            segment_pc_enable=True,
            segment_prim_path_list=["/World/Rigid"]
        )
        self.env_camera5.initialize(
            segment_pc_enable=True,
            segment_prim_path_list=["/World/Rigid", "/World/Table"]
        )
        for _ in range(30):
            self.step()

    def set_pose(self, position, yaw_deg):
        euler_deg = [X_ROT_DEG, 0.0, float(yaw_deg)]
        quat = euler_to_quaternion(euler_deg, order='xyz')
        self.current_object.set_world_poses(position=position, orientation=quat)
        for _ in range(60):
            self.step()

    def capture_all(self, filename_base: str):
        for _ in range(100):
            self.step()

        split_dir = getattr(self, "split_dir", "train")
        strategy_a = {"Book", "Disk", "Plate", "Magazine", "Pad", "Disk_Case"}
        strategy_b = {"Box", "Board", "Keyboard", "Photo_Album", "Clock", "Cutting_Board", "Chessboard", "Painting", "Laptop"}
        obj_type = getattr(self, "object_type", "")
        if obj_type in strategy_a:
            strategy = "Strategy_A"
        else:
            strategy = "Strategy_B"

        rgb_root = os.path.join(os.getcwd(), "Data/FlatLab", "Data_op_RGB", split_dir, strategy)
        os.makedirs(rgb_root, exist_ok=True)

        rgb_save_path = os.path.join(rgb_root, f"{filename_base}_Wood_RGB.png")


        depth_root = os.path.join(os.getcwd(), "Data/FlatLab", "Data_op_Depth", split_dir, strategy)
        env_root = os.path.join(os.getcwd(), "Data/FlatLab", "Data_op_PointCloud_Env", split_dir, strategy)
        obj_root = os.path.join(os.getcwd(), "Data/FlatLab", "Data_op_PointCloud_Obj", split_dir, strategy)
        os.makedirs(depth_root, exist_ok=True)
        os.makedirs(env_root, exist_ok=True)
        os.makedirs(obj_root, exist_ok=True)

        depth_save_path = os.path.join(depth_root, f"{filename_base}_Wood_Depth.png")
        obj_pc_save_path = os.path.join(obj_root, f"{filename_base}_Wood_Obj.ply")
        env_pc_save_path = os.path.join(env_root, f"{filename_base}_Wood_Env.ply")

        self.object_camera5.get_rgb_graph(save_or_not=True, save_path=rgb_save_path)
        self.object_camera5.get_depth_graph(save_or_not=True, save_path=depth_save_path)

        self.env_camera5.get_point_cloud_data_from_segment(
            save_or_not=True,
            save_path=env_pc_save_path,
            sample_flag=True,
            sampled_point_num=8192,
            real_time_watch=False
        )

        set_prim_visible_group(prim_path_list=["/World/Table"], visible=False)
        for _ in range(50):
            self.step()

        self.object_camera5.get_point_cloud_data_from_segment(
            save_or_not=True,
            save_path=obj_pc_save_path,
            sample_flag=True,
            sampled_point_num=2048,
            real_time_watch=False
        )

        set_prim_visible_group(prim_path_list=["/World/Table"], visible=True)
        for _ in range(100):
            self.step()


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("object_type", type=str, help="Object type")
    parser.add_argument("object_index", type=int, help="Object index (start from 1)")
    parser.add_argument("--split", type=str, default="Train", choices=["Train", "Test_Seen", "Test_Unseen"], help="Data partition")
    parser.add_argument("--config", type=str, default="Env_Config/Flat_Object/rigid_scales.json", help="Path to scale configuration JSON file")
    parser.add_argument("--positions", type=int, default=5, help="Random pose times per scale")
    parser.add_argument("--scale", type=str, default=None, help="Single scale, format: '0.8' or '0.8,0.9' or '0.8,0.9,0.8'")
    parser.add_argument("--scale-idx", type=int, default=0, help="Scale index (for naming: Scale<idx>)")
    args = parser.parse_args()

    object_type = args.object_type
    object_index = int(args.object_index)
    split = args.split
    scale_vec = parse_scale_arg(args.scale, args.config, object_type, split, object_index)

    # Initialize environment and object (object created in __init__)
    env = RigidVisionEnv(object_type, object_index, scale_vec, split)
    env.scale_idx = max(0, int(getattr(args, "scale_idx", 0)))
    env.split_dir = {"Train": "train", "Test_Seen": "test_seen", "Test_Unseen": "test_unseen"}.get(split, "train")

    table_name = "Wood"
    for i in range(1, max(1, args.positions) + 1):
        if split == "Train":
            pos, yaw = rand_pose_train(object_type)
        else:
            pos, yaw = rand_pose_yaw_only()
        env.set_pose(pos, yaw)
        # {Type}_view5_{Type}{idx}_Scale{scale_idx}_Position{i}
        filename_base = f"{object_type}_view5_{object_type}{object_index}_Scale{env.scale_idx}_Position{i}"
        env.capture_all(filename_base)
        cprint(f"[SAVE] {filename_base}", "green")

    simulation_app.close()


if __name__ == "__main__":
    main()
