# simulation_app.close()
from isaacsim import SimulationApp
simulation_app = SimulationApp({"headless": True})

# load external package
import os
import sys
import time
import json
import numpy as np
import open3d as o3d
from termcolor import cprint
import re
import threading

# load isaac-relevant package
import isaacsim.core.utils.prims as prims_utils
from pxr import UsdGeom, UsdPhysics
from isaacsim.core.api import World
from isaacsim.core.api.objects import FixedCuboid
from isaacsim.core.utils.prims import set_prim_visibility
from isaacsim.core.utils.rotations import euler_angles_to_quat

# load custom package
sys.path.append(os.getcwd())
from Env_StandAlone.BaseEnv import BaseEnv
from Env_Config.Robot.Bimanual_Franka import Bimanual_Franka
from Env_Config.Camera.Recording_Camera import Recording_Camera
from Env_Config.Room.Real_Ground import Real_Ground
from Env_Config.Room.Object_Tools import set_prim_visible_group
from Env_Config.Utils_Project.Code_Tools import get_unique_filename
from Env_Config.Utils_Project.Parse import parse_args_record
from Env_Config.Utils_Project.Point_Cloud_Manip import furthest_point_sampling
from Env_Config.Utils_Project.Object_Contact_Config import get_contact_config
from Env_Config.Utils_Project.Transforms import Rotation
from Env_Config.Table.Table import Table
from Env_Config.Flat_Object.Flat_Object import Rigid_Flat_Object
from Env_Config.Utils_Project.Pos_Ori import calculate_target_from_world_contact


def _to_serializable(value):
    if isinstance(value, np.ndarray):
        return value.tolist()
    if isinstance(value, (np.integer, np.floating)):
        return value.item()
    if isinstance(value, (list, tuple)):
        return [_to_serializable(v) for v in value]
    if isinstance(value, dict):
        return {k: _to_serializable(v) for k, v in value.items()}
    return value


def _write_point_cloud_ply3(
    file_path: str,
    points_cloud,
    colors=None,
    save_or_not: bool = True,
    sample_flag: bool = True,
    sampled_point_num: int = 8192,
    real_time_watch: bool = False,
):
    if points_cloud is None:
        return
    pts = np.asarray(points_cloud)
    if sample_flag:
        points_cloud, colors = furthest_point_sampling(points_cloud, colors, sampled_point_num)
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points_cloud)
    pcd.colors = o3d.utility.Vector3dVector(colors)
    if real_time_watch:
        o3d.visualization.draw_geometries([pcd])
    if save_or_not:
        o3d.io.write_point_cloud(file_path, pcd)


def _write_pose_txt(file_path: str, sample: dict):
    payload = {
        "joint_pos_L": _to_serializable(sample.get("joint_pos_L")),
        "joint_pos_R": _to_serializable(sample.get("joint_pos_R")),
        "ee_pose_L": _to_serializable(sample.get("ee_pose_L")),
        "ee_pose_R": _to_serializable(sample.get("ee_pose_R")),
    }
    with open(file_path, "w", encoding="utf-8") as handle:
        json.dump(payload, handle, ensure_ascii=True, indent=2)


def save_mid_samples(samples, target_dir: str):
    if not samples:
        cprint("[Save] No callback samples captured; skip mid outputs", "yellow")
        return
    os.makedirs(target_dir, exist_ok=True)
    for name in os.listdir(target_dir):
        if name.startswith("Mid") and (name.endswith("_env.ply") or name.endswith("_pos.txt")):
            try:
                os.remove(os.path.join(target_dir, name))
            except OSError:
                pass
    for idx, sample in enumerate(samples, start=1):
        pose_path = os.path.join(target_dir, f"Mid{idx}_pos.txt")
        _write_pose_txt(pose_path, sample)


