"""
Real robot deployment script for H1 humanoid robot.
Handles policy execution, motion playback, and data recording.
"""

import os
import argparse
from collections import deque
from datetime import datetime
import time

import numpy as np
import torch
import faulthandler
import joblib
import h5py
from scipy.interpolate import interp1d
import transforms3d as t3d

import rclpy
from rclpy.node import Node
from unitree_hg.msg import (
    LowState,
    MotorState,
    IMUState,
    LowCmd,
    MotorCmd,
)

import mujoco
import mujoco.viewer

from crc import CRC
from gamepad import Gamepad, parse_remote_data


# Hardware configuration
HW_DOF = 27

# Configuration flags
WALK_STRAIGHT = False
LOG_DATA = True
USE_GRIPPPER = False
NO_MOTOR = False
DEBUG = False
SIM = False

# File paths
HUMANOID_XML = 'h1_2.xml'
FILE_PATH = 'locomotions'
FILE_LIST = os.listdir(FILE_PATH)

# CRC instance for message checksums
crc = CRC()


class H1_2:
    """H1 humanoid robot environment configuration."""
    
    def __init__(self, task='stand'):
        """Initialize H1_2 environment.
        
        Args:
            task: Task type ('stand', 'datacollect', etc.)
        """
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.task = task

        # Environment dimensions
        self.num_envs = 1
        self.num_observations = 76
        self.num_actions = 12
        self.num_privileged_obs = None
        self.obs_context_len = 6
        
        # Observation scaling factors
        self.scale_lin_vel = 2.0
        self.scale_ang_vel = 0.25
        self.scale_orn = 1.0
        self.scale_dof_pos = 1.0
        self.scale_dof_vel = 0.05
        self.scale_action = 0.25
        
        # Gait configuration
        self.cycle_time = 0.64
        self.gait_indices = torch.zeros(
            self.num_envs, dtype=torch.float, device=self.device, requires_grad=False
        )
        
        # PD controller gains for each joint
        hip_pgain = 200.
        hip_dgain = 2.5
        hip_pitch_pgain = 300.
        hip_pitch_dgain = 3.0
        knee_pgain = 300.
        knee_dgain = 4.
        ankle_pgain = 40.
        ankle_dgain = 2.0
        waist_pgain = 300.
        waist_dgain = 5.
        
        shoulder_pgain = 99.2
        shoulder_dgain = 3.3
        elbow_pgain = 98.7
        elbow_dgain = 3.4
        wrist_roll_pgain = 46.9
        wrist_roll_dgain = 1.6
        wrist_pitch_pgain = 24
        wrist_pitch_dgain = 0.8
        wrist_yaw_pgain = 24
        wrist_yaw_dgain = 0.8


        # Assemble PD gains arrays (order: left leg, right leg, waist, left arm, right arm)
        self.p_gains = np.array([
            hip_pitch_pgain, hip_pgain, hip_pgain, knee_pgain, ankle_pgain, ankle_pgain,  # Left leg
            hip_pitch_pgain, hip_pgain, hip_pgain, knee_pgain, ankle_pgain, ankle_pgain,  # Right leg
            waist_pgain,  # Waist
            shoulder_pgain, shoulder_pgain, shoulder_pgain, elbow_pgain,  # Left arm
            wrist_roll_pgain, wrist_pitch_pgain, wrist_yaw_pgain,
            shoulder_pgain, shoulder_pgain, shoulder_pgain, elbow_pgain,  # Right arm
            wrist_roll_pgain, wrist_pitch_pgain, wrist_yaw_pgain
        ])
        self.d_gains = np.array([
            hip_pitch_dgain, hip_dgain, hip_dgain, knee_dgain, ankle_dgain, ankle_dgain,  # Left leg
            hip_pitch_dgain, hip_dgain, hip_dgain, knee_dgain, ankle_dgain, ankle_dgain,  # Right leg
            waist_dgain,  # Waist
            shoulder_dgain, shoulder_dgain, shoulder_dgain, elbow_dgain,  # Left arm
            wrist_roll_dgain, wrist_pitch_dgain, wrist_yaw_dgain,
            shoulder_dgain, shoulder_dgain, shoulder_dgain, elbow_dgain,  # Right arm
            wrist_roll_dgain, wrist_pitch_dgain, wrist_yaw_dgain
        ])
        
        # Joint position limits (radians)
        # Order: left leg, right leg, waist, left arm, right arm
        self.joint_limit_lo = [
            -0.43, -3.14, -0.43, -0.26, -np.inf, -np.inf,  # Left leg
            -0.43, -3.14, -3.14, -0.24, -np.inf, -np.inf,  # Right leg
            -2.618,  # Waist
            -3.0892, -1.5882, -2.618, -1.0472,  # Left arm
            -1.972222054, -1.614429558, -1.614429558,
            -3.0892, -2.2515, -2.618, -1.0472,  # Right arm
            -1.972222054, -1.614429558, -1.614429558
        ]
        self.joint_limit_hi = [
            0.43, 2.5, 3.14, 2.05, np.inf, np.inf,  # Left leg
            0.43, 2.5, 0.43, 2.0, np.inf, np.inf,  # Right leg
            2.618,  # Waist
            2.6704, 2.2515, 2.618, 2.0944,  # Left arm
            1.972222054, 1.614429558, 1.614429558,
            2.6704, 1.5882, 2.618, 2.0944,  # Right arm
            1.972222054, 1.614429558, 1.614429558
        ]
        self.soft_dof_pos_limit = 1.0
        
        # Apply soft limits to prevent hitting hard limits (exclude ankle roll joints: 4, 5, 10, 11)
        for i in range(len(self.joint_limit_lo)):
            if i not in [4, 5, 10, 11]:
                m = (self.joint_limit_lo[i] + self.joint_limit_hi[i]) / 2
                r = self.joint_limit_hi[i] - self.joint_limit_lo[i]
                self.joint_limit_lo[i] = m - 0.5 * r * self.soft_dof_pos_limit
                self.joint_limit_hi[i] = m + 0.5 * r * self.soft_dof_pos_limit
        
        # Default joint positions for standing pose
        self.default_dof_pos_np = np.array([
            0.0, -0.16, 0.0, 0.36, -0.2, 0,  # Left leg: hip yaw, pitch, roll, knee, ankle pitch, roll
            0.0, -0.16, 0.0, 0.36, -0.2, 0,  # Right leg: hip yaw, pitch, roll, knee, ankle pitch, roll
            0,  # Waist
            0.0, 0, 0, 0.0, 0, 0, 0,  # Left arm
            0.0, 0, 0, 0.0, 0, 0, 0,  # Right arm
        ])
        
        default_dof_pos = torch.tensor(
            self.default_dof_pos_np, dtype=torch.float, device=self.device, requires_grad=False
        )
        self.default_dof_pos = default_dof_pos.unsqueeze(0)

        # Initialize observation buffer with history
        self.obs_buf = torch.zeros(
            1, self.num_observations * self.obs_context_len,
            dtype=torch.float, device=self.device, requires_grad=False
        )
        self.obs_history = deque(maxlen=self.obs_context_len)
        for _ in range(self.obs_context_len):
            self.obs_history.append(torch.zeros(
                1, self.num_observations, dtype=torch.float, device=self.device
            ))
    
    def init_mujoco_viewer(self):
        """Initialize MuJoCo viewer for visualization."""
        self.mj_model = mujoco.MjModel.from_xml_path(HUMANOID_XML)
        self.mj_data = mujoco.MjData(self.mj_model)
        self.mj_model.opt.timestep = 0.001
        self.viewer = mujoco.viewer.launch_passive(self.mj_model, self.mj_data)


