

from sre_parse import State
import numpy as np
import robosuite.utils.transform_utils as T
from scipy.spatial.transform import Rotation as R
import cv2

def put_text(img, text, font_size=1, thickness=2, resize=False,position="top"):
    img = img.copy()
    if resize:
        img = cv2.resize(np.uint8(img), (256, 256))
    h, w = img.shape[:2]
    if position == "top":
        p = (10, 30)
    elif position == "bottom":
        p = (10, h - 10)
    # put the frame number in the top left corner
    img = cv2.putText(
        img,
        text,
        p,
        cv2.FONT_HERSHEY_SIMPLEX,
        font_size,
        (0, 255, 255),
        thickness,
        cv2.LINE_AA,
    )
    return img

# 辅助函数：将 [w, x, y, z] 转换为 [x, y, z, w] 并创建 Rotation 对象
def create_rot_from_wxyz(quat_wxyz):
    quat_xyzw = np.concatenate([quat_wxyz[:, 1:], quat_wxyz[:, 0:1]], axis=1)
    return R.from_quat(quat_xyzw)
    
def compute_eef_error(position_desired, quat_desired, position_actual, quat_actual, pos_coef=0.5, rot_coef=0.5):
    position_desired_after = position_desired[1:]
    quat_desired_after = create_rot_from_wxyz(quat_desired[1:])
    position_actual_after = position_actual[1:]
    quat_actual_after = create_rot_from_wxyz(quat_actual[1:])

    if len(position_desired.shape) == 1:
        pos_err = np.linalg.norm(position_desired_after - position_actual_after)
    else:
        pos_err = np.linalg.norm(position_desired_after - position_actual_after, axis=1)
    rot_err = (quat_desired_after.inv() * quat_actual_after).magnitude()
    err = (pos_coef * pos_err + rot_coef * rot_err) / (pos_coef + rot_coef)
    return err

def compute_eef_error_ratio(position_desired, quat_desired, position_actual, quat_actual):
    position_desired_after = position_desired[1:]
    position_desired_before = position_desired[:-1]
    quat_desired_after = create_rot_from_wxyz(quat_desired[1:])
    quat_desired_before = create_rot_from_wxyz(quat_desired[:-1])

    position_actual_after = position_actual[1:]
    quat_actual_after = create_rot_from_wxyz(quat_actual[1:])

    if len(position_desired.shape) == 1:
        pos_err = np.linalg.norm(position_desired_after - position_actual_after) / np.linalg.norm(position_desired_after - position_desired_before)
    else:
        pos_err = np.linalg.norm(position_desired_after - position_actual_after, axis=1) / np.linalg.norm(position_desired_after - position_desired_before, axis=1)
    rot_err = (quat_desired_after.inv() * quat_actual_after).magnitude() / (quat_desired_after.inv() * quat_desired_before).magnitude()
    # print(pos_coef, rot_coef)
    # print(0 * pos_err + rot_coef * rot_err)
    # err = (pos_coef * pos_err + rot_coef * rot_err) / (pos_coef + rot_coef)
    err = 0.5*pos_err + 0.5*rot_err
    return err



def merge_delta_actions(actions: np.ndarray) -> np.ndarray:
    delta_pos_sum = np.sum(actions[:, :3], axis=0)
    
    # 2. 旋转：LIBERO是Global Frame，因此左乘
    rotations = R.from_rotvec(actions[:, 3:6])
    merged_rot = R.identity()
    for rot in rotations:
        merged_rot = rot * merged_rot
    delta_rot_merged = merged_rot.as_rotvec()
    
    gripper_merged = actions[-1, 6]
    
    return np.hstack([delta_pos_sum, delta_rot_merged, gripper_merged])


def split_action(action):
    action = np.asarray(action)
    
    pos_half = action[:3] / 2.0
    
    delta_rot = R.from_euler('xyz', action[3:6])
    rotvec = delta_rot.as_rotvec()
    rotvec_half = rotvec / 2.0
    rot_half = R.from_rotvec(rotvec_half)
    rot_euler_half = rot_half.as_euler('xyz')
    
    gripper_half = action[6]
    
    action_half = np.concatenate([
        pos_half,
        rot_euler_half,
        [gripper_half]
    ])
    
    # 第二个半动作与第一个相同（因为等分）
    # 返回两个相同的半动作
    split_actions = np.array([action_half, action_half])
    
    return split_actions

