from bigym.envs.reach_target import ReachTarget, ReachTargetDual, ReachTargetSingle
from bigym.envs.move_plates import MovePlate, MoveTwoPlates
from bigym.envs.cupboards import (
    CupboardsOpenAll,
    CupboardsCloseAll,
    WallCupboardOpen,
    WallCupboardClose,
    DrawerTopOpen,
    DrawerTopClose,
    DrawersAllOpen,
    DrawersAllClose,
)
from bigym.envs.dishwasher import (
    DishwasherOpen,
    DishwasherClose,
    DishwasherOpenTrays,
    DishwasherCloseTrays,
)
from bigym.envs.dishwasher_cups import (
    DishwasherLoadCups,
    DishwasherUnloadCups,
    DishwasherUnloadCupsLong,
)
from bigym.envs.dishwasher_cutlery import (
    DishwasherLoadCutlery,
    DishwasherUnloadCutlery,
    DishwasherUnloadCutleryLong,
)
from bigym.envs.dishwasher_plates import (
    DishwasherLoadPlates,
    DishwasherUnloadPlates,
    DishwasherUnloadPlatesLong,
)
from bigym.envs.pick_and_place import (
    PutCups,
    TakeCups,
    PickBox,
    SaucepanToHob,
    StoreKitchenware,
    ToastSandwich,
    FlipSandwich,
    RemoveSandwich,
    StoreBox,
)
from bigym.envs.manipulation import FlipCup, FlipCutlery, StackBlocks
from bigym.envs.groceries import GroceriesStoreLower, GroceriesStoreUpper
from bigym.const import HandSide

import mujoco
import numpy as np
from collections import defaultdict

TASK_MAP = dict(
    reach_target_single=ReachTargetSingle,  # 2000, 10, enable_all_floating_dofs=False
    reach_target_multi_modal=ReachTarget,  # 3000, 10, enable_all_floating_dofs=False
    reach_target_dual=ReachTargetDual,  # 3000, 10, enable_all_floating_dofs=False
    stack_blocks=StackBlocks,  # 28500, 25
    move_plate=MovePlate,  # 3000, 10
    move_two_plates=MoveTwoPlates,  # 5500, 10
    flip_cup=FlipCup,  # 5500, 10
    flip_cutlery=FlipCutlery,  # 12500, 25
    dishwasher_open=DishwasherOpen,  # 7500, 20
    dishwasher_close=DishwasherClose,  # 7500, 20
    dishwasher_open_trays=DishwasherOpenTrays,  # 9500, 25
    dishwasher_close_trays=DishwasherCloseTrays,  # 7500, 25
    dishwasher_load_cups=DishwasherLoadCups,  # 7500, 10
    dishwasher_unload_cups=DishwasherUnloadCups,  # 10000, 25
    dishwasher_unload_cups_long=DishwasherUnloadCupsLong,  # 18000, 25
    dishwasher_load_cutlery=DishwasherLoadCutlery,  # 7000, 10
    dishwasher_unload_cutlery=DishwasherUnloadCutlery,  # 15500, 25
    dishwasher_unload_cutlery_long=DishwasherUnloadCutleryLong,  # 18000, 25
    dishwasher_load_plates=DishwasherLoadPlates,  # 14000, 25
    dishwasher_unload_plates=DishwasherUnloadPlates,  # 20000, 25
    dishwasher_unload_plates_long=DishwasherUnloadPlatesLong,  # 26000, 25
    drawer_top_open=DrawerTopOpen,  # 5000, 10
    drawer_top_close=DrawerTopClose,  # 3000, 10
    drawers_open_all=DrawersAllOpen,  # 12000, 25
    drawers_close_all=DrawersAllClose,  # 5000, 25
    wall_cupboard_open=WallCupboardOpen,  # 6000, 20
    wall_cupboard_close=WallCupboardClose,  # 3000, 10
    cupboards_open_all=CupboardsOpenAll,  # 22500, 25
    cupboards_close_all=CupboardsCloseAll,  # 15500, 25
    take_cups=TakeCups,  # 10500, 25
    put_cups=PutCups,  # 8500, 20
    pick_box=PickBox,  # 13500, 25
    store_box=StoreBox,  # 15000, 25
    saucepan_to_hob=SaucepanToHob,  # 11000, 25
    store_kitchenware=StoreKitchenware,  # 20000, 25
    sandwich_toast=ToastSandwich,  # 16500, 25
    sandwich_flip=FlipSandwich,  # 15500, 25
    sandwich_remove=RemoveSandwich,  # 13500, 25
    store_groceries_lower=GroceriesStoreLower,  # 32000, 25
    store_groceries_upper=GroceriesStoreUpper,  # 19000, 25
)


