from isaacsim import SimulationApp
simulation_app = SimulationApp({"headless": True})

import os
import sys
import random
import numpy as np
from termcolor import cprint

# Isaac / USD
from isaacsim.core.api import World
from isaacsim.core.api import SimulationContext
from isaacsim.core.utils.prims import delete_prim
from isaacsim.core.utils.rotations import euler_angles_to_quat

# Project modules
sys.path.append(os.getcwd())
from Env_StandAlone.BaseEnv import BaseEnv
from Env_Config.Camera.Recording_Camera import Recording_Camera
from Env_Config.Room.Real_Ground import Real_Ground
from Env_Config.Table.Table import Table
from Env_Config.Garment.Particle_Garment import Particle_Garment
from Env_Config.Garment.Rigid_Garment import Rigid_Garment

import argparse

# -------------------- Args --------------------
parser = argparse.ArgumentParser()
parser.add_argument('object_type', type=str, help='Object type, e.g. Fabric, Painting, etc')
parser.add_argument('object_index', type=int, help='Object index, starting from 1')
parser.add_argument('--mode', type=str, choices=['Deformable', 'Rigid'], default='Deformable', help='Capture mode: Deformable or Rigid')
parser.add_argument('--split', type=str, default='train', choices=['train', 'test_seen', 'test_unseen'], help='Data partition directory name')
args = parser.parse_args()

OBJECT_TYPE = args.object_type
OBJECT_INDEX = args.object_index
MODE = args.mode
SPLIT = args.split  # train | test_seen | test_unseen

# -------------------- Helpers --------------------
VIEW_NAME = 'view5'
VIEW5_POS = np.array([0.0, -3.1, 3.8])
VIEW5_ORI = np.array([0.92388, 0.38268, 0.0, 0.0])  # [w, x, y, z]
TABLE_WOOD_USD = os.path.join(os.getcwd(), 'Assets/Table/Collected_Willow/Wood.usd')

# Directory Strategy: All outputs of this file are saved in Strategy_C uniformly
strategy = 'Strategy_C'

# New directory structure (this script writes to Strategy_C by default):
# RGB:   Data/FlatLab/Data_op_RGB/<split>/Strategy_C/
# Depth: Data/FlatLab/Data_op_Depth/<split>/Strategy_C/
# Env:   Data/FlatLab/Data_op_PointCloud_Env/<split>/Strategy_C/
# Obj:   Data/FlatLab/Data_op_PointCloud_Obj/<split>/Strategy_C/
RGB_DIR = os.path.join('Data', 'FlatLab', 'Data_op_RGB', SPLIT, strategy)
DEPTH_DIR = os.path.join('Data', 'FlatLab', 'Data_op_Depth', SPLIT, strategy)
PC_ENV_DIR = os.path.join('Data', 'FlatLab', 'Data_op_PointCloud_Env', SPLIT, strategy)
PC_OBJ_DIR = os.path.join('Data', 'FlatLab', 'Data_op_PointCloud_Obj', SPLIT, strategy)
for d in (RGB_DIR, DEPTH_DIR, PC_ENV_DIR, PC_OBJ_DIR):
    os.makedirs(d, exist_ok=True)

USD_PATH = os.path.join(os.getcwd(), f"Assets/Flat_Object/{OBJECT_TYPE}/{OBJECT_TYPE}{OBJECT_INDEX}.usd")

POS_COUNT = 5
X_MIN, X_MAX = 0.15, 0.50
Y_MIN, Y_MAX = 0.00, 0.21
YAW_MIN, YAW_MAX = -10, 10
Z_FIXED = 0.90

def euler_to_quaternion(euler_deg, order: str = 'xyz') -> np.ndarray:
    roll_deg, pitch_deg, yaw_deg = float(euler_deg[0]), float(euler_deg[1]), float(euler_deg[2])
    rx = np.radians(roll_deg) / 2.0
    ry = np.radians(pitch_deg) / 2.0
    rz = np.radians(yaw_deg) / 2.0

    cx, sx = np.cos(rx), np.sin(rx)
    cy, sy = np.cos(ry), np.sin(ry)
    cz, sz = np.cos(rz), np.sin(rz)

    w = cx*cy*cz - sx*sy*sz
    x = sx*cy*cz + cx*sy*sz
    y = cx*sy*cz - sx*cy*sz
    z = cx*cy*sz + sx*sy*cz
    return np.array([w, x, y, z], dtype=float)

def rand_pos_angle():
    # Only change the Z-axis angle (integer from -10 to 10), keep position unchanged
    x = 0.0 
    y = 0.0 
    z = 0.725 if MODE == 'Rigid' else Z_FIXED
    yaw = int(round(random.uniform(YAW_MIN, YAW_MAX)))
    if MODE == 'Rigid':
        return [x, y, z], [0.0, 0.0, float(yaw)]
    else:
        return [x, y, z], [0.8, 0.8, float(yaw)]