def pd_control(target_q, q, kp, target_dq, dq, kd):
    """Calculate PD control torques from position and velocity commands.
    
    Args:
        target_q: Target joint positions
        q: Current joint positions
        kp: Position gains
        target_dq: Target joint velocities
        dq: Current joint velocities
        kd: Velocity gains
        
    Returns:
        Control torques
    """
    return (target_q - q) * kp + (target_dq - dq) * kd


def quat_rotate_inverse(q, v):
    """Rotate vector v by the inverse of quaternion q.
    
    Args:
        q: Quaternion [w, x, y, z]
        v: Vector to rotate [x, y, z]
        
    Returns:
        Rotated vector in body frame
    """
    w = q[..., 0]
    x = q[..., 1]
    y = q[..., 2]
    z = q[..., 3]
    
    q_conj = np.array([w, -x, -y, -z])
    
    return np.array([
        v[0] * (q_conj[0]**2 + q_conj[1]**2 - q_conj[2]**2 - q_conj[3]**2) +
        v[1] * 2 * (q_conj[1] * q_conj[2] - q_conj[0] * q_conj[3]) +
        v[2] * 2 * (q_conj[1] * q_conj[3] + q_conj[0] * q_conj[2]),
        
        v[0] * 2 * (q_conj[1] * q_conj[2] + q_conj[0] * q_conj[3]) +
        v[1] * (q_conj[0]**2 - q_conj[1]**2 + q_conj[2]**2 - q_conj[3]**2) +
        v[2] * 2 * (q_conj[2] * q_conj[3] - q_conj[0] * q_conj[1]),
        
        v[0] * 2 * (q_conj[1] * q_conj[3] - q_conj[0] * q_conj[2]) +
        v[1] * 2 * (q_conj[2] * q_conj[3] + q_conj[0] * q_conj[1]) +
        v[2] * (q_conj[0]**2 - q_conj[1]**2 - q_conj[2]**2 + q_conj[3]**2)
    ])


def get_gravity_orientation(quat):
    """Get gravity vector in body frame.
    
    Args:
        quat: Quaternion representing body orientation
        
    Returns:
        Gravity vector in body frame
    """
    gravity_vec = np.array([0.0, 0.0, -1.0])
    return quat_rotate_inverse(quat, gravity_vec)



# Global variables for motion data
global_dof_data = None
global_motion_info = None
count = 0  # Current motion file index

# Global state for automatic route control
global_route_state = {
    'mode': 'manual',  # 'manual' or 'auto'
    'current_phase': 'idle',  # Phase: 'idle', 'forward_accel', 'forward_const', 'forward_decel',
                              # 'pause1', 'backward_accel', 'backward_const', 'backward_decel', 'pause2'
    'phase_timer': 0.0,
    'target_velocity': 0.0,
    'current_velocity': 0.0,
    'route_active': False,
    'route_cycle_count': 0,
    'current_pattern': 'A',  # 'A': forward-backward, 'B': backward-forward
    'continuous_mode': False  # True for continuous alternating cycles
}

def smooth_velocity_interpolation(current_vel, target_vel, max_accel=0.3, dt=0.01):
    """Smoothly interpolate velocity with acceleration limit.
    
    Args:
        current_vel: Current velocity
        target_vel: Target velocity
        max_accel: Maximum acceleration
        dt: Time step
        
    Returns:
        Interpolated velocity
    """
    vel_diff = target_vel - current_vel
    max_change = max_accel * dt
    
    if abs(vel_diff) <= max_change:
        return target_vel
    else:
        return current_vel + np.sign(vel_diff) * max_change


def update_automatic_route(dt=0.01, self=None):
    """Update automatic route state and velocity commands.
    
    Args:
        dt: Time step
        self: DeployNode instance (optional, for auto-save functionality)
        
    Returns:
        Velocity command [vx, vy, yaw]
    """
    global global_route_state
    
    if global_route_state['mode'] != 'auto' or not global_route_state['route_active']:
        return np.array([0, 0, 0], dtype=np.float32)
    
    # Route timing parameters (seconds)
    accel_time = 2.0
    const_time = 0.5
    decel_time = 2.5
    pause_time = 2.5
    
    max_forward_velocity = 0.3
    max_backward_velocity = 0.6
    
    global_route_state['phase_timer'] += dt
    
    # Phase state machine
    if global_route_state['current_phase'] == 'idle':
        global_route_state['target_velocity'] = 0.0
        global_route_state['current_velocity'] = 0.0
        
    elif global_route_state['current_phase'] == 'forward_accel':
        progress = global_route_state['phase_timer'] / accel_time
        global_route_state['target_velocity'] = min(progress * max_forward_velocity, max_forward_velocity)
        
        if global_route_state['phase_timer'] >= accel_time:
            global_route_state['current_phase'] = 'forward_const'
            global_route_state['phase_timer'] = 0.0
            
    elif global_route_state['current_phase'] == 'forward_const':
        global_route_state['target_velocity'] = max_forward_velocity
        
        if global_route_state['phase_timer'] >= const_time:
            global_route_state['current_phase'] = 'forward_decel'
            global_route_state['phase_timer'] = 0.0
            
    elif global_route_state['current_phase'] == 'forward_decel':
        progress = global_route_state['phase_timer'] / decel_time
        global_route_state['target_velocity'] = max_forward_velocity * (1.0 - progress)
        
        if global_route_state['phase_timer'] >= decel_time:
            global_route_state['current_phase'] = 'pause1'
            global_route_state['phase_timer'] = 0.0
            
    elif global_route_state['current_phase'] == 'pause1':
        global_route_state['target_velocity'] = 0.0
        
        if global_route_state['phase_timer'] >= pause_time:
            global_route_state['current_phase'] = 'backward_accel'
            global_route_state['phase_timer'] = 0.0
            
    elif global_route_state['current_phase'] == 'backward_accel':
        progress = global_route_state['phase_timer'] / accel_time
        global_route_state['target_velocity'] = -min(progress * max_backward_velocity, max_backward_velocity)
        
        if global_route_state['phase_timer'] >= accel_time:
            global_route_state['current_phase'] = 'backward_const'
            global_route_state['phase_timer'] = 0.0
            
    elif global_route_state['current_phase'] == 'backward_const':
        global_route_state['target_velocity'] = -max_backward_velocity
        
        if global_route_state['phase_timer'] >= const_time:
            global_route_state['current_phase'] = 'backward_decel'
            global_route_state['phase_timer'] = 0.0
            
    elif global_route_state['current_phase'] == 'backward_decel':
        progress = global_route_state['phase_timer'] / decel_time
        global_route_state['target_velocity'] = -max_backward_velocity * (1.0 - progress)
        
        if global_route_state['phase_timer'] >= decel_time:
            global_route_state['current_phase'] = 'pause2'
            global_route_state['phase_timer'] = 0.0
            
    elif global_route_state['current_phase'] == 'pause2':
        global_route_state['target_velocity'] = 0.0
        
        if global_route_state['phase_timer'] >= pause_time:
            global_route_state['route_cycle_count'] += 1
            print(f'Route cycle {global_route_state["route_cycle_count"]} completed '
                  f'(Pattern {global_route_state["current_pattern"]})')
            
            if global_route_state['continuous_mode']:
                # Switch to next pattern and continue
                if global_route_state['current_pattern'] == 'A':
                    global_route_state['current_pattern'] = 'B'
                    global_route_state['current_phase'] = 'backward_accel'
                    print('Switching to Pattern B: Backward -> Forward')
                else:
                    global_route_state['current_pattern'] = 'A'
                    global_route_state['current_phase'] = 'forward_accel'
                    print('Switching to Pattern A: Forward -> Backward')
                global_route_state['phase_timer'] = 0.0
            else:
                # Return to idle (single cycle mode)
                global_route_state['current_phase'] = 'idle'
                global_route_state['phase_timer'] = 0.0
    
    # Smooth velocity interpolation
    global_route_state['current_velocity'] = smooth_velocity_interpolation(
        global_route_state['current_velocity'], 
        global_route_state['target_velocity']
    )
    
    # Return velocity command: [vx, vy, yaw]
    return np.array([global_route_state['current_velocity'], 0.0, 0.0], dtype=np.float32)

