# Copyright (c) 2022-2024, The IsaacLab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import math
from dataclasses import MISSING

import isaaclab.sim as sim_utils
from isaaclab.assets import ArticulationCfg, AssetBaseCfg
from isaaclab.envs import ManagerBasedRLEnvCfg
from isaaclab.managers import CurriculumTermCfg as CurrTerm
from isaaclab.managers import ObservationGroupCfg as ObsGroup
from isaaclab.managers import ObservationTermCfg as ObsTerm
from isaaclab.managers import EventTermCfg as EventTerm
from isaaclab.managers import SceneEntityCfg
from isaaclab.managers import TerminationTermCfg as DoneTerm
from isaaclab.scene import InteractiveSceneCfg
from isaaclab.sensors import ContactSensorCfg
from isaaclab.utils import configclass
from isaaclab.utils.noise import AdditiveUniformNoiseCfg as Unoise
from isaaclab_assets.robots.anymal import ANYDRIVE_4_MLP_ACTUATOR_CFG
import os
# from p4rl import P4RL_EXT_DIR # cause "can not import from partially initialized module" error
P4RL_EXT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
import p4rl.tasks.exploration.mdp as mdp
from isaaclab.utils.assets import ISAACLAB_NUCLEUS_DIR
from isaaclab.managers import RewardTermCfg as RewTerm
import torch.nn as nn
from isaaclab.terrains import TerrainImporterCfg
from parkour.config.terrain_cfg import PARKOUR_TERRAIN_CFG
from parkour.config.terrain_cfg import WALK_SUBTERRAINS

from isaaclab_assets import G1_MINIMAL_CFG  # isort: skip



##
# Scene definition
##

@configclass
class P4RLExplorationBaseSceneCfg(InteractiveSceneCfg):
    """Configuration for the terrain scene with a legged robot."""

    # ground terrain
    terrain = None # defined in child classes
    # robots
    robot: ArticulationCfg = G1_MINIMAL_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot")
    # contact sensors
    contact_forces = ContactSensorCfg(prim_path="{ENV_REGEX_NS}/Robot/.*", history_length=3, track_air_time=True)
    # lights
    light = AssetBaseCfg(
        prim_path="/World/light",
        spawn=sim_utils.DomeLightCfg(
            intensity=1000,
            exposure=1.5,
            color=(1.0, 1.0, 1.0)
        )
    )
    env_spacing = 5
    replicate_physics = False
    num_envs = 4096


##
# MDP settings
##

@configclass
class CommandsCfg:
    """Configuration for the foot position command generator."""

    rle_command = mdp.RLECommandCfg(
        resampling_time_range=(6.0, 6.0),
        dim_latent_space=128,
        std_sampling=1.0,
        normalize_vector=True,  # whether to normalize the sampled vector to unit length
    )


@configclass
class InvInputCfg(ObsGroup):
    """Observations for PIDM group. This space can not include past actions!
    NOTE: corruption is not enabled, so that we can control the noise in training loop. """

    # observation terms (order preserved)
    base_lin_vel = ObsTerm(func=mdp.base_lin_vel) # indices 0:3
    base_ang_vel = ObsTerm(func=mdp.base_ang_vel) # indices 3:6
    projected_gravity = ObsTerm(
        func=mdp.projected_gravity,
    ) # indices 6:9
    
    joint_pos = ObsTerm(func=mdp.joint_pos_rel) # length , indices 
    joint_vel = ObsTerm(func=mdp.joint_vel_rel) # length , indices

    def __post_init__(self):
        self.enable_corruption = False # NOTE corruption is not enabled, so that we can control the noise in training loop
        self.concatenate_terms = True