def interpolate_pos_quat(position, quat, gripper, time_indices, target_indices):
    # 转换为 numpy 
    pos_pts = np.array(position)
    quat_pts = np.array(quat)
    gripper_pts = np.array(gripper)
    
    # 2. 修正四元数符号翻转 (Crucial for NLERP)
    # 确保相邻四元数点积为正，保证走最短路径
    for i in range(1, len(quat_pts)):
        if np.dot(quat_pts[i-1], quat_pts[i]) < 0:
            quat_pts[i] *= -1

    # 3. 向量化插值
    # 位置插值 (Linear)
    pos_interp = np.zeros((len(target_indices), 3))
    for i in range(3):
        pos_interp[:, i] = np.interp(target_indices, time_indices, pos_pts[:, i])
    
    # 姿态插值 (NLERP)
    quat_interp = np.zeros((len(target_indices), 4))
    for i in range(4):
        quat_interp[:, i] = np.interp(target_indices, time_indices, quat_pts[:, i])
    
    # NLERP 必须做的归一化
    norms = np.linalg.norm(quat_interp, axis=1, keepdims=True)
    quat_interp /= (norms + 1e-12) # 防止除零
    
    # 夹爪插值 (Linear)
    gripper_interp = np.interp(target_indices, time_indices, gripper_pts)

    return pos_interp, quat_interp, gripper_interp

def get_expected_eef_from_action(current_eef_pos, current_eef_quat, action):
    # pos
    delta_pos = np.array(action[:3])
    expected_eef_pos = current_eef_pos + delta_pos
    
    # orientation
    delta_orientation = np.array(action[3:6])  # [droll, dpitch, dyaw]
    current_rot_mat = T.quat2mat(T.convert_quat(current_eef_quat, to="wxyz"))
    delta_rot_mat = T.euler2mat(delta_orientation)
    expected_rot_mat = current_rot_mat @ delta_rot_mat
    expected_eef_quat = T.convert_quat(T.mat2quat(expected_rot_mat), to="xyzw")
    
    return expected_eef_pos, expected_eef_quat

def normalize_action(action, input_max, input_min, output_max, output_min):
    action_scale = abs(output_max - output_min) / abs(input_max - input_min)
    action_output_transform = (output_max + output_min) / 2.0
    action_input_transform = (input_max + input_min) / 2.0
    transformed_action = (action[:, :6] - action_input_transform) * action_scale + action_output_transform
    transformed_action = np.concatenate([transformed_action, action[:, 6:]], axis=1)
    return transformed_action

def unnormalize_action(action, input_max, input_min, output_max, output_min):
    action_scale = abs(output_max - output_min) / abs(input_max - input_min)
    action_output_transform = (output_max + output_min) / 2.0
    action_input_transform = (input_max + input_min) / 2.0
    transformed_action = (action[:, :6] - action_output_transform) / action_scale + action_input_transform
    transformed_action = np.concatenate([transformed_action, action[:, 6:]], axis=1)
    return transformed_action

def snapshot(env):
    """
    保存环境的完整状态，包括：
    1. MuJoCo sim 状态（qpos, qvel, act, time）
    2. MuJoCo 控制输入（ctrl）- 影响 actuator_force
    3. MuJoCo 计算出的派生量（qacc）- 可选，通常不需要，因为可以通过 forward() 重新计算
    
    注意：
    - qacc（关节加速度）是计算出来的派生量，不是状态变量。
    - sim.get_state() 已经包含 act，所以单独保存 act 是冗余的，但为了确保完全一致可以保存。
    - 理论上，恢复 qpos/qvel/ctrl 后调用 forward() 可以重新计算 qacc 和 actuator_force。
    - 但如果需要确保这些值完全一致（例如用于状态验证），可以保存它们。
    
    对任务执行的影响：
    - 这些值主要用于状态一致性验证，不会直接影响任务执行。
    - 恢复状态后，下一次 step() 会重新计算所有值（ctrl, qacc, actuator_force）。
    - 只要 qpos 和 qvel 正确，任务执行应该不受影响。
    """
    state = {
        'sim_state': env.sim.get_state().flatten().copy(),  # 包含 qpos, qvel, act, time
        'ctrl': env.sim.data.ctrl.copy(),  # 保存控制输入，影响 actuator_force 和 qacc
        'act': env.sim.data.act.copy(),  # 保存 act（虽然已在 sim_state 中，但单独保存可确保一致性）
    }
    return state
    
def restore(env, state):
    env.sim.set_state_from_flattened(state['sim_state'])
    if 'ctrl' in state:
        env.sim.data.ctrl[:] = state['ctrl']
    if 'act' in state:
        env.sim.data.act[:] = state['act']
    env.sim.forward()
