# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from isaaclab.managers import RewardTermCfg as RewTerm
from isaaclab.managers import SceneEntityCfg
from isaaclab.utils import configclass
from isaaclab.managers import SuccessTermCfg as SucTerm
import isaaclab_tasks.manager_based.locomotion.velocity.mdp as mdp
from isaaclab_tasks.manager_based.locomotion.velocity.velocity_env_cfg import LocomotionVelocityRoughEnvCfg, RewardsCfg

##
# Pre-defined configs
##
from isaaclab_assets import G1_MINIMAL_CFG  # isort: skip


@configclass
class G1Rewards(RewardsCfg):
    """Reward terms for the MDP."""

    termination_penalty = RewTerm(func=mdp.is_terminated, weight=-200.0)
    track_lin_vel_xy_exp = RewTerm(
        func=mdp.track_lin_vel_xy_yaw_frame_exp,
        weight=1.0,
        params={"command_name": "base_velocity", "std": 0.5},
    )
    track_ang_vel_z_exp = RewTerm(
        func=mdp.track_ang_vel_z_world_exp, weight=2.0, params={"command_name": "base_velocity", "std": 0.5}
    )
    feet_air_time = RewTerm(
        func=mdp.feet_air_time_positive_biped,
        weight=0.25,
        params={
            "command_name": "base_velocity",
            "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*_ankle_roll_link"),
            "threshold": 0.4,
        },
    )
    feet_slide = RewTerm(
        func=mdp.feet_slide,
        weight=-0.1,
        params={
            "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*_ankle_roll_link"),
            "asset_cfg": SceneEntityCfg("robot", body_names=".*_ankle_roll_link"),
        },
    )

    # Penalize ankle joint limits
    dof_pos_limits = RewTerm(
        func=mdp.joint_pos_limits,
        weight=-1.0,
        params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_ankle_pitch_joint", ".*_ankle_roll_joint"])},
    )
    # Penalize deviation from default of the joints that are not essential for locomotion
    joint_deviation_hip = RewTerm(
        func=mdp.joint_deviation_l1,
        weight=-0.1,
        params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_hip_yaw_joint", ".*_hip_roll_joint"])},
    )
    joint_deviation_arms = RewTerm(
        func=mdp.joint_deviation_l1,
        weight=-0.1,
        params={
            "asset_cfg": SceneEntityCfg(
                "robot",
                joint_names=[
                    ".*_shoulder_pitch_joint",
                    ".*_shoulder_roll_joint",
                    ".*_shoulder_yaw_joint",
                    ".*_elbow_pitch_joint",
                    ".*_elbow_roll_joint",
                ],
            )
        },
    )
    joint_deviation_fingers = RewTerm(
        func=mdp.joint_deviation_l1,
        weight=-0.05,
        params={
            "asset_cfg": SceneEntityCfg(
                "robot",
                joint_names=[
                    ".*_five_joint",
                    ".*_three_joint",
                    ".*_six_joint",
                    ".*_four_joint",
                    ".*_zero_joint",
                    ".*_one_joint",
                    ".*_two_joint",
                ],
            )
        },
    )
    joint_deviation_torso = RewTerm(
        func=mdp.joint_deviation_l1,
        weight=-0.1,
        params={"asset_cfg": SceneEntityCfg("robot", joint_names="torso_joint")},
    )

def constraint(env, joint_pos_limits, joint_deviation_l1, feet_slide, track_ang_vel_z_world_exp, feet_air_time_positive_biped):
    # 限制 1：踝关节不超限
    ankle_limit = joint_pos_limits(
        env,
        asset_cfg=SceneEntityCfg("robot", joint_names=[".*_ankle_pitch_joint", ".*_ankle_roll_joint"])
    )

    # 限制 2：髋关节、手臂、手指、躯干不要偏离过多
    hip_dev = joint_deviation_l1(env, asset_cfg=SceneEntityCfg("robot", joint_names=[".*_hip_yaw_joint", ".*_hip_roll_joint"]))
    arm_dev = joint_deviation_l1(env, asset_cfg=SceneEntityCfg(
        "robot",
        joint_names=[
            ".*_shoulder_pitch_joint",
            ".*_shoulder_roll_joint",
            ".*_shoulder_yaw_joint",
            ".*_elbow_pitch_joint",
            ".*_elbow_roll_joint",
        ]
    ))
    finger_dev = joint_deviation_l1(env, asset_cfg=SceneEntityCfg(
        "robot",
        joint_names=[
            ".*_five_joint",
            ".*_three_joint",
            ".*_six_joint",
            ".*_four_joint",
            ".*_zero_joint",
            ".*_one_joint",
            ".*_two_joint",
        ]
    ))
    torso_dev = joint_deviation_l1(env, asset_cfg=SceneEntityCfg("robot", joint_names="torso_joint"))

    # 限制 3：避免脚滑动
    slide = feet_slide(
        env,
        sensor_cfg=SceneEntityCfg("contact_forces", body_names=".*_ankle_roll_link"),
        asset_cfg=SceneEntityCfg("robot", body_names=".*_ankle_roll_link"),
    )
    ang_vel_score = track_ang_vel_z_world_exp(env, command_name="base_velocity", std=0.5)

    # 判据 3：足部空气时间
    air_time_score = feet_air_time_positive_biped(
        env,
        command_name="base_velocity",
        sensor_cfg=SceneEntityCfg("contact_forces", body_names=".*_ankle_roll_link"),
        threshold=0.4,
    )
    # 是否满足约束
    satisfy_constraint = (ankle_limit < 0.2) & (hip_dev < 0.2) & (arm_dev < 0.2) & (finger_dev < 0.2) & (torso_dev < 0.2) & (slide < 0.2)

    return satisfy_constraint, {
        "ankle_limit": ankle_limit.max(),
        "hip_dev": hip_dev.max(),
        "arm_dev": arm_dev.max(),
        "finger_dev": finger_dev.max(),
        "torso_dev": torso_dev.max(),
        "feet_slide": slide.max(),
        "ang_vel_score": ang_vel_score.max(),
        "air_time_score": air_time_score.max(),
    }