@configclass
class PolicyCfg(ObsGroup):
    """Observations for policy group."""

    # observation terms (order preserved)
    base_lin_vel = ObsTerm(func=mdp.base_lin_vel, noise=Unoise(n_min=-0.1, n_max=0.1)) # indices 0:3
    base_ang_vel = ObsTerm(func=mdp.base_ang_vel, noise=Unoise(n_min=-0.2, n_max=0.2)) # indices 3:6
    projected_gravity = ObsTerm(
        func=mdp.projected_gravity,
        noise=Unoise(n_min=-0.05, n_max=0.05),
    ) # indices 6:9
    
    joint_pos = ObsTerm(func=mdp.joint_pos_rel, noise=Unoise(n_min=-0.01, n_max=0.01)) # length 12, indices 9:21
    joint_vel = ObsTerm(func=mdp.joint_vel_rel, noise=Unoise(n_min=-1.5, n_max=1.5)) # length 12, indices 21:33

    actions = ObsTerm(func=mdp.last_action, clip=(-100.00, 100.00)) # length 12, indices 33:45

    def __post_init__(self):
        self.enable_corruption = True
        self.concatenate_terms = True


@configclass
class ObservationsCfg:
    """Observation specifications for the MDP."""

    # observation groups
    policy: PolicyCfg = PolicyCfg()
    inv_dynamics_input: InvInputCfg = InvInputCfg()

@configclass
class EventCfg:
    """Configuration for randomization."""
    # startup
    physics_material: EventTerm | None = EventTerm(
        func=mdp.randomize_rigid_body_material,
        mode="startup",
        params={
            "asset_cfg": SceneEntityCfg("robot", body_names=".*"),
            "static_friction_range": (0.8, 0.8),  # TODO: legged_gym only has one friction termn, how to choose?
            "dynamic_friction_range": (0.6, 0.6),
            "restitution_range": (0.0, 0.0),
            "num_buckets": 64,
        },
    )

    # add_base_mass: EventTerm | None = EventTerm(
    #     func=mdp.randomize_rigid_body_mass,
    #     mode="startup",
    #     params={"asset_cfg": SceneEntityCfg("robot", body_names="base"), "mass_distribution_params": (-5.0, 5.0), "operation": "add"},
    # )

    # push_foot_constant: EventTerm | None = EventTerm(
    #     func=mdp.apply_external_force_torque,
    #     mode="reset",
    #     params={
    #         "asset_cfg": SceneEntityCfg("robot", body_names=".*FOOT"),
    #         "force_range": (0.0, 12.0),
    #         "torque_range": (-0.0, 0.0),
    #     },
    # )

    reset_base = EventTerm(
        func=mdp.reset_root_state_uniform,
        mode="reset",
        params={
            "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)},
            "velocity_range": {
                "x": (0.0, 0.0),
                "y": (0.0, 0.0),
                "z": (0.0, 0.0),
                "roll": (0.0, 0.0),
                "pitch": (0.0, 0.0),
                "yaw": (0.0, 0.0),
            },
        }
    )

    reset_robot_joints = EventTerm(
        func=mdp.reset_joints_by_scale,
        mode="reset",
        params={
            "position_range": (1.0, 1.0),
            "velocity_range": (0.0, 0.0),
        },
    )

    '''

    note the joint limits (those are not enforced in articulation class)
        joint_limits = torch.tensor([[-0.7854,  0.6109],
                                [-0.7854,  0.6109],
                                [-0.6109,  0.7854],
                                [-0.6109,  0.7854],
                                [-9.4248,  9.4248],
                                [-9.4248,  9.4248],
                                [-9.4248,  9.4248],
                                [-9.4248,  9.4248],
                                [-9.4248,  9.4248],
                                [-9.4248,  9.4248],
                                [-9.4248,  9.4248],
                                [-9.4248,  9.4248]], device='cuda:0') 
    '''

    # # interval
    # push_robot: EventTerm | None = EventTerm(
    #     func=mdp.push_by_setting_velocity,
    #     mode="interval",
    #     interval_range_s=(2.0, 4.0),
    #     params={"velocity_range": {"x": (-1.5, 1.5),
    #                                 "y": (-1.5, 1.5),
    #                                 "z": (-1.5, 1.5),
    #                                 "roll": (-1.5, 1.5),
    #                                 "pitch": (-1.5, 1.5),
    #                                 "yaw": (-1.5, 1.5),
    #                                 }},
    # )

    # push_foot_interval: EventTerm | None = EventTerm(
    #     func=mdp.apply_external_force_torque_to_foot,
    #     mode="interval",
    #     interval_range_s=(2.0, 4.0),
    #     params={"asset_cfg": SceneEntityCfg("robot", body_names=".*FOOT"),
    #             "force_range": (-50, 50),
    #             "torque_range": (-0.0, 0.0),
    #             },
    # )