class WearScarf_Env(BaseEnv):
    def __init__(
        self,
        pos: np.ndarray = None,
        ori: np.ndarray = None,
        usd_path: str = None,
        env_dx: float = 0.0,
        env_dy: float = 0.0,
        object_name: str = "Book1",
        object_random_init: bool = False,
        ground_material_usd: str = None,
        record_video_flag: bool = False,
    ):
        super().__init__()

        self.ground = Real_Ground(
            self.scene,
            visual_material_usd=ground_material_usd,
        )
        self.bimanual_franka = Bimanual_Franka(
            self.world,
            left_pos=np.array([0.57325, 0.52302, 0.52954]),
            left_ori=np.array([0.0, 0.0, -135.0]),
            left_franka_scale=np.array([1.0, 1.0, 1.0]),
            right_pos=np.array([-0.57325, 0.52302, 0.52954]),
            right_ori=np.array([0.0, 0.0, -45.0]),
            right_franka_scale=np.array([1.0, 1.0, 1.0]),
        )
        self.Table = Table(
            self.world,
            path=os.path.join(os.getcwd(), "Assets/Table/Collected_Willow/Wood.usd"),
            position=[0.0, 0.02152, 0.0],
            orientation=[0.0, 0.0, 0.0],
            scale=np.array([0.0088, 0.0104, 0.01]),
        )

        self.object_name = object_name
        category_name = re.sub(r"\d+$", "", self.object_name)

        self.object = Rigid_Flat_Object(
            self.world,
            name=self.object_name,
            path=os.path.join(
                os.getcwd(), "Assets/Flat_Object", category_name, f"{self.object_name}.usd"
            ),
            random_init=object_random_init,
            scale=np.array([0.8, 0.8, 1]),

        )

        self.env_camera = Recording_Camera(
            camera_position=np.array([0.0, -3.1, 3.8]),
            camera_orientation=np.array([0.92388, 0.38268, 0.0, 0.0]),
            resolution=(600, 400),
            prim_path="/World/env_camera",
        )

        self.recording_camera = Recording_Camera(
            camera_position=np.array([0.0, -5.04154, 3.61856]),
            camera_orientation=np.array([0.85717, 0.51504, 0.0, 0.0]),
            resolution=(2048, 1080),
            prim_path="/World/recording_camera",
        )
        cube = FixedCuboid(
            prim_path="/World/Cube",
            name="my_cube",
            position=np.array([-0.7, 0.0, 0.9]),
            scale=np.array([0.01, 0.01, 0.01]),
            color=np.array([0.0, 0.5, 1.0]),
        )

        self.env_pcd = None
        self.flat_object_pcd = None
        self.saved_callback_files = []
        self.name_slug = (self.object_name or "").lower()
        self.pos_index = None
        self.callback_save_count = 0

        cprint("self.object_name: {}".format(self.object_name), "green", attrs=["bold"])

        self.reset()

        if record_video_flag:
            self.thread_record = threading.Thread(
                target=self.recording_camera.collect_rgb_graph_for_video
            )
            self.thread_record.daemon = True

        self.recording_camera.initialize()

        self.env_camera.initialize(
            segment_pc_enable=True,
            segment_prim_path_list=[
                f"/World/Rigid_Flat_Object/{self.object_name}",
                "/World/Table",
            ],
        )

        for i in range(100):
            self.step()

        cprint("World Ready!", "green", "on_green")

    def record_callback(self, step_size):
        if self.step_num % 60 == 0:
            joint_pos_L = self.bimanual_franka.left_franka.get_joint_positions()
            joint_pos_R = self.bimanual_franka.right_franka.get_joint_positions()
            joint_state = np.array([*joint_pos_L, *joint_pos_R])
            self.callback_save_count += 1
            self.env_pcd, env_color = self.env_camera.get_point_cloud_data_from_segment(
                save_or_not=False,
                sample_flag=True,
                sampled_point_num=8192,
                real_time_watch=False,
            )

            pos_L, quat_L = self.bimanual_franka.left_franka.get_cur_ee_pos()
            pos_R, quat_R = self.bimanual_franka.right_franka.get_cur_ee_pos()
            ee_pose_L = np.array([*pos_L, *quat_L])
            ee_pose_R = np.array([*pos_R, *quat_R])
            self.saving_data.append({
                "joint_pos_L": joint_pos_L,
                "joint_pos_R": joint_pos_R,
                "ee_pose_L": ee_pose_L,
                "ee_pose_R": ee_pose_R,
            })
        
        self.step_num += 1