def criteria(env, track_lin_vel_xy_yaw_frame_exp):
    # 判据 1：存活到 episode 结束
    survive = env.episode_length_buf >= env.max_episode_length

    # 判据 2：速度跟踪
    lin_vel_score = track_lin_vel_xy_yaw_frame_exp(env, command_name="base_velocity", std=0.5)

    # 综合得分（0~1 之间）
    score = torch.clamp((lin_vel_score) / 4.0, min=0., max=1.)

    return survive * score


@configclass
class G1SuccessCfg:
    """Success terms for the MDP."""

    constraint = SucTerm(
        func=constraint,
        params={
            "joint_pos_limits": mdp.joint_pos_limits,
            "joint_deviation_l1": mdp.joint_deviation_l1,
            "feet_slide": mdp.feet_slide,
            "track_ang_vel_z_world_exp": mdp.track_ang_vel_z_world_exp,
            "feet_air_time_positive_biped": mdp.feet_air_time_positive_biped,
        },
    )

    criteria = SucTerm(
        func=criteria,
        params={
            "track_lin_vel_xy_yaw_frame_exp": mdp.track_lin_vel_xy_yaw_frame_exp,
        },
    )

@configclass
class G1RoughEnvCfg(LocomotionVelocityRoughEnvCfg):
    rewards: G1Rewards = G1Rewards()

    def __post_init__(self):
        # post init of parent
        super().__post_init__()
        # Scene
        self.scene.robot = G1_MINIMAL_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot")
        self.scene.height_scanner.prim_path = "{ENV_REGEX_NS}/Robot/torso_link"

        # Randomization
        self.events.push_robot = None
        self.events.add_base_mass = None
        self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0)
        self.events.base_external_force_torque.params["asset_cfg"].body_names = ["torso_link"]
        self.events.reset_base.params = {
            "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)},
            "velocity_range": {
                "x": (0.0, 0.0),
                "y": (0.0, 0.0),
                "z": (0.0, 0.0),
                "roll": (0.0, 0.0),
                "pitch": (0.0, 0.0),
                "yaw": (0.0, 0.0),
            },
        }

        # Rewards
        self.rewards.lin_vel_z_l2.weight = 0.0
        self.rewards.undesired_contacts = None
        self.rewards.flat_orientation_l2.weight = -1.0
        self.rewards.action_rate_l2.weight = -0.005
        self.rewards.dof_acc_l2.weight = -1.25e-7
        self.rewards.dof_acc_l2.params["asset_cfg"] = SceneEntityCfg(
            "robot", joint_names=[".*_hip_.*", ".*_knee_joint"]
        )
        self.rewards.dof_torques_l2.weight = -1.5e-7
        self.rewards.dof_torques_l2.params["asset_cfg"] = SceneEntityCfg(
            "robot", joint_names=[".*_hip_.*", ".*_knee_joint", ".*_ankle_.*"]
        )

        # Commands
        self.commands.base_velocity.ranges.lin_vel_x = (0.0, 1.0)
        self.commands.base_velocity.ranges.lin_vel_y = (-0.0, 0.0)
        self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0)

        # terminations
        self.terminations.base_contact.params["sensor_cfg"].body_names = "torso_link"


@configclass
class G1RoughEnvCfg_PLAY(G1RoughEnvCfg):
    def __post_init__(self):
        # post init of parent
        super().__post_init__()

        # make a smaller scene for play
        self.scene.num_envs = 50
        self.scene.env_spacing = 2.5
        self.episode_length_s = 40.0
        # spawn the robot randomly in the grid (instead of their terrain levels)
        self.scene.terrain.max_init_terrain_level = None
        # reduce the number of terrains to save memory
        if self.scene.terrain.terrain_generator is not None:
            self.scene.terrain.terrain_generator.num_rows = 5
            self.scene.terrain.terrain_generator.num_cols = 5
            self.scene.terrain.terrain_generator.curriculum = False

        self.commands.base_velocity.ranges.lin_vel_x = (1.0, 1.0)
        self.commands.base_velocity.ranges.lin_vel_y = (0.0, 0.0)
        self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0)
        self.commands.base_velocity.ranges.heading = (0.0, 0.0)
        # disable randomization for play
        self.observations.policy.enable_corruption = False
        # remove random pushing
        self.events.base_external_force_torque = None
        self.events.push_robot = None