# @configclass
# class ActionsCfg:
#     """Action specifications for the MDP."""

#     joint_pos = mdp.JointPositionActionCfg(asset_name="robot", joint_names=[".*"], scale=0.5, use_default_offset=True)

@configclass
class ActionsCfg:
    """Action specifications for the MDP."""

    # joint_pos = mdp.RelativeJointPositionActionCfg(asset_name="robot", joint_names=[".*"], scale=0.5)
    joint_pos = mdp.JointPositionActionCfg(asset_name="robot", joint_names=[".*"], scale=0.5, use_default_offset=True)


@configclass
class TerminationsCfg:
    """Termination terms for the MDP."""

    time_out = DoneTerm(func=mdp.time_out, time_out=True)

    bad_orientation = DoneTerm(func=mdp.bad_orientation, 
                               params={"limit_angle": math.radians(30.0), "asset_cfg": SceneEntityCfg("robot")})

    base_contact = DoneTerm( # terminate if the base contacts the ground hard enough
        func=mdp.illegal_contact,
        params={"sensor_cfg": SceneEntityCfg("contact_forces", body_names="base"), "threshold": 0.1}, # enlarged threshold
    )


@configclass
class TerminationsRoughCfg(TerminationsCfg):
    def __post_init__(self):
        """Post initialization."""
        super().__post_init__()
        # add terrain related terminations
        self.illegal_force = DoneTerm(
                func=mdp.illegal_contact,
                params={"sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*"), "threshold": 5000.0},
            )
        # self.illegal_force_feet = DoneTerm(
        #     func=mdp.illegal_contact,
        #     params={"sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*FOOT"), "threshold": 1500.0},
        # )


@configclass
class RewardsCfg:
    """Reward terms for the MDP."""
    # -- penalties
    lin_vel_z_l2 = RewTerm(func=mdp.lin_vel_z_l2, weight=-2.0)
    ang_vel_xy_l2 = RewTerm(func=mdp.ang_vel_xy_l2, weight=-0.05)
    dof_torques_l2 = RewTerm(func=mdp.joint_torques_l2, weight=-1.0e-5)
    dof_acc_l2 = RewTerm(func=mdp.joint_acc_l2, weight=-2.5e-7)
    action_rate_l2 = RewTerm(func=mdp.action_rate_l2, weight=-0.01)

    # -- optional penalties
    flat_orientation_l2 = RewTerm(func=mdp.flat_orientation_l2, weight=0.0)
    dof_pos_limits = RewTerm(func=mdp.joint_pos_limits, weight=0.0)