def WearScarf(pos, ori, usd_path, env_dx, env_dy, object_name, ground_material_usd, data_collection_flag, record_video_flag, object_random_init=False, pos_index=None):

    env = WearScarf_Env(pos, ori, usd_path, env_dx, env_dy, object_name, object_random_init, ground_material_usd, record_video_flag)
    cprint(f"[Info]: data_collection_flag: {data_collection_flag}", "green", attrs=["bold"])
    cprint(f"[Info]: record_video_flag: {record_video_flag}", "green", attrs=["bold"])
    cprint(f"[Info]: object_name: {object_name}", "green", attrs=["bold"])

    name_slug = (object_name or getattr(env, "object_name", None) or getattr(env.object, "name", None))
    if not name_slug:
        raise ValueError("object_name is not provided and cannot be inferred from the environment; please pass it explicitly via --object_name")
    name_slug = name_slug.lower()
    
    env.name_slug = name_slug
    env.pos_index = pos_index if pos_index is not None else 1
    env.callback_save_count = 0
    pos_idx = env.pos_index

    env.bimanual_franka.Gripper_Left_Open()
    env.object.set_mass(0.05)
    z_height = 0.776
    if "Disk" in object_name:
        z_height = 0.7731

    if record_video_flag:
        env.thread_record.start()
        cprint("[Info]: Video recording started", "green", attrs=["bold"])

    task_name = "StrategyA_V2"
    if not os.path.exists(f"Data/{task_name}/train_data/"):
                os.makedirs(f"Data/{task_name}/train_data/")
    if not os.path.exists(f"Data/{task_name}/pointcloud/"):
                os.makedirs(f"Data/{task_name}/pointcloud/")

    set_prim_visible_group(prim_path_list=["/World/Franka_Left","/World/Franka_Right"], visible=False)
    for i in range(2):
        env.step()
    whole_point_cloud_pcd, whole_point_cloud_pcd_color = env.env_camera.get_point_cloud_data_from_segment(
        save_or_not=False,
        sampled_point_num=8192,
        real_time_watch=False
    )    
    set_prim_visible_group(prim_path_list=["/World/Franka_Left","/World/Franka_Right"], visible=True)

    object_name_initial = getattr(env.object, "name", "None")
    contact_cfg = get_contact_config(object_name_initial)

    c_target_position, c_target_orientation = calculate_target_from_world_contact(
        contact_world_init=contact_cfg["contact_world_init"],
        initial_board_position=contact_cfg["initial_board_position"],
        initial_board_euler=contact_cfg.get("initial_board_euler"),
        initial_gripper_orientation_euler=contact_cfg.get("initial_gripper_orientation_euler"),
        current_board_position=env.object.my_position,
        current_board_euler=env.object.my_orientation
    )

    c_target_orientation = np.array ([-179.631, 0.308, 2.640])

    obj_pos, _ = env.object.rigid_form.get_world_pose()
    cprint(f"[Info]: Detected current object X position : {obj_pos}","white")

    secure_min = 0.61462 - 0.03
    secure_max = 0.61462 - 0.01

    e_mx = 0.05

    cprint(f"[Info]: Step distance for each move: e_mx:{e_mx}","white")

    c_x = c_target_position[0] -0.6 * e_mx
    cprint(f"[Info]: First move distance to contact point: c_x:{c_x}","white")

    if data_collection_flag:
        for i in range(2):
            env.step()
        env.record(task_name=f"{task_name}", stage_index=1)

    env.bimanual_franka.Rmpflow_Left_Move(target_position=np.array([c_x , c_target_position[1], 0.85]), 
    target_orientation = c_target_orientation) 

    y_noise = round(np.random.uniform(-0.1, 0.1), 2)
    if "Disk" in object_name:
        y_noise = round(np.random.uniform(-0.03, 0.03), 2)
    c_target_position[1] += y_noise
    cprint(f"[Info]: Random perturbation added to c_target_position Y: {y_noise}, new value: {c_target_position[1]}", "yellow")

    gripper_current_position = np.array([c_x, c_target_position[1], z_height])

    env.bimanual_franka.Rmpflow_Left_Move(target_position=gripper_current_position, 
    target_orientation = c_target_orientation)  

    obj_x = obj_pos[0]
    current_obj_x = obj_x
    i = 1
    fatal_error_flag = False
    while  current_obj_x <  secure_min:
        if i == 1:
            cprint(f"[Info]: Collect start pose info separately","white")
            joint_pos_L = env.bimanual_franka.left_franka.get_joint_positions()
            joint_pos_R = env.bimanual_franka.right_franka.get_joint_positions()
            pos_L, quat_L = env.bimanual_franka.left_franka.get_cur_ee_pos()
            pos_R, quat_R = env.bimanual_franka.right_franka.get_cur_ee_pos()
            ee_pose_L = np.array([*pos_L, *quat_L])
            ee_pose_R = np.array([*pos_R, *quat_R])
            env.saving_data_replay["start_joint_pos_L"] = joint_pos_L
            env.saving_data_replay["start_joint_pos_R"] = joint_pos_R
            env.saving_data_replay["start_ee_pose_L"] = ee_pose_L
            env.saving_data_replay["start_ee_pose_R"] = ee_pose_R
            start_payload = {
                "joint_pos_L": _to_serializable(joint_pos_L),
                "joint_pos_R": _to_serializable(joint_pos_R),
                "ee_pose_L": _to_serializable(ee_pose_L),
                "ee_pose_R": _to_serializable(ee_pose_R),
            }
            env.start_pose_payload = start_payload
            for i in range(2):
                env.step()

        obj_pos, _ = env.object.rigid_form.get_world_pose()
        current_left_gripper_pos, left_quat = env.bimanual_franka.left_franka.get_cur_ee_pos()
        gripper_pos_to_object = current_left_gripper_pos + Rotation(left_quat, np.array([0.0, 0.0, 0.1]))
        table_pos, _ = env.Table.rigid_form.get_world_pose()

        cprint(f"[Info]: -------In Loop-------")
        cprint(f"[Info]: Current object position: obj_pos:{obj_pos}")
        cprint(f"[Info]: Current gripper tip position: pos_e:{gripper_pos_to_object}")
        cprint(f"[Info]: Current table position: table_pos:{table_pos}")    

        if table_pos[2] - obj_pos[2]> 0.4:
            cprint(f"[Error]: Object left the table, exit current run","red")
            fatal_error_flag = True
            break
        if gripper_pos_to_object[1] - table_pos[1]> 0.8 or gripper_pos_to_object[0] - table_pos[0]> 0.8:
            cprint(f"[Error]: Gripper was bounced off, exit current run","red")
            fatal_error_flag = True
            break
        if gripper_pos_to_object[1] - obj_pos[1]> 0.8 or gripper_pos_to_object[0] - obj_pos[0]> 0.8:
            cprint(f"[Error]: Gripper is far from object, exit current run","red")
            fatal_error_flag = True
            break

        gripper_current_position = np.array([c_x + i* e_mx, c_target_position[1], z_height])
        cprint(f"[Info]: Target position for next move:{gripper_current_position}","white",)

        if gripper_current_position[0] >= secure_max:
            break
        env.bimanual_franka.Rmpflow_Left_Move(target_position=gripper_current_position,
        target_orientation = c_target_orientation)  
        
        obj_pos, _ = env.object.rigid_form.get_world_pose()
        obj_x = obj_pos[0]
        current_obj_x = obj_x
        cprint(f"[Info]: Current object X position after move: current_obj_x:{current_obj_x}","white")

        obj_pos, _ = env.object.rigid_form.get_world_pose()
        current_left_gripper_pos, left_quat = env.bimanual_franka.left_franka.get_cur_ee_pos()
        gripper_pos_to_object = current_left_gripper_pos + Rotation(left_quat, np.array([0.0, 0.0, 0.1]))
        table_pos, _ = env.Table.rigid_form.get_world_pose()

        if abs(gripper_pos_to_object[1] - obj_pos[1])> 0.3 or abs(gripper_pos_to_object[0] - obj_pos[0])> 0.3:
            cprint(f"[Error]: Gripper is far from object, exit current run","red")
            fatal_error_flag = True
            break

        i+=1
              
    obj_pos, _ = env.object.rigid_form.get_world_pose()
    obj_x = obj_pos[0]
    if obj_x <= secure_max - 0.03:
        next_obj_x = secure_max

    joint_pos_L = env.bimanual_franka.left_franka.get_joint_positions()
    joint_pos_R = env.bimanual_franka.right_franka.get_joint_positions()
    pos_L, quat_L = env.bimanual_franka.left_franka.get_cur_ee_pos()
    pos_R, quat_R = env.bimanual_franka.right_franka.get_cur_ee_pos()
    ee_pose_L = np.array([*pos_L, *quat_L])
    ee_pose_R = np.array([*pos_R, *quat_R])
    env.saving_data_replay["end_joint_pos_L"] = joint_pos_L
    env.saving_data_replay["end_joint_pos_R"] = joint_pos_R
    env.saving_data_replay["end_ee_pose_L"] = ee_pose_L
    env.saving_data_replay["end_ee_pose_R"] = ee_pose_R
    end_payload = {
        "joint_pos_L": _to_serializable(joint_pos_L),
        "joint_pos_R": _to_serializable(joint_pos_R),
        "ee_pose_L": _to_serializable(ee_pose_L),
        "ee_pose_R": _to_serializable(ee_pose_R),
    }
    env.end_pose_payload = end_payload
    for i in range(2):
        env.step()
    if data_collection_flag:
        env.stop_record()
    for i in range(2):
        env.step()

    if not fatal_error_flag:
        obj_pos, _ = env.object.rigid_form.get_world_pose()

        env.bimanual_franka.Rmpflow_Left_Move(target_position=np.array([gripper_current_position[0] - 0.008,
        gripper_current_position[1] - 0.008,
        z_height]),
        target_orientation = c_target_orientation)    

        env.bimanual_franka.Rmpflow_Left_Move(target_position=np.array([gripper_current_position[0] - 0.005,
        gripper_current_position[1] - 0.005,
        0.85]),
        target_orientation = c_target_orientation)

        env.bimanual_franka.Gripper_Left_Open()
        env.bimanual_franka.Rmpflow_Left_Move(target_position=np.array([0.85, 0.0, 0.85]),target_orientation = c_target_orientation)
        cprint("[Info]: Left_Move; Lateral movement completed", "white")

        env.bimanual_franka.Rmpflow_Left_Move(target_position=np.array([0.9, 0.0, 0.85]), target_orientation=np.array([-91.314, 1.700, 88.261 + env.object.my_orientation[2]]))
        cprint("[Info]: Left_Move; Rotation completed", "white")
        env.bimanual_franka.Rmpflow_Left_Move(target_position=np.array([0.9, 0, 0.77]), target_orientation=np.array([-91.314, 1.700, 88.261 + env.object.my_orientation[2]]))
        cprint("[Info]: Left_Move; Downward movement completed", "white")
        env.bimanual_franka.Rmpflow_Left_Move(target_position=np.array([0.8, obj_pos[1], 0.77]), target_orientation=np.array([-91.314, 1.700, 88.261 + env.object.my_orientation[2]]))
        env.bimanual_franka.Rmpflow_Left_Move(target_position=np.array([0.7, obj_pos[1], 0.77]), target_orientation=np.array([-91.314, 1.700, 88.261 + env.object.my_orientation[2]]))

        env.bimanual_franka.Rmpflow_Left_Move(target_position=np.array([0.63276, obj_pos[1], 0.77]), target_orientation=np.array([-91.314, 1.700, 88.261 + env.object.my_orientation[2]]))
        cprint("[Info]: Right_Move; Pre-grasp movement completed", "white")
        env.bimanual_franka.Gripper_Left_Close()
        if data_collection_flag:
            env.stop_record()
        cprint("[Info]: Gripper_Left_Move: Grasp completed", "green")
        env.bimanual_franka.Rmpflow_Left_Move(target_position=np.array([0.63276, obj_pos[1], 1.0]), target_orientation=np.array([-91.314, 1.700, 88.261 + env.object.my_orientation[2]]))
    
    else:
        if data_collection_flag:
            env.stop_record()
        cprint("[Info]: Skip execution due to failure", "green")        

    obj_pos, _ = env.object.rigid_form.get_world_pose()
    current_left_gripper_pos, left_quat = env.bimanual_franka.left_franka.get_cur_ee_pos()
    gripper_pos_to_object = current_left_gripper_pos + Rotation(left_quat, np.array([0.0, 0.0, 0.1]))
    table_pos, _ = env.Table.rigid_form.get_world_pose()

    dist_gripper_to_obj = np.linalg.norm(gripper_pos_to_object - obj_pos)
    cprint(f"[Info]: obj_pos:{obj_pos}","white")
    cprint(f"[Info]: pos_e:{gripper_pos_to_object}","white")
    cprint(f"[Info]: table_pos:{table_pos}","white")    
    cprint(f"[Info]: dist_gripper_to_obj:{dist_gripper_to_obj}","white")
    success_flag = False

    if obj_pos[2] - table_pos[2] > 0.5 and dist_gripper_to_obj < 0.8 and gripper_pos_to_object[2] - table_pos[2] > 0.4 and not fatal_error_flag:
        success_flag = True
        cprint("[Info]: WearScarf Task Success!", "green")
    else:
        cprint("[Info]: WearScarf Task Fail!", "red")

    if data_collection_flag:
        pos_idx_final = pos_index if pos_index is not None else (getattr(env, "pos_index", None) or 1)
        outcome_flag = "T" if success_flag else "F"
        base_dir = os.path.join("Data", task_name, "train_data", object_name)
        record_dir = os.path.join(base_dir, f"record_{pos_idx_final}_{outcome_flag}")
        samples = getattr(env, "last_recorded_samples", [])
        save_mid_samples(samples, record_dir)

        _write_point_cloud_ply3(
            record_dir + f"/initial_env.ply",
            whole_point_cloud_pcd,
            whole_point_cloud_pcd_color,
        )        

        start_payload = getattr(env, "start_pose_payload", None)
        if start_payload:
            _write_pose_txt(os.path.join(record_dir, "start_pos.txt"), start_payload)
        end_payload = getattr(env, "end_pose_payload", None)
        if end_payload:
            _write_pose_txt(os.path.join(record_dir, "end_pos.txt"), end_payload)
        cprint(f"[Info] Mid samples saved to {record_dir}", "green")

    if record_video_flag:
        try:
            env.recording_camera.capture = False
            if hasattr(env, "thread_record") and env.thread_record.is_alive():
                env.thread_record.join(timeout=2)
        except Exception as e:
            cprint(f"[video-thread] stop failed: {e}", "red")

    if record_video_flag:
        sub_folder = "success" if success_flag else "fail"
        video_dir = f"Data/{task_name}/Video/{sub_folder}"

        os.makedirs(video_dir, exist_ok=True)

        if pos_index is not None:
            video_path = f"{video_dir}/{object_name}_pos{pos_index}.mp4"
        else:
            video_path = get_unique_filename(f"{video_dir}/{name_slug}", ".mp4")
        
        env.recording_camera.create_mp4(video_path)
        
    for i in range(20):
        env.step()
    for i in range(150):
        env.step()

    simulation_app.close()
    cprint(f"[Info]: Collection Finished!", "green", attrs=["bold"])
    return success_flag