def start_automatic_route():
    """Start the automatic route sequence with continuous alternating patterns"""
    global global_route_state
    global_route_state['mode'] = 'auto'
    global_route_state['route_active'] = True
    global_route_state['continuous_mode'] = True
    global_route_state['current_pattern'] = 'A'  # Start with pattern A
    global_route_state['current_phase'] = 'forward_accel'
    global_route_state['phase_timer'] = 0.0
    global_route_state['target_velocity'] = 0.0
    global_route_state['current_velocity'] = 0.0
    print('Continuous automatic route started!')
    print('Pattern A: Forward -> Pause -> Backward -> Pause -> Switch to Pattern B')
    print('Pattern B: Backward -> Pause -> Forward -> Pause -> Switch to Pattern A')
    print('Press X again to stop continuous mode')

def stop_automatic_route():
    """Stop the automatic route and return to manual mode"""
    global global_route_state
    global_route_state['mode'] = 'manual'
    global_route_state['route_active'] = False
    global_route_state['continuous_mode'] = False
    global_route_state['current_phase'] = 'idle'
    global_route_state['phase_timer'] = 0.0
    global_route_state['target_velocity'] = 0.0
    global_route_state['current_velocity'] = 0.0
    global_route_state['current_pattern'] = 'A'  # Reset to pattern A
    print('Automatic route stopped, returned to manual mode')

def save_data_automatically(self):
    """Automatically save data when route cycle completes."""
    if not self.is_recording:
        return
    
    motion_name = FILE_LIST[count] if self.motion_info is not None else "unknown_motion"
    motion_name = os.path.splitext(motion_name)[0]
    current_time = datetime.now().strftime('_%Y-%m-%d_%H-%M-%S')
    save_path = f'data_output/{motion_name}_auto_route{current_time}.h5'
    
    os.makedirs('data_output', exist_ok=True)
    
    with h5py.File(save_path, 'w') as f:
        f.create_dataset('command_time_list', data=np.array(self.time_hist))
        f.create_dataset('command_val_list', data=np.array(self.action_hist))
        f.create_dataset('motion_name', data=motion_name.encode('utf-8'))
        f.create_dataset('current_time', data=current_time.encode('utf-8'))
        
        g = f.create_group('robot')
        g.create_dataset('joint_time_list', data=np.array(self.time_hist))
        g.create_dataset('joint_angle_list', data=np.array(self.dof_pos_hist))
        g.create_dataset('joint_velocity_list', data=np.array(self.dof_vel_hist))
        g.create_dataset('joint_current_list', data=np.array(self.tau_hist))
        g.create_dataset('joint_temperature_list', data=np.array(self.temp_hist))
        g.create_dataset('imu_list', data=np.array(self.imu_hist))
        g.create_dataset('ang_vel_list', data=np.array(self.ang_vel_hist))
    
    print(f"Auto-saved data to {save_path}")
    
    # Clear data for next recording
    self.time_hist.clear()
    self.action_hist.clear()
    self.dof_pos_hist.clear()
    self.dof_vel_hist.clear()
    self.imu_hist.clear()
    self.ang_vel_hist.clear()
    self.tau_hist.clear()
    self.obs_hist.clear()
    self.temp_hist.clear()


def load_reference_motions():
    """Load upper body motion data from pickle file and interpolate to target FPS.
    
    Returns:
        tuple: (interpolated_dof_data, motion_info) or (None, None) on error
    """
    global count
    
    motion_file_path = os.path.join(FILE_PATH, FILE_LIST[count])
    if not os.path.exists(motion_file_path):
        print(f"Warning: Motion file not found at {motion_file_path}")
        return None, None
    
    with open(motion_file_path, 'rb') as f:
        motion_data = joblib.load(f)

    try:
        first_key = list(motion_data.keys())[0]
        dof_data = motion_data[first_key]['dof']
        fps = motion_data[first_key]['fps']
        
        # Interpolate from original FPS to 50Hz
        target_fps = 50
        original_frames = dof_data.shape[0]
        target_frames = int(original_frames * target_fps / fps)
        
        original_time = np.linspace(0, (original_frames - 1) / fps, original_frames)
        target_time = np.linspace(0, (original_frames - 1) / fps, target_frames)
        
        interpolated_dof_data = np.zeros((target_frames, dof_data.shape[1]), dtype=np.float32)
        for dof_idx in range(dof_data.shape[1]):
            interpolator = interp1d(
                original_time, dof_data[:, dof_idx], kind='linear',
                bounds_error=False, fill_value='extrapolate'
            )
            interpolated_dof_data[:, dof_idx] = interpolator(target_time)
        
        motion_info = {
            'total_frames': target_frames,
            'fps': target_fps,
            'frame_dt': 1.0 / target_fps,
            'frame_counter': 0,
            'motion_playing': False,
            'motion_finished': False,
            'return_to_zero': False,
            'return_counter': 0,
            'return_duration': 1.0,
            'return_steps': 50,
            'last_motion_frame': None,
            'file_num': count,
        }
        
        print(f"Original motion: {original_frames} frames at {fps}Hz")
        print(f"Interpolated to: {target_frames} frames at {target_fps}Hz")
        print(f"Motion duration: {(target_frames - 1) / target_fps:.3f}s")
        print(f"Return to zero duration: {motion_info['return_duration']}s "
              f"({motion_info['return_steps']} steps)")
        
        return interpolated_dof_data, motion_info

    except Exception as e:
        print(f"Error loading motion file: {e}")
        return None, None