def snapshot_env(env):
    data = env.unwrapped.mojo.data
    # 按需拷贝 qpos, qvel, act, ctrl, etc.
    return {
        "qpos": np.array(data.qpos),
        "qvel": np.array(data.qvel),
        "act": np.array(data.act),
        "ctrl": np.array(data.ctrl),
        # 如果任务还记录了其他缓存（比如 floating_base accum actions），也加上
    }

def restore_env(env, state):
    data = env.unwrapped.mojo.data
    data.qpos[:] = state["qpos"]
    data.qvel[:] = state["qvel"]
    if "act" in state:
        data.act[:] = state["act"]
    if "ctrl" in state:
        data.ctrl[:] = state["ctrl"]
    mujoco.mj_forward(env.unwrapped.mojo.model, data)

def apply_joint_targets(env, joint_vector):
    """
    joint_vector 顺序（长度必须为 16）：
    [pelvis_x, pelvis_y, pelvis_z, pelvis_rz,
     left_shoulder_pitch, left_shoulder_roll, left_shoulder_yaw, left_elbow, left_wrist,
     right_shoulder_pitch, right_shoulder_roll, right_shoulder_yaw, right_elbow, right_wrist,
     left_gripper, right_gripper]

    注意：使用前先 snapshot_env(env)，用完再 restore_env(env, snapshot)，否则真实环境会被修改。
    """
    joint_vector = np.asarray(joint_vector, dtype=np.float64)
    if joint_vector.shape[0] != 16 and joint_vector.shape[0] != 15:
        raise ValueError(f"Expected 16 or 15 joint values, got {joint_vector.shape[0]}.")

    uenv = env.unwrapped
    robot = uenv.robot
    mojo = uenv.mojo
    physics = mojo.physics

    # 1) 浮动基座 x, y, z, yaw
    if joint_vector.shape[0] == 16 and robot.floating_base:
        base_pos = joint_vector[:3]
        yaw = joint_vector[3]
        quat = np.array([np.cos(yaw / 2.0), 0.0, 0.0, np.sin(yaw / 2.0)], dtype=np.float64)

        base_offset = np.zeros(3, dtype=np.float64)
        if getattr(robot, "config", None) and getattr(robot.config, "floating_base", None):
            base_offset = np.array(getattr(robot.config.floating_base, "offset_position", np.zeros(3)), dtype=np.float64)
        robot.set_pose(base_pos - base_offset, quat)
    elif joint_vector.shape[0] == 15 and robot.floating_base:
        base_pos = np.append(joint_vector[:2], 0.0)
        yaw = joint_vector[2]
        quat = np.array([np.cos(yaw / 2.0), 0.0, 0.0, np.sin(yaw / 2.0)], dtype=np.float64)

        base_offset = np.zeros(3, dtype=np.float64)
        if getattr(robot, "config", None) and getattr(robot.config, "floating_base", None):
            base_offset = np.array(getattr(robot.config.floating_base, "offset_position", np.zeros(3)), dtype=np.float64)
        robot.set_pose(base_pos, quat)

    # 2) 左右臂 5 个关节（按名字写回 qpos）
    #    robot._joints 是 mojo.elements.Joint 的列表，名称与 config 中一致
    rest_vector = joint_vector[4:] if joint_vector.shape[0] == 16 else joint_vector[3:]
    joint_map = {joint.mjcf.name: joint for joint in getattr(robot, "_joints", []) if joint.mjcf is not None}

    joint_targets = {
        "left_shoulder_pitch": rest_vector[0],
        "left_shoulder_roll":  rest_vector[1],
        "left_shoulder_yaw":   rest_vector[2],
        "left_elbow":          rest_vector[3],
        "left_wrist":          rest_vector[4],
        "right_shoulder_pitch": rest_vector[5],
        "right_shoulder_roll":  rest_vector[6],
        "right_shoulder_yaw":   rest_vector[7],
        "right_elbow":          rest_vector[8],
        "right_wrist":          rest_vector[9],
    }

    for name, value in joint_targets.items():
        joint = joint_map.get(name)
        if joint is None:
            continue  # 在当前模型中不存在该关节就跳过
        bound = physics.bind(joint.mjcf)
        bound.qpos = value              # hinge/slide 的 qpos 是标量
        bound.qvel = 0.0
        if hasattr(bound, "qacc"):
            bound.qacc = 0.0

    # 3) 夹爪开度（使用 gripper.set_control）
    left_grip = float(rest_vector[10])
    right_grip = float(rest_vector[11])
    if HandSide.LEFT in robot.grippers:
        robot.grippers[HandSide.LEFT].set_control(left_grip)
    if HandSide.RIGHT in robot.grippers:
        robot.grippers[HandSide.RIGHT].set_control(right_grip)

    # 4) 刷新派生量
    mujoco.mj_forward(mojo.model, mojo.data)