if __name__=="__main__":
    
    args = parse_args_record()
    
    pos = np.array([0.0, 0.30, 0.65])
    ori = np.array([90.0, 0.0, 0.0])
    usd_path = None
    env_dx = 0.0
    env_dy = 0.0

    if args.env_random_flag or args.garment_random_flag:
        np.random.seed(int(time.time()))
        if args.env_random_flag:
            env_dx = np.random.uniform(-0.05, 0.1)
            env_dy = np.random.uniform(-0.05, 0.05)
        if args.garment_random_flag:
            x = np.random.uniform(-0.05, 0.05)
            y = np.random.uniform(0.30, 0.40)
            pos = np.array([x,y,0.0])
            ori = np.array([90.0, 0.0, 0.0])
            Base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
            assets_lists = os.path.join(Base_dir,"Model_HALO/GAM/checkpoints/Scarf/assets_training_list.txt")
            assets_list = []
            with open(assets_lists,"r",encoding='utf-8') as f:
                for line in f:
                    clean_line = line.rstrip('\n')
                    assets_list.append(clean_line)
            usd_path=os.getcwd() + "/" + np.random.choice(assets_list)

    pos_index = None
    _pos_env = os.environ.get("POS_INDEX")
    if _pos_env and _pos_env.isdigit():
        pos_index = int(_pos_env)

    try:
        cprint(
            f"[Args] object_name={args.object_name} data_collection_flag={args.data_collection_flag} record_video_flag={args.record_video_flag} POS_INDEX={pos_index}",
            "yellow",
            attrs=["bold"],
        )
    except Exception:
        pass

    ok = WearScarf(
        pos, ori,
        usd_path,
        env_dx, env_dy,
        args.object_name,
        args.ground_material_usd,
        args.data_collection_flag,
        args.record_video_flag,
        args.object_random_init,
        pos_index,
    )

    if args.data_collection_flag:
        simulation_app.close()
        sys.exit(0 if ok else 1)
    else:
        while simulation_app.is_running():
            simulation_app.update()

simulation_app.close()