class DeformableVisionEnv(BaseEnv):
    def __init__(self):
        super().__init__()
        # Ground and single Wood table
        self.ground = Real_Ground(self.scene, visual_material_usd=None)
        self.table = Table(
            path=TABLE_WOOD_USD,
            position=[0.0, 0.0, 0.0],
            orientation=[0.0, 0.0, 0.0],
            scale = np.array([0.0088, 0.0104, 0.01]),
            world=self.world
        )
        # view5 cameras
        self.object_camera5 = Recording_Camera(
            camera_position=VIEW5_POS,
            camera_orientation=VIEW5_ORI,
            resolution=(600, 400),
            prim_path=f"/World/object_camera5",
        )
        self.env_camera5 = Recording_Camera(
            camera_position=VIEW5_POS,
            camera_orientation=VIEW5_ORI,
            resolution=(600, 400),
            prim_path=f"/World/env_camera5",
        )
        # Create initial object to ensure segmentation Prim exists (determined by MODE)
        if MODE == 'Rigid':
            self.current_object = Rigid_Garment(
                world=self.world,
                path=USD_PATH,
                position=np.array([0.0, 0.0, 0.725]),
                orientation=np.array([0.0, 0.0, 0.0]),
                scale=np.array([0.8, 0.8, 0.8])
            )
        else:
            self.current_object = Particle_Garment(
                self.world,
                pos=np.array([0.0, 0.0, 0.9]),
                ### Default parameters
                # ori=np.array([0.8, 0.8, 0.0]),
                ### For Shorts、Skirt
                ori=np.array([1.2, 1.2, 0.0]),
                scale=np.array([0.8, 0.8, 0.8]),
                usd_path=USD_PATH,
                contact_offset=0.010,             # Important parameter 【Controls particle size】
                rest_offset=0.0075,                # Important parameter 【Controls particle size】
                particle_contact_offset=0.010,    # Important parameter
                adhesion=0.5,                     # Important parameter
                adhesion_offset_scale=0.0,        # Important parameter
                cohesion=0.0,                     # Important parameter
                particle_adhesion_scale=0.5,      # Important parameter
                particle_friction_scale=0.1,      # Important parameter
                friction=5.0,  
                lift=0.7,   
                gravity_scale=1.5,                # Important parameter
                particle_mass=1e-2,               # Important parameter
            )

        # Reset and set segmentation target immediately (ensure /World/Garment exists)
        self.reset()
        if MODE == 'Rigid':
            self.object_camera5.initialize(depth_enable=True, segment_pc_enable=True, segment_prim_path_list=["/World/Garment"])
            self.env_camera5.initialize(segment_pc_enable=True, segment_prim_path_list=["/World/Garment", "/World/Table"])
        else:
            self.object_camera5.initialize(depth_enable=True, segment_pc_enable=True, segment_prim_path_list=["/World/Garment/garment"])
            self.env_camera5.initialize(segment_pc_enable=True, segment_prim_path_list=["/World/Garment/garment", "/World/Table"])
        for _ in range(30):
            self.step()
        cprint('[Env] Ready with Wood table and view5 camera', 'green')

    # Mode switching is not required; mode is specified by CLI and fixed in __init__
    def load_object(self, mode: str):
        pass

    def set_object_pose(self, position, orientation):
        """Set object pose:
        - Particle deformable object: Convert Euler angles to quaternion before setting.
        - Rigid object: Pass Euler angles or quaternion directly, handled internally by the class.
        """
        self.current_object.set_world_poses(position=position, orientation=euler_to_quaternion(orientation))


def collect(env: DeformableVisionEnv):
    """Capture data for the current mode (specified by CLI)."""

    base_name = f"{OBJECT_TYPE}{OBJECT_INDEX}"

    for pos_idx in range(1, POS_COUNT + 1):
        pos, ang_deg = rand_pos_angle()
        env.set_object_pose(pos, ang_deg)

        # Stabilize the simulation
        for _ in range(200):
            env.step()
        # New naming rule: {Type}_view5_{Type}{Idx}_{Mode}_Position{pos}_Wood_<Modality>
        mode_tag = MODE
        stem = f"{OBJECT_TYPE}_view5_{base_name}_{mode_tag}_Position{pos_idx}"  # Remove Scale0, place Mode first as required
        rgb_path = os.path.join(RGB_DIR, f"{stem}_Wood_RGB.png")
        depth_path = os.path.join(DEPTH_DIR, f"{stem}_Wood_Depth.png")

        # Save RGB/Depth (follow the strict process)
        env.object_camera5.get_rgb_graph(save_or_not=True, save_path=rgb_path)
        env.object_camera5.get_depth_graph(save_or_not=True, save_path=depth_path)
        # Save environment point cloud (object + table) — follow the sample parameters strictly
        env.env_camera5.get_point_cloud_data_from_segment(
            save_or_not=True,
            save_path=os.path.join(PC_ENV_DIR, f"{stem}_Wood_Env.ply"),
            sample_flag=True,
            sampled_point_num=8192,
            real_time_watch=False
        )

        # Hide the table to get only the object point cloud
        from Env_Config.Room.Object_Tools import set_prim_visible_group
        set_prim_visible_group(["/World/Table"], visible=False)
        for _ in range(50):
            env.step()
        env.object_camera5.get_point_cloud_data_from_segment(
            save_or_not=True,
            save_path=os.path.join(PC_OBJ_DIR, f"{stem}_Wood_Obj.ply"),
            sample_flag=True,
            sampled_point_num=2048,
            real_time_watch=False
        )
        # Restore table visibility
        set_prim_visible_group(["/World/Table"], visible=True)
        for _ in range(100):
            env.step()
        cprint(f"[SAVE] {mode_tag} P{pos_idx} -> RGB:{rgb_path} Depth:{depth_path} Env/Obj PLY", 'cyan')

    for _ in range(5):
        env.step()


if __name__ == '__main__':
    try:
        if not os.path.isfile(USD_PATH):
            raise FileNotFoundError(f"USD file not found: {USD_PATH}")

        env = DeformableVisionEnv()
        collect(env)

        simulation_app.close()
    except Exception as e:
        import traceback
        cprint(f"[Error] {e}", 'red')
        traceback.print_exc()
        try:
            simulation_app.close()
        except Exception:
            pass
        sys.exit(1)