def get_upper_body_actions(dof_data, motion_info, default_upper_pos, action_scale):
    """Get upper body actions for the current timestep.
    
    Args:
        dof_data: Motion data array (num_frames, num_dofs)
        motion_info: Motion information dictionary
        default_upper_pos: Default positions for upper body joints (15 dims)
        action_scale: Action scaling factor (unused, kept for compatibility)
        
    Returns:
        numpy.ndarray: Upper body actions (15 dims)
    """
    if motion_info is None or dof_data is None:
        return np.zeros(15, dtype=np.float32)
    
    if not motion_info['motion_playing']:
        return np.zeros(15, dtype=np.float32)
    
    # Return to zero phase after motion completion
    if motion_info['motion_finished'] and motion_info['return_to_zero']:
        motion_info['return_counter'] += 1
        
        if motion_info['return_counter'] >= motion_info['return_steps']:
            motion_info['return_to_zero'] = False
            motion_info['motion_playing'] = False
            motion_info['motion_finished'] = False
            motion_info['frame_counter'] = 0
            print("Return to zero complete, motion stopped")
            return np.zeros(15, dtype=np.float32)
        
        progress = motion_info['return_counter'] / motion_info['return_steps']
        if motion_info['last_motion_frame'] is not None:
            current_target = motion_info['last_motion_frame'] * (1 - progress)
            upper_body_actions = (current_target - default_upper_pos) * 0.7
            return upper_body_actions
        else:
            return np.zeros(15, dtype=np.float32)
    
    # Play motion sequence
    current_frame_dof = dof_data[motion_info['frame_counter']]
    motion_info['frame_counter'] += 1
    
    if motion_info['frame_counter'] >= motion_info['total_frames']:
        motion_info['motion_finished'] = True
        motion_info['return_to_zero'] = True
        motion_info['return_counter'] = 0
        motion_info['last_motion_frame'] = current_frame_dof[12:27]
        print("Motion sequence complete, returning to zero position")
        return np.zeros(15, dtype=np.float32)
    
    # Extract upper body DOFs (indices 12-26)
    upper_body_dof = current_frame_dof[12:27]
    upper_body_actions = (upper_body_dof - default_upper_pos) * 0.7
    return upper_body_actions
    