@configclass
class G1Rewards(RewardsCfg):
    """Reward terms for the MDP."""

    termination_penalty = RewTerm(func=mdp.is_terminated, weight=-200.0)

    # following rewards are dependent on velocity command and are thus disabled.

    # track_lin_vel_xy_exp = RewTerm(
    #     func=mdp.track_lin_vel_xy_yaw_frame_exp,
    #     weight=1.0,
    #     params={"command_name": "base_velocity", "std": 0.5},
    # )
    # track_ang_vel_z_exp = RewTerm(
    #     func=mdp.track_ang_vel_z_world_exp, weight=2.0, params={"command_name": "base_velocity", "std": 0.5}
    # )

    # feet_air_time = RewTerm(
    #     func=mdp.feet_air_time_positive_biped,
    #     weight=0.25,
    #     params={
    #         "command_name": "base_velocity",
    #         "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*_ankle_roll_link"),
    #         "threshold": 0.4,
    #     },
    # )

     # -- rewards
    feet_air_time = RewTerm(
        func=mdp.feet_air_time,
        weight=400.0,
        params={
            "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*_ankle_roll_link"),
            "threshold": 0.5, # the feet air time reward is clipped at this upper bound
        },
    )

    # penalties
    feet_slide = RewTerm(
        func=mdp.feet_slide,
        weight=-0.1,
        params={
            "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*_ankle_roll_link"),
            "asset_cfg": SceneEntityCfg("robot", body_names=".*_ankle_roll_link"),
        },
    )

    # Penalize ankle joint limits
    dof_pos_limits = RewTerm(
        func=mdp.joint_pos_limits,
        weight=-1.0,
        params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_ankle_pitch_joint", ".*_ankle_roll_joint"])},
    )
    # Penalize deviation from default of the joints that are not essential for locomotion
    joint_deviation_hip = RewTerm(
        func=mdp.joint_deviation_l1,
        weight=-0.1,
        params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_hip_yaw_joint", ".*_hip_roll_joint"])},
    )
    joint_deviation_arms = RewTerm(
        func=mdp.joint_deviation_l1,
        weight=-0.1,
        params={
            "asset_cfg": SceneEntityCfg(
                "robot",
                joint_names=[
                    ".*_shoulder_pitch_joint",
                    ".*_shoulder_roll_joint",
                    ".*_shoulder_yaw_joint",
                    ".*_elbow_pitch_joint",
                    ".*_elbow_roll_joint",
                ],
            )
        },
    )
    joint_deviation_fingers = RewTerm(
        func=mdp.joint_deviation_l1,
        weight=-0.05,
        params={
            "asset_cfg": SceneEntityCfg(
                "robot",
                joint_names=[
                    ".*_five_joint",
                    ".*_three_joint",
                    ".*_six_joint",
                    ".*_four_joint",
                    ".*_zero_joint",
                    ".*_one_joint",
                    ".*_two_joint",
                ],
            )
        },
    )
    joint_deviation_torso = RewTerm(
        func=mdp.joint_deviation_l1,
        weight=-0.1,
        params={"asset_cfg": SceneEntityCfg("robot", joint_names="torso_joint")},
    )



@configclass
class P4RLExplorationFlatSceneCfg(P4RLExplorationBaseSceneCfg):
    """Configuration for the flat scene."""
    # Set the terrain to be a flat plane
    terrain: TerrainImporterCfg = TerrainImporterCfg(
            prim_path="/World/ground",
            terrain_type="plane", 
            terrain_generator=None,
            max_init_terrain_level=5,
            collision_group=-1,
            physics_material=sim_utils.RigidBodyMaterialCfg(
                friction_combine_mode="average",
                restitution_combine_mode="average",
                static_friction=1.0,
                dynamic_friction=1.0),
            debug_vis=False)


##
# Environment configuration
##