class ErrorCalculator:
    """Utility to compute EEF-based errors from an environment."""

    def __init__(self, env):
        self.pos_coef = 1.0
        self.ang_coef = 1.0

        model = env.unwrapped.mojo.model
        self.qpos_dict = {}
        self.qvel_dict = {}
        for i in range(model.njnt):
            name = model.names[model.name_jntadr[i]:].split(b'\x00')[0].decode()
            name = name.split('/')[-1]
            qpos_adr = model.jnt_qposadr[i]
            dof = model.jnt_dofadr[i]   # 在 qvel 中的 index
            self.qpos_dict[name] = qpos_adr
            self.qvel_dict[name] = dof
        self.joint_names = [joint.mjcf.name for joint in env.unwrapped.robot._joints]
        
        # get all actuated names except gripper
        limb_names = [mjcf.name for mjcf in env.unwrapped.robot.limb_actuators]
        fb_names = [mjcf.name for mjcf in env.unwrapped.robot.floating_base.all_actuators]
        actuated_names = fb_names + limb_names
        prop_names = {j.mjcf.name: i for i, j in enumerate(env.unwrapped.robot._joints)}
        self.prop_to_actuated_idx = [prop_names[a_n] for a_n in actuated_names]
        
    def apply_proprioception_and_get_eef_error(self, env, proprioception):
        state = snapshot_env(env)
        data = env.unwrapped.mojo.data
        model = env.unwrapped.mojo.model
        qpos = proprioception[:30]
        qvel = proprioception[30:60]
        fb_qpos = proprioception[60:64]
        gripper_qpos = proprioception[64:]
        for i, name in enumerate(self.joint_names):
            qpos_idx = self.qpos_dict[name]
            qvel_idx = self.qvel_dict[name]
            data.qpos[qpos_idx] = qpos[i]
            data.qvel[qvel_idx] = qvel[i]
        mujoco.mj_forward(model, data)

        eef = self._read_eef(env)
        restore_env(env, state)
        return eef

    def scalar_error_from_metric(self, metric):
        aggreates = 0
        if type(metric) is list:
            for m in metric:
                for k, v in m.items():
                    if 'position' in k:
                        aggreates += v * self.pos_coef
                    elif 'angle' in k:
                        aggreates += v * self.ang_coef
            return aggreates/len(metric)
        else:
            for k, v in metric.items():
                if 'position' in k:
                    aggreates += v * self.pos_coef
                elif 'angle' in k:
                    aggreates += v * self.ang_coef
            return aggreates

    def compute_chunk_sim(self, env, proprioceptions, actions):
        errors = []
        data = snapshot_env(env)
        for i in range(len(actions)):
            err = self.compute_sim(env, proprioceptions[i], actions[i])
            errors.append(err)
        restore_env(env, data)
        return errors

    def compute_sim(self, env, proprioception, action):
        actual = self.apply_proprioception_and_get_eef_error(env, proprioception)
        desired = self._desired_eef(env, action)
        return self._eef_error(desired, actual)

    def compute_chunk(self, env, actions):
        errors = []
        data = snapshot_env(env)
        for action in actions:
            err = self.compute(env, action)
            env.step(action, fast=True)
            errors.append(err)
        restore_env(env, data)
        return errors
    
    def compute(self, env, action):
        action = np.array(action, dtype=np.float64)
        actual = self._read_eef(env)
        desired = self._desired_eef(env, action)
        return self._eef_error(desired, actual)

    @staticmethod
    def _desired_eef(env, action, return_array=False):
        backup = snapshot_env(env)
        apply_joint_targets(env, action)
        mujoco.mj_forward(env.unwrapped.mojo.model, env.unwrapped.mojo.data)
        desired = ErrorCalculator._read_eef(env, return_array=return_array)
        restore_env(env, backup)
        return desired

    @staticmethod
    def _desired_eef_array(env, actions):
        desireds = []
        for action in actions:
            backup = snapshot_env(env)
            apply_joint_targets(env, action)
            mujoco.mj_forward(env.unwrapped.mojo.model, env.unwrapped.mojo.data)
            desired = ErrorCalculator._read_eef(env, return_array=True)
            desireds.append(desired)
            restore_env(env, backup)

        desired = np.stack(desireds, axis=0)
        return desired
    
    @staticmethod
    def _actual_eef_array(env, actions):
        backup = snapshot_env(env)
        actuals = []
        for action in actions:
            env.step(action)
            actual = ErrorCalculator._read_eef(env, return_array=True)
            actuals.append(actual)
        restore_env(env, backup)
        actual = np.stack(actuals, axis=0)
        return actual
    
    
    @staticmethod
    def _read_eef(env, return_array=False):
        uenv = env.unwrapped
        robot = uenv.robot
        mojo = uenv.mojo

        def _position(side):
            return np.array(robot.get_hand_pos(side), dtype=np.float64)

        def _quat(side):
            if side not in robot.grippers:
                return np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64)
            site = robot.grippers[side].wrist_site
            quat = site.get_quaternion()
            return np.asarray(quat, dtype=np.float64)

        def _grip(side):
            if side in robot.grippers:
                return float(robot.grippers[side].qpos)
            return float("nan")
        if return_array:
            return np.concatenate([
                _position(HandSide.LEFT),
                _quat(HandSide.LEFT),
                [_grip(HandSide.LEFT)],
                _position(HandSide.RIGHT),
                _quat(HandSide.RIGHT),
                [_grip(HandSide.RIGHT)],
            ])

        return {
            "left_position": _position(HandSide.LEFT),
            "left_quat": _quat(HandSide.LEFT),
            "left_grip": _grip(HandSide.LEFT),
            "right_position": _position(HandSide.RIGHT),
            "right_quat": _quat(HandSide.RIGHT),
            "right_grip": _grip(HandSide.RIGHT),
        }

    @staticmethod
    def _eef_error(desired, actual):
        from scipy.spatial.transform import Rotation as R
        errors = {}
        for key in desired.keys():
            if key not in desired or key not in actual:
                continue
            desired_state = desired[key]
            actual_state = actual[key]
        
            if 'position' in key:
                err = float(np.linalg.norm(desired_state - actual_state))
            # elif 'grip' in key:
            #     err = float(np.abs(desired_state - actual_state))
            elif 'angle' in key:
                rot_err = R.from_quat([desired_state[1], desired_state[2], desired_state[3], desired_state[0]]).inv() * R.from_quat([
                    actual_state[1], actual_state[2], actual_state[3], actual_state[0]
                ])
                err = float(rot_err.magnitude())

            errors[key+'_error'] = err 
        
        return errors

    def compute_eef_error_from_array(self, desired, actual):
        # desired, actual: (seq_len, eef_dim)
        from scipy.spatial.transform import Rotation as R
        left_postion_desired = desired[:, :3]
        left_quat_desired = desired[:, 3:7]
        left_grip_desired = desired[:, 7]
        right_postion_desired = desired[:, 8:11]
        right_quat_desired = desired[:, 11:15] 
        right_grip_desired = desired[:, 15]

        left_postion_actual = actual[:, :3]
        left_quat_actual = actual[:, 3:7]
        left_grip_actual = actual[:, 7]
        right_postion_actual = actual[:, 8:11]
        right_quat_actual = actual[:, 11:15] 
        right_grip_actual = actual[:, 15]

        left_pos_err = np.linalg.norm(left_postion_desired - left_postion_actual, axis=1)
        right_pos_err = np.linalg.norm(right_postion_desired - right_postion_actual, axis=1)
        # 辅助函数：将 [w, x, y, z] 转换为 [x, y, z, w] 并创建 Rotation 对象
        def create_rot_from_wxyz(quat_wxyz):
            quat_xyzw = np.concatenate([quat_wxyz[:, 1:], quat_wxyz[:, 0:1]], axis=1)
            return R.from_quat(quat_xyzw)

        r_desired_left = create_rot_from_wxyz(left_quat_desired)
        r_actual_left = create_rot_from_wxyz(left_quat_actual)
        # 计算相对旋转: R_diff = R_desired^-1 * R_actual
        # magnitude() 返回旋转角度（弧度），范围 [0, pi]
        left_rot_err = (r_desired_left.inv() * r_actual_left).magnitude()

        # 计算右手旋转误差
        r_desired_right = create_rot_from_wxyz(right_quat_desired)
        r_actual_right = create_rot_from_wxyz(right_quat_actual)
        right_rot_err = (r_desired_right.inv() * r_actual_right).magnitude()

        # 组合误差
        err = self.pos_coef * (left_pos_err + right_pos_err) + \
              self.ang_coef * (left_rot_err + right_rot_err)
        
        err = err / 4.0

        return err