class DeployNode(Node):
    """ROS2 node for deploying policy on real robot."""
    
    class WirelessButtons:
        """Wireless controller button bit masks."""
        R1 = 0b00000001
        L1 = 0b00000010
        start = 0b00000100
        select = 0b00001000
        R2 = 0b00010000
        L2 = 0b00100000
        F1 = 0b01000000
        F2 = 0b10000000
        A = 0b100000000
        B = 0b1000000000
        X = 0b10000000000
        Y = 0b100000000000
        up = 0b1000000000000
        right = 0b10000000000000
        down = 0b100000000000000
        left = 0b1000000000000000

    def __init__(self, task='stand'):
        """Initialize deployment node.
        
        Args:
            task: Task type ('stand', 'datacollect', etc.)
        """
        super().__init__("deploy_node")  # type: ignore
        
        # Initialize subscribers and publishers
        self.lowlevel_state_sub = self.create_subscription(
            LowState, "lowstate", self.lowlevel_state_cb, 1
        )
        self.lowlevel_state_sub  # Prevent unused variable warning

        self.low_state = LowState()
        self.joint_pos = np.zeros(HW_DOF)
        self.joint_vel = np.zeros(HW_DOF)

        # Motor command publisher
        self.motor_pub = self.create_publisher(LowCmd, "lowcmd_buffer", 1)
        self.motor_pub_freq = 50
        self.dt = 1 / self.motor_pub_freq

        # Initialize motor command message
        self.cmd_msg = LowCmd()
        self.cmd_msg.mode_pr = 0
        self.cmd_msg.mode_machine = 6

        # Initialize motor commands (27 active motors + 8 dummy)
        self.motor_cmd = []
        for id in range(HW_DOF):
            cmd = MotorCmd(q=0.0, dq=0.0, tau=0.0, kp=0.0, kd=0.0, mode=1, reserve=0)
            self.motor_cmd.append(cmd)
        for id in range(HW_DOF, 35):
            cmd = MotorCmd(q=0.0, dq=0.0, tau=0.0, kp=0.0, kd=0.0, mode=0, reserve=0)
            self.motor_cmd.append(cmd)
        self.cmd_msg.motor_cmd = self.motor_cmd.copy()

        # Initialize policy
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.init_policy()
        self.prev_action = np.zeros(self.env.num_actions)
        self.start_policy = False
        
        if DEBUG:
            self.env.init_mujoco_viewer()
            self.env.mj_data.qpos[7:] = self.angles
            self.env.mj_data.qpos[:3] = [0, 0, 1.03]
            mujoco.mj_forward(self.env.mj_model, self.env.mj_data)

            tau = pd_control(
                self.angles,
                self.env.mj_data.qpos[7:],
                self.env.p_gains,
                np.zeros(self.env.num_actions + 15),
                self.env.mj_data.qvel[6:],
                self.env.d_gains
            )
            self.env.mj_data.ctrl[:] = tau
            mujoco.mj_step(self.env.mj_model, self.env.mj_data)
            self.env.viewer.sync()
        
        # Standing up sequence
        self.get_logger().info("Standing up")
        self.stand_up = True

        # Velocity command configuration
        self.lin_vel_deadband = 0.1
        self.ang_vel_deadband = 0.1
        self.move_by_wireless_remote = True
        self.cmd_px_range = [0.1, 0.4]
        self.cmd_nx_range = [0.1, 0.4]
        self.cmd_py_range = [0.1, 0.5]
        self.cmd_ny_range = [0.1, 0.5]
        self.cmd_pyaw_range = [0.2, 1.0]
        self.cmd_nyaw_range = [0.2, 1.0]
        self.commands = np.array([0, 0, 0, 0, 0.95], dtype=np.float32)
        self.commands_scale = np.array([2.0, 2.0, 0.25], dtype=np.float32)

        # Upper body motion control
        self.start_load_upper_body_motion = False
        self.upper_body_motion_start_time = None
        self.dof_data = None
        self.motion_info = None
        self.is_recording = False
        
        # Initialize global motion data
        global global_dof_data, global_motion_info
        if global_dof_data is None or global_motion_info is None:
            global_dof_data, global_motion_info = load_reference_motions()
            if global_dof_data is not None and global_motion_info is not None:
                self.dof_data = global_dof_data
                self.motion_info = global_motion_info
                print(f"Initial motion loaded: {FILE_LIST[global_motion_info['file_num']]}")
            else:
                print("Failed to load initial motion data")

        # Log control instructions
        self.start_time = time.monotonic()
        self.get_logger().info("Press L1 for start policy")
        self.get_logger().info("Press L2 to emergent stop")
        self.get_logger().info("Press A to start upper body motion")
        self.get_logger().info("Press B to switch motion file")
        self.get_logger().info("Press R2 to reset motion")
        self.get_logger().info("Press X to start/stop automatic lower body route")
        self.get_logger().info("Automatic route: Pattern A (Forward->Backward) -> Pattern B (Backward->Forward)")
        self.get_logger().info("Data will be auto-saved after each route cycle")
        if self.motion_info is not None:
            self.get_logger().info(f"Current motion: {FILE_LIST[self.motion_info['file_num']]}")
        
        # Data recording buffers
        self.init_buffer = 0
        self.foot_contact_buffer = []
        self.time_hist = []
        self.obs_time_hist = []
        self.angle_hist = []
        self.action_hist = []
        self.dof_pos_hist = []
        self.dof_vel_hist = []
        self.imu_hist = []
        self.ang_vel_hist = []
        self.foot_contact_hist = []
        self.tau_hist = []
        self.obs_hist = []
        self.temp_hist = []

        # Command and observation setup
        self.xyyaw_command = np.array([0, 0., 0.], dtype=np.float32)
        self.up_axis_idx = 2  # Z-axis is up
        self.gravity_vec = torch.zeros((1, 3), device=self.device, dtype=torch.float32)
        self.gravity_vec[:, self.up_axis_idx] = -1

        self.episode_length_buf = torch.zeros(1, device=self.device, dtype=torch.long)
        self.phase = torch.zeros(1, device=self.device, dtype=torch.float)

        self.Emergency_stop = False
        self.stop = False

        self.gamepad = Gamepad()
        time.sleep(0.02)

        
    def _joy_stick_callback(self, msg):
        """Legacy joystick callback (unused)."""
        global count
        self.joy_stick_buffer = msg
        if self.move_by_wireless_remote:
            # left-y for forward/backward
            ly = msg.ly
            if ly > self.lin_vel_deadband:
                vx = (ly - self.lin_vel_deadband) / (1 - self.lin_vel_deadband) # (0, 1)
                vx = vx * (self.cmd_px_range[1] - self.cmd_px_range[0]) + self.cmd_px_range[0]
            elif ly < -self.lin_vel_deadband:
                vx = (ly + self.lin_vel_deadband) / (1 - self.lin_vel_deadband) # (-1, 0)
                vx = vx * (self.cmd_nx_range[1] - self.cmd_nx_range[0]) - self.cmd_nx_range[0]
            else:
                vx = 0
            # left-x for turning left/right
            rx = -msg.rx
            if rx > self.ang_vel_deadband:
                yaw = (rx - self.ang_vel_deadband) / (1 - self.ang_vel_deadband)
                yaw = yaw * (self.cmd_pyaw_range[1] - self.cmd_pyaw_range[0]) + self.cmd_pyaw_range[0]
            elif rx < -self.ang_vel_deadband:
                yaw = (rx + self.ang_vel_deadband) / (1 - self.ang_vel_deadband)
                yaw = yaw * (self.cmd_nyaw_range[1] - self.cmd_nyaw_range[0]) - self.cmd_nyaw_range[0]
            else:
                yaw = 0
            # right-x for side moving left/right
            lx = -msg.lx
            if lx > self.lin_vel_deadband:
                vy = (lx - self.lin_vel_deadband) / (1 - self.lin_vel_deadband)
                vy = vy * (self.cmd_py_range[1] - self.cmd_py_range[0]) + self.cmd_py_range[0]
            elif lx < -self.lin_vel_deadband:
                vy = (lx + self.lin_vel_deadband) / (1 - self.lin_vel_deadband)
                vy = vy * (self.cmd_ny_range[1] - self.cmd_ny_range[0]) - self.cmd_ny_range[0]
            else:
                vy = 0
            self.xyyaw_command = np.array([vx, vy, yaw], dtype= np.float32)
            print(self.xyyaw_command)

        # refer to Unitree Remote Control data structure, msg.keys is a bit mask
        # 00000000 00000001 means pressing the 0-th button (R1)
        # 00000000 00000010 means pressing the 1-th button (L1)
        # 10000000 00000000 means pressing the 15-th button (left)

        if msg.keys & self.WirelessButtons.L2:
            if self.stand_up:
                self.get_logger().info("Start policy")
                self.start_policy = True
                self.policy_start_time = time.monotonic()
            else:
                self.get_logger().info("Wait for standing up first")

        if msg.keys & self.WirelessButtons.L1:
            self.get_logger().info("Emergency stop")
            self.set_gains(np.array([0.0]*HW_DOF), self.env.d_gains)
            self.set_motor_position(q=self.env.default_dof_pos_np)
            if LOG_DATA and self.is_recording:
                print("Emergency stop during data recording - data will be saved on next B press")
            Warning("Emergency stop")
            self.Emergency_stop = True
        
        if msg.keys & self.WirelessButtons.R1:
            self.get_logger().info("Program exiting")
            self.stop = True
        
        if msg.keys & self.WirelessButtons.A:
            # start upper body motion playback and data recording
            if not self.start_load_upper_body_motion:
                self.get_logger().info("Starting upper body motion playback and data recording")
                self.start_load_upper_body_motion = True
                self.upper_body_motion_start_time = time.monotonic()
                self.is_recording = True
                
                # Clear previous data
                self.time_hist.clear()
                self.action_hist.clear()
                self.dof_pos_hist.clear()
                self.dof_vel_hist.clear()
                self.imu_hist.clear()
                self.ang_vel_hist.clear()
                self.tau_hist.clear()
                self.obs_hist.clear()
                self.temp_hist.clear()
                
                # Load motion data if not already loaded
                if self.dof_data is None or self.motion_info is None:
                    self.dof_data, self.motion_info = load_reference_motions()
                    print(self.dof_data)
                    if self.dof_data is not None and self.motion_info is not None:
                        # Start motion playback
                        self.motion_info['motion_playing'] = True
                        self.motion_info['frame_counter'] = 0
                        print("Motion playback started!")
                    else:
                        self.get_logger().warning("Failed to load motion data")
                        self.start_load_upper_body_motion = False
                        self.is_recording = False
                else:
                    # Restart motion playback
                    self.motion_info['motion_playing'] = True
                    self.motion_info['frame_counter'] = 0
                    print("Motion playback restarted!")
            else:
                self.get_logger().info("Motion already playing or loading")
        
        if msg.keys & self.WirelessButtons.B:
            # Save data and change to next motion file
            if self.is_recording and self.start_load_upper_body_motion:
                # Save current data before switching
                self.get_logger().info("Saving data and switching motion")
                
                motion_name = FILE_LIST[count] if self.motion_info is not None else "unknown_motion"
                motion_name = os.path.splitext(motion_name)[0]
                current_time = datetime.now().strftime('_%Y-%m-%d_%H-%M-%S')
                save_path = f'data_output/{motion_name}{current_time}.h5'
                
                os.makedirs('data_output', exist_ok=True)
                
                with h5py.File(save_path, 'w') as f:
                    f.create_dataset('command_time_list', data=np.array(self.time_hist))
                    f.create_dataset('command_val_list', data=np.array(self.action_hist))
                    f.create_dataset('motion_name', data=motion_name.encode('utf-8'))
                    f.create_dataset('current_time', data=current_time.encode('utf-8'))
                    
                    g = f.create_group('robot')
                    g.create_dataset('joint_time_list', data=np.array(self.time_hist))
                    g.create_dataset('joint_angle_list', data=np.array(self.dof_pos_hist))
                    g.create_dataset('joint_velocity_list', data=np.array(self.dof_vel_hist))
                    g.create_dataset('joint_current_list', data=np.array(self.tau_hist))
                    g.create_dataset('joint_temperature_list', data=np.array(self.temp_hist))
                    
                    g.create_dataset('imu_list', data=np.array(self.imu_hist))
                    g.create_dataset('ang_vel_list', data=np.array(self.ang_vel_hist))
                    g.create_dataset('obs_list', data=np.array(self.obs_hist))
                
                print(f"File saved to {save_path}")
                
                # Stop recording and motion
                self.is_recording = False
                self.start_load_upper_body_motion = False
                if self.motion_info is not None:
                    self.motion_info['motion_playing'] = False
            
            # Change to next motion file
            count = (count + 1) % len(FILE_LIST)
            
            # Load new motion data
            new_dof_data, new_motion_info = load_reference_motions()
            if new_dof_data is not None and new_motion_info is not None:
                # Update the global variables
                global global_dof_data, global_motion_info
                global_dof_data = new_dof_data
                global_motion_info = new_motion_info
                global_motion_info['file_num'] = count  # Preserve the file number
                
                # Update local variables
                self.dof_data = global_dof_data
                self.motion_info = global_motion_info
                
                print(f"Switched to motion: {FILE_LIST[count]}")
                print(f"Motion file: {FILE_LIST[count]}")
            else:
                print(f"Failed to load motion file: {FILE_LIST[count]}")
                # Revert to previous file if loading failed
                count = (count - 1) % len(FILE_LIST)
        
        if msg.keys & self.WirelessButtons.R2:
            # Reset motion playback
            if self.motion_info is not None:
                self.motion_info['motion_playing'] = False
                self.motion_info['motion_finished'] = False
                self.motion_info['return_to_zero'] = False
                self.motion_info['return_counter'] = 0
                self.motion_info['frame_counter'] = 0
                self.start_load_upper_body_motion = False
                self.is_recording = False
                print("Motion reset to beginning")

    def lowlevel_state_cb(self, msg: LowState):
        """Callback for low-level state messages from robot."""
        global count
        joystick_data = msg.wireless_remote
        parsed_data = parse_remote_data(joystick_data)
        self.gamepad.update(parsed_data)
        
        if self.gamepad.L1.pressed:
            print('Policy start!')
            self.start_policy = True
        if self.gamepad.L2.pressed:
            self.start_policy = False
            self.Emergency_stop = True
            print('Manual emergency stop!!!')
        if self.gamepad.R1.pressed:
            self.get_logger().info("Program exiting")
            self.stop = True
        
        if self.gamepad.A.pressed:
            # start upper body motion playback and data recording
            if not self.start_load_upper_body_motion:
                self.get_logger().info("Starting upper body motion playback and data recording")
                self.start_load_upper_body_motion = True
                self.upper_body_motion_start_time = time.monotonic()
                self.is_recording = True
                
                # Clear previous data
                self.time_hist.clear()
                self.action_hist.clear()
                self.dof_pos_hist.clear()
                self.dof_vel_hist.clear()
                self.imu_hist.clear()
                self.ang_vel_hist.clear()
                self.tau_hist.clear()
                self.obs_hist.clear()
                self.temp_hist.clear()
                
                if self.dof_data is None or self.motion_info is None:
                    self.dof_data, self.motion_info = load_reference_motions()
                    if self.dof_data is not None and self.motion_info is not None:
                        # Start motion playback
                        self.motion_info['motion_playing'] = True
                        self.motion_info['frame_counter'] = 0
                        print("Motion playback started!")
                    else:
                        self.get_logger().warning("Failed to load motion data")
                        self.start_load_upper_body_motion = False
                        self.is_recording = False
                else:
                    # Restart motion playback
                    self.motion_info['motion_playing'] = True
                    self.motion_info['frame_counter'] = 0
                    print("Motion playback restarted!")
            else:
                self.get_logger().info("Motion already playing or loading")

            start_automatic_route()

            
        if self.gamepad.B.pressed:
            # Save data and change to next motion file
            if self.is_recording and self.start_load_upper_body_motion:
                # Save current data before switching
                self.get_logger().info("Saving data and switching motion")
                
                motion_name = FILE_LIST[count] if self.motion_info is not None else "unknown_motion"
                motion_name = os.path.splitext(motion_name)[0]
                current_time = datetime.now().strftime('_%Y-%m-%d_%H-%M-%S')
                save_path = f'data_output/{motion_name}{current_time}.h5'
                
                os.makedirs('data_output', exist_ok=True)
                
                with h5py.File(save_path, 'w') as f:
                    f.create_dataset('command_time_list', data=np.array(self.time_hist))
                    f.create_dataset('command_val_list', data=np.array(self.action_hist))
                    f.create_dataset('motion_name', data=motion_name.encode('utf-8'))
                    f.create_dataset('current_time', data=current_time.encode('utf-8'))
                    
                    # Create robot group
                    g = f.create_group('robot')
                    g.create_dataset('joint_time_list', data=np.array(self.time_hist))
                    g.create_dataset('joint_angle_list', data=np.array(self.dof_pos_hist))
                    g.create_dataset('joint_velocity_list', data=np.array(self.dof_vel_hist))
                    g.create_dataset('joint_current_list', data=np.array(self.tau_hist))
                    g.create_dataset('joint_temperature_list', data=np.array(self.temp_hist))
                    
                    g.create_dataset('imu_list', data=np.array(self.imu_hist))
                    g.create_dataset('ang_vel_list', data=np.array(self.ang_vel_hist))
                    # g.create_dataset('obs_list', data=np.array(self.obs_hist),dtype=object)
                
                print(f"Save data to {save_path}")
                
                # Stop recording and motion
                self.is_recording = False
                self.start_load_upper_body_motion = False
                if self.motion_info is not None:
                    self.motion_info['motion_playing'] = False
            
            # Change to next motion file
            count = (count + 1) % len(FILE_LIST)
            
            # Load new motion data
            new_dof_data, new_motion_info = load_reference_motions()
            if new_dof_data is not None and new_motion_info is not None:
                # Update the global variables
                global global_dof_data, global_motion_info
                global_dof_data = new_dof_data
                global_motion_info = new_motion_info
                global_motion_info['file_num'] = count  # Preserve the file number
                
                # Update local variables
                self.dof_data = global_dof_data
                self.motion_info = global_motion_info
                
                print(f"Switched to motion: {FILE_LIST[count]}")
                print(f"Motion file: {FILE_LIST[count]}")
            else:
                print(f"Failed to load motion file: {FILE_LIST[count]}")
                # Revert to previous file if loading failed
                count = (count - 1) % len(FILE_LIST)
        
        if self.gamepad.R2.pressed:
            # Reset motion playback
            if self.motion_info is not None:
                self.motion_info['motion_playing'] = False
                self.motion_info['motion_finished'] = False
                self.motion_info['return_to_zero'] = False
                self.motion_info['return_counter'] = 0
                self.motion_info['frame_counter'] = 0
                global_route_state['current_velocity'] = 0.0
                self.start_load_upper_body_motion = False
                stop_automatic_route()
                self.is_recording = False  # 停止记录数据
                print("Motion reset to beginning")


        if self.move_by_wireless_remote:
            # Process joystick input for velocity commands
            # Left stick Y-axis: forward/backward
            ly = self.gamepad.ly
            if ly > self.lin_vel_deadband:
                vx = (ly - self.lin_vel_deadband) / (1 - self.lin_vel_deadband)
                vx = vx * (self.cmd_px_range[1] - self.cmd_px_range[0]) + self.cmd_px_range[0]
            elif ly < -self.lin_vel_deadband:
                vx = (ly + self.lin_vel_deadband) / (1 - self.lin_vel_deadband)
                vx = vx * (self.cmd_nx_range[1] - self.cmd_nx_range[0]) - self.cmd_nx_range[0]
            else:
                vx = 0
            
            # Right stick X-axis: yaw rotation
            rx = -self.gamepad.rx
            if rx > self.ang_vel_deadband:
                yaw = (rx - self.ang_vel_deadband) / (1 - self.ang_vel_deadband)
                yaw = yaw * (self.cmd_pyaw_range[1] - self.cmd_pyaw_range[0]) + self.cmd_pyaw_range[0]
            elif rx < -self.ang_vel_deadband:
                yaw = (rx + self.ang_vel_deadband) / (1 - self.ang_vel_deadband)
                yaw = yaw * (self.cmd_nyaw_range[1] - self.cmd_nyaw_range[0]) - self.cmd_nyaw_range[0]
            else:
                yaw = 0
            
            # Left stick X-axis: lateral movement
            lx = -self.gamepad.lx
            if lx > self.lin_vel_deadband:
                vy = (lx - self.lin_vel_deadband) / (1 - self.lin_vel_deadband)
                vy = vy * (self.cmd_py_range[1] - self.cmd_py_range[0]) + self.cmd_py_range[0]
            elif lx < -self.lin_vel_deadband:
                vy = (lx + self.lin_vel_deadband) / (1 - self.lin_vel_deadband)
                vy = vy * (self.cmd_ny_range[1] - self.cmd_ny_range[0]) - self.cmd_ny_range[0]
            else:
                vy = 0
            self.xyyaw_command = np.array([vx, vy, yaw], dtype=np.float32)
        
        # Process IMU data
        imu_data = msg.imu_state
        self.imu_data = imu_data
        self.msg_tick = msg.tick/1000
        self.roll, self.pitch, self.yaw = imu_data.rpy
        self.obs_ang_vel = np.array(imu_data.gyroscope)*self.env.scale_ang_vel
        self.obs_imu = np.array([self.roll, self.pitch, self.yaw])*self.env.scale_orn
        self.obs_root_rot = np.array(imu_data.quaternion, dtype=np.float32)

        # termination condition
        r_threshold = abs(self.roll) > 0.6
        p_threshold = abs(self.pitch) > 0.6
        if r_threshold or p_threshold:
            self.get_logger().warning("Roll or pitch threshold reached")

        # motor data
        self.joint_tau = [msg.motor_state[i].tau_est for i in range(HW_DOF)]
        self.joint_pos = [msg.motor_state[i].q for i in range(HW_DOF)]
        self.obs_joint_pos = (np.array(self.joint_pos) - self.env.default_dof_pos_np) * self.env.scale_dof_pos
        joint_vel = [msg.motor_state[i].dq for i in range(HW_DOF)]
        self.obs_joint_vel = np.array(joint_vel) * self.env.scale_dof_vel
        self.joint_temp = [msg.motor_state[i].temperature for i in range(HW_DOF)]

        # Joint limit check
        if self.start_policy and (((np.array(self.joint_pos)-np.array(self.env.joint_limit_lo))<0).sum() >0 or ((np.array(self.joint_pos)-np.array(self.env.joint_limit_hi))>0).sum() > 0):
            print("Joint limit reached")
            print(self.joint_pos)
            print("Low limit Joint index: ", np.where((np.array(self.joint_pos)-np.array(self.env.joint_limit_lo))<0))
            print("High limit Joint index: ", np.where((np.array(self.joint_pos)-np.array(self.env.joint_limit_hi))>0))
            Warning("Emergency stop")
            self.Emergency_stop = True
    
    def lowlevel_state_mujoco(self):
        """Update state from MuJoCo simulation (for debugging)."""
        if DEBUG and self.start_policy and SIM:
            quat = self.env.mj_data.qpos[3:7]
            self.obs_ang_vel = np.array(self.env.mj_data.qvel[3:6]) * self.env.scale_ang_vel
            self.obs_root_rot = self.env.mj_data.qpos[3:7]
            
            euler = t3d.euler.quat2euler(quat)
            self.roll, self.pitch, self.yaw = euler[0], euler[1], euler[2]
            self.obs_imu = np.array([self.roll, self.pitch, self.yaw]) * self.env.scale_orn

            self.joint_pos = np.concatenate([self.env.mj_data.qpos[7:]])
            self.obs_joint_pos = (np.array(self.joint_pos) - self.env.default_dof_pos_np) * self.env.scale_dof_pos
            self.joint_vel = np.concatenate([self.env.mj_data.qvel[6:]])
            self.obs_joint_vel = np.array(self.joint_vel) * self.env.scale_dof_vel

    def set_gains(self, kp: np.ndarray, kd: np.ndarray):
        """Set PD controller gains for all motors.
        
        Args:
            kp: Position gains array
            kd: Velocity gains array
        """
        self.kp = kp
        self.kd = kd
        for i in range(HW_DOF):
            self.motor_cmd[i].kp = kp[i]
            self.motor_cmd[i].kd = kd[i]

    def set_motor_position(self, q: np.ndarray):
        """Set target positions for all motors.
        
        Args:
            q: Target joint positions array
        """
        for i in range(HW_DOF):
            self.motor_cmd[i].q = q[i]
        self.cmd_msg.motor_cmd = self.motor_cmd.copy()
        self.cmd_msg.crc = crc.Crc(self.cmd_msg)
    
    def init_policy(self):
        """Initialize policy and environment."""
        self.get_logger().info("Preparing policy")
        faulthandler.enable()

        self.env = H1_2(task='self.task')
        self.policy = torch.jit.load('72000_loco_policy.pt')
        self.policy.to(self.env.device)
        # Warm up policy (first inference takes longer)
        _ = self.policy(self.env.obs_buf.detach().reshape(1, -1))
        
        # Initialize motor commands to default positions
        for i in range(HW_DOF):
            self.motor_cmd[i].q = self.env.default_dof_pos[0][i].item()
            self.motor_cmd[i].dq = 0.0
            self.motor_cmd[i].tau = 0.0
            self.motor_cmd[i].kp = 0.0
            self.motor_cmd[i].kd = 0.0
        self.cmd_msg.motor_cmd = self.motor_cmd.copy()
        self.angles = self.env.default_dof_pos_np
    
    def get_walking_cmd_mask(self):
        """Determine if robot should be walking based on commands and state.
        
        Returns:
            Boolean mask indicating walking state
        """
        walking_mask0 = np.abs(self.xyyaw_command[0]) > 0.1
        walking_mask1 = np.abs(self.xyyaw_command[1]) > 0.1
        walking_mask2 = np.abs(self.xyyaw_command[2]) > 0.2
        walking_mask = walking_mask0 | walking_mask1 | walking_mask2

        walking_mask = walking_mask | (self.env.gait_indices.cpu() >= self.dt / self.env.cycle_time).numpy()[0]
        walking_mask |= np.logical_or(np.abs(self.obs_imu[1]) > 0.1, np.abs(self.obs_imu[0]) > 0.05)
        return walking_mask
    
    def _get_phase(self):
        """Get current gait phase."""
        return self.env.gait_indices
    
    def step_contact_targets(self):
        """Update gait phase indices."""
        cycle_time = self.env.cycle_time
        standing_mask = ~self.get_walking_cmd_mask()
        self.env.gait_indices = torch.remainder(self.env.gait_indices + self.dt / cycle_time, 1.0)
        if standing_mask:
            self.env.gait_indices[:] = 0
    
    def compute_observations(self):
        """Compute observations for policy."""
        imu_ang_vel = self.obs_ang_vel
        gravity_vec = get_gravity_orientation(self.obs_root_rot)
        
        # Update commands based on automatic route if active
        if global_route_state['mode'] == 'auto' and global_route_state['route_active']:
            auto_cmd = update_automatic_route(dt=self.dt, self=self)
            current_commands = self.commands.copy()
            current_commands[:3] = auto_cmd
            current_commands[2] = self.xyyaw_command[2]  # Preserve height command
        else:
            current_commands = self.commands.copy()
            current_commands[:3] = self.xyyaw_command
        
        # Build observation vector: [commands(3), height(1), ang_vel(3), gravity(3),
        #                            joint_pos(27), joint_vel(27), prev_action(12)]
        obs_buf = torch.tensor(
            np.concatenate((
                current_commands[:3] * self.commands_scale,
                current_commands[4, None],
                imu_ang_vel,
                gravity_vec,
                self.obs_joint_pos[:27],
                self.obs_joint_vel[:27],
                self.prev_action,
            ), axis=-1),
            dtype=torch.float, device=self.device
        ).unsqueeze(0)

        obs_now = obs_buf.clone()
        self.env.obs_history.append(obs_now)
        
        # Concatenate observation history
        obs_buf_all = torch.cat([
            self.env.obs_history[i] for i in range(self.env.obs_history.maxlen)
        ], dim=-1)
        
        self.env.obs_buf = obs_buf_all


    @torch.no_grad()
    def main_loop(self):
        """Main control loop."""
        # Stand up sequence
        _percent_1 = 0
        _duration_1 = 500
        firstRun = True
        init_success = False
        while self.stand_up and not self.start_policy:
            if firstRun:
                firstRun = False
                rclpy.spin_once(self)
                start_pos = self.joint_pos
            else:
                self.set_gains(kp=self.env.p_gains, kd=self.env.d_gains)
                if _percent_1 < 1:
                    self.set_motor_position(
                        q=(1 - _percent_1) * np.array(start_pos) +
                        _percent_1 * np.array(self.env.default_dof_pos_np)
                    )
                    _percent_1 += 1 / _duration_1
                    _percent_1 = min(1, _percent_1)
                if _percent_1 == 1 and not init_success:
                    init_success = True
                    print("---Initialized---")
                if not NO_MOTOR:
                    self.motor_pub.publish(self.cmd_msg)
                rclpy.spin_once(self)

        cnt = 0
        fps_ckt = time.monotonic()
        self.get_logger().info("start main loop")
        
        while rclpy.ok():
            loop_start_time = time.monotonic()
            
            if self.Emergency_stop:
                breakpoint()
            if self.stop:
                # Return to default pose before exiting
                _percent_1 = 0
                _duration_1 = 1000
                start_pos = self.joint_pos
                while _percent_1 < 1:
                    self.set_motor_position(
                        q=(1 - _percent_1) * np.array(start_pos) +
                        _percent_1 * np.array(self.env.default_dof_pos_np)
                    )
                    _percent_1 += 1 / _duration_1
                    _percent_1 = min(1, _percent_1)
                    if not NO_MOTOR:
                        self.motor_pub.publish(self.cmd_msg)
                self.get_logger().info("Program exit")
                break

            rclpy.spin_once(self, timeout_sec=0.001)

            if self.start_policy:
                # Record data if enabled
                if LOG_DATA and self.is_recording:
                    current_time = time.monotonic() - self.start_time
                    self.time_hist.append(current_time)
                    self.dof_pos_hist.append(self.obs_joint_pos)
                    self.dof_vel_hist.append(self.obs_joint_vel)
                    self.imu_hist.append(self.obs_imu)
                    self.ang_vel_hist.append(self.obs_ang_vel)
                    self.tau_hist.append(self.joint_tau)
                    self.obs_hist.append(self.env.obs_buf.cpu().detach().reshape(1, -1))
                    self.temp_hist.append(self.joint_temp)
                
                if DEBUG and SIM:
                    self.lowlevel_state_mujoco()
                
                # Update gait phase and compute observations
                self.step_contact_targets()
                self.compute_observations()
                self.episode_length_buf += 1
                
                # Get policy actions
                raw_actions = self.policy(self.env.obs_buf.detach().reshape(1, -1))
                if torch.any(torch.isnan(raw_actions)):
                    self.get_logger().info("Emergency stop due to NaN")
                    self.set_gains(np.array([0.0]*HW_DOF), self.env.d_gains)
                    self.set_motor_position(q=self.env.default_dof_pos_np)
                    raise SystemExit
                
                self.prev_action = raw_actions.clone().detach().cpu().numpy().squeeze(0)
                leg_action = raw_actions.clone().detach().cpu().numpy().squeeze(0)

                # Get upper body actions if motion is playing
                if (self.start_load_upper_body_motion and
                    global_dof_data is not None and global_motion_info is not None):
                    upper_body_actions = get_upper_body_actions(
                        global_dof_data, global_motion_info,
                        self.env.default_dof_pos_np[12:27], self.env.scale_action
                    )
                    self.dof_data = global_dof_data
                    self.motion_info = global_motion_info
                else:
                    upper_body_actions = np.zeros(15, dtype=np.float32)

                # Combine leg and upper body actions
                whole_body_action = np.concatenate((
                    leg_action * self.env.scale_action + self.env.default_dof_pos_np[:12],
                    upper_body_actions
                ))
                
                # Apply joint limits
                self.angles = np.clip(whole_body_action, self.env.joint_limit_lo, self.env.joint_limit_hi)

                if LOG_DATA and self.is_recording:
                    self.action_hist.append(whole_body_action)
                
                # Send motor commands
                self.set_motor_position(self.angles)
                if not NO_MOTOR and not DEBUG:
                    self.motor_pub.publish(self.cmd_msg)
                else:
                    if not SIM:
                        self.env.mj_data.qpos[7:] = self.angles
                        mujoco.mj_forward(self.env.mj_model, self.env.mj_data)
                        self.env.viewer.sync()
                    else:
                        # Simulate multiple steps for smoother visualization
                        for i in range(20):
                            self.env.viewer.sync()
                            tau = pd_control(
                                self.angles,
                                self.env.mj_data.qpos[7:],
                                self.env.p_gains,
                                np.zeros(self.env.num_actions + 15),
                                self.env.mj_data.qvel[6:],
                                self.env.d_gains
                            )
                            self.env.mj_data.ctrl[:] = tau
                            mujoco.mj_step(self.env.mj_model, self.env.mj_data)
            
            # Maintain 50Hz control loop
            while 0.02 - time.monotonic() + loop_start_time > 0:
                pass
            
            cnt += 1
            if cnt == 500:
                dt = (time.monotonic() - fps_ckt) / cnt
                cnt = 0
                fps_ckt = time.monotonic()
                print(f"FPS: {1/dt}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--task_name', action='store', type=str, help='Task name: stand, datacollect', required=False, default='stand')
    
    args = parser.parse_args()
    
    rclpy.init(args=None)
    dp_node = DeployNode(args.task_name)
    dp_node.get_logger().info("Deploy node started")

    dp_node.main_loop()
    rclpy.shutdown()