@configclass
class P4RLExplorationEnvCfg(ManagerBasedRLEnvCfg):
    """Configuration for the flat terrain exploration environment."""

    # Scene settings
    scene: P4RLExplorationBaseSceneCfg = P4RLExplorationBaseSceneCfg() # Note: 3070 8GB doesn't have enough memory for 4096 envs
    # Basic settings
    observations: ObservationsCfg = ObservationsCfg()
    actions: ActionsCfg = ActionsCfg()
    
    # MDP settings (rewards are defined in the child classes)
    terminations: TerminationsCfg = TerminationsCfg()
    events: EventCfg = EventCfg()
    # curriculum: CurriculumCfg = CurriculumCfg() # disable curriculum for p4rl
    # Simulation settings
    decimation: int = 4
    episode_length_s: float = 3.0 # time between resets. There may be multiple command resamplings within this time

    # rewards
    rewards: G1Rewards = G1Rewards()

    # scene
    scene: P4RLExplorationFlatSceneCfg = P4RLExplorationFlatSceneCfg()

    def __post_init__(self):
        """Post initialization."""

        # simulation settings
        self.sim.dt = 0.005
        self.sim.physics_material = self.scene.terrain.physics_material

        # Scene
        self.scene.robot = G1_MINIMAL_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot")
        # self.scene.height_scanner.prim_path = "{ENV_REGEX_NS}/Robot/torso_link"

        # Randomization
        # self.events.push_robot = None
        self.events.add_base_mass = None
        # self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0) 
        # self.events.base_external_force_torque.params["asset_cfg"].body_names = ["torso_link"]
        # self.events.reset_base.params = {
        #     "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)},
        #     "velocity_range": {
        #         "x": (0.0, 0.0),
        #         "y": (0.0, 0.0),
        #         "z": (0.0, 0.0),
        #         "roll": (0.0, 0.0),
        #         "pitch": (0.0, 0.0),
        #         "yaw": (0.0, 0.0),
        #     },
        # } # reset base randomization adapted for more explorative behavior.
        # self.events.base_com = None

        # Rewards
        self.rewards.lin_vel_z_l2.weight = 0.0
        # self.rewards.undesired_contacts = None
        # self.rewards.flat_orientation_l2.weight = -1.0
        self.rewards.action_rate_l2.weight = -0.005
        self.rewards.dof_acc_l2.weight = -1.25e-7
        self.rewards.dof_acc_l2.params["asset_cfg"] = SceneEntityCfg(
            "robot", joint_names=[".*_hip_.*", ".*_knee_joint"]
        )
        self.rewards.dof_torques_l2.weight = -1.5e-7
        self.rewards.dof_torques_l2.params["asset_cfg"] = SceneEntityCfg(
            "robot", joint_names=[".*_hip_.*", ".*_knee_joint", ".*_ankle_.*"]
        )

        # Commands
        # self.commands.base_velocity.ranges.lin_vel_x = (0.0, 1.0)
        # self.commands.base_velocity.ranges.lin_vel_y = (-0.0, 0.0)
        # self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0)

        # terminations
        self.terminations.base_contact.params["sensor_cfg"].body_names = "torso_link"


# @configclass
# class P4RLExplorationRoughTerrainEnvCfg(P4RLExplorationEnvCfg):
#     """ Configuration for the rough terrain exploration environment."""
#     def __post_init__(self):
#         """Post initialization."""
#         super().__post_init__()
#         # Set the terrain to be a rough terrain
#         self.scene.terrain = TerrainImporterCfg(
#                 prim_path="/World/ground",
#                 terrain_type="generator",
#                 terrain_generator=PARKOUR_TERRAIN_CFG,
#                 max_init_terrain_level=None,
#                 collision_group=-1,
#                 physics_material=sim_utils.RigidBodyMaterialCfg(
#                     static_friction=1.0,
#                     dynamic_friction=1.0,
#                 ),
#                 visual_material=sim_utils.MdlFileCfg(
#                     mdl_path="{NVIDIA_NUCLEUS_DIR}/Materials/Base/Masonry/Concrete_Polished.mdl",
#                     project_uvw=True,
#                 ),
#                 debug_vis=False,
#             )
#         self.scene.terrain.terrain_generator.sub_terrains = WALK_SUBTERRAINS
#         self.scene.terrain.terrain_generator.difficulty_range = (0.0, 1.25)
#         self.scene.terrain.terrain_generator.curriculum = True # generate terrains with increasing difficulty, not using RL curriculum

#         self.events.physics_material.params["restitution_range"] = (0.0, 0.5)  # comply with parkour setting

#         self.events.reset_base=EventTerm(
#             func=mdp.reset_root_state_from_terrain,
#             mode="reset",
#             params={
#                 "pose_range": {"roll": (0.0, 0.0), "pitch": (0.0, 0.0), "yaw": (-3.14, 3.14)},
#                 "velocity_range": {
#                     "x": (-0.5, 0.5),
#                     "y": (-0.5, 0.5),
#                     "z": (-0.5, 0.5),
#                     "roll": (-0.5, 0.5),
#                     "pitch": (-0.5, 0.5),
#                     "yaw": (-0.5, 0.5),
#                 },
#             },)
        
#         self.terminations = TerminationsRoughCfg() # use the rough terrain terminations, with 2 extra terms


# # in this environment, the steps between resampling can be calculated, the result is (2.0)/(0.005*4) = 100 steps.
# # so if we collect 200 steps in the play env, we will have num_envs(20) * (200/100) = 40 different commands
