# Copyright (c) 2022-2024, The IsaacLab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import math
from dataclasses import MISSING

import isaaclab.sim as sim_utils
from isaaclab.assets import ArticulationCfg, AssetBaseCfg
from isaaclab.envs import ManagerBasedRLEnvCfg
from isaaclab.managers import CurriculumTermCfg as CurrTerm
from isaaclab.managers import ObservationGroupCfg as ObsGroup
from isaaclab.managers import ObservationTermCfg as ObsTerm
from isaaclab.managers import EventTermCfg as EventTerm
from isaaclab.managers import SceneEntityCfg
from isaaclab.managers import TerminationTermCfg as DoneTerm
from isaaclab.scene import InteractiveSceneCfg
from isaaclab.sensors import ContactSensorCfg
from isaaclab.utils import configclass
from isaaclab.utils.noise import AdditiveUniformNoiseCfg as Unoise
from isaaclab_assets.robots.anymal import ANYDRIVE_4_MLP_ACTUATOR_CFG
import os
# from p4rl import P4RL_EXT_DIR # cause "can not import from partially initialized module" error
P4RL_EXT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
import p4rl.tasks.exploration.mdp as mdp
from isaaclab.utils.assets import ISAACLAB_NUCLEUS_DIR
from isaaclab.managers import RewardTermCfg as RewTerm
import torch.nn as nn
from isaaclab.terrains import TerrainImporterCfg
from parkour.config.terrain_cfg import PARKOUR_TERRAIN_CFG
from parkour.config.terrain_cfg import WALK_SUBTERRAINS
from isaaclab_assets.robots.unitree import UNITREE_GO1_CFG  # isort: skip



##
# Scene definition
##

@configclass
class P4RLExplorationBaseSceneCfg(InteractiveSceneCfg):
    """Configuration for the terrain scene with a legged robot."""

    # ground terrain
    terrain = None # defined in child classes
    # robots
    robot: ArticulationCfg = UNITREE_GO1_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot")
    # contact sensors
    contact_forces = ContactSensorCfg(prim_path="{ENV_REGEX_NS}/Robot/.*", history_length=3, track_air_time=True)
    # lights
    light = AssetBaseCfg(
        prim_path="/World/light",
        spawn=sim_utils.DomeLightCfg(
            intensity=1000,
            exposure=1.5,
            color=(1.0, 1.0, 1.0)
        )
    )
    env_spacing = 5
    replicate_physics = False
    num_envs = 4096


##
# MDP settings
##

@configclass
class CommandsCfg:
    """Configuration for the foot position command generator."""

    rle_command = mdp.RLECommandCfg(
        resampling_time_range=(6.0, 6.0),
        dim_latent_space=128,
        std_sampling=1.0,
        normalize_vector=True,  # whether to normalize the sampled vector to unit length
    )


@configclass
class InvInputCfg(ObsGroup):
    """Observations for PIDM group. This space can not include past actions!
    NOTE: corruption is not enabled, so that we can control the noise in training loop. """

    # observation terms (order preserved)
    base_lin_vel = ObsTerm(func=mdp.base_lin_vel) # indices 0:3
    base_ang_vel = ObsTerm(func=mdp.base_ang_vel) # indices 3:6
    projected_gravity = ObsTerm(
        func=mdp.projected_gravity,
    ) # indices 6:9
    
    joint_pos = ObsTerm(func=mdp.joint_pos_rel) # length 12, indices 9:21
    joint_vel = ObsTerm(func=mdp.joint_vel_rel) # length 12, indices 21:33

    def __post_init__(self):
        self.enable_corruption = False # NOTE corruption is not enabled, so that we can control the noise in training loop
        self.concatenate_terms = True


@configclass
class PolicyCfg(ObsGroup):
    """Observations for policy group."""

    # observation terms (order preserved)
    base_lin_vel = ObsTerm(func=mdp.base_lin_vel, noise=Unoise(n_min=-0.1, n_max=0.1)) # indices 0:3
    base_ang_vel = ObsTerm(func=mdp.base_ang_vel, noise=Unoise(n_min=-0.2, n_max=0.2)) # indices 3:6
    projected_gravity = ObsTerm(
        func=mdp.projected_gravity,
        noise=Unoise(n_min=-0.05, n_max=0.05),
    ) # indices 6:9
    
    joint_pos = ObsTerm(func=mdp.joint_pos_rel, noise=Unoise(n_min=-0.01, n_max=0.01)) # length 12, indices 9:21
    joint_vel = ObsTerm(func=mdp.joint_vel_rel, noise=Unoise(n_min=-1.5, n_max=1.5)) # length 12, indices 21:33

    actions = ObsTerm(func=mdp.last_action, clip=(-100.00, 100.00)) # length 12, indices 33:45

    def __post_init__(self):
        self.enable_corruption = True
        self.concatenate_terms = True


@configclass
class ObservationsCfg:
    """Observation specifications for the MDP."""

    # observation groups
    policy: PolicyCfg = PolicyCfg()
    inv_dynamics_input: InvInputCfg = InvInputCfg()

@configclass
class EventCfg:
    """Configuration for randomization."""
    # startup
    physics_material: EventTerm | None = EventTerm(
        func=mdp.randomize_rigid_body_material,
        mode="startup",
        params={
            "asset_cfg": SceneEntityCfg("robot", body_names=".*"),
            "static_friction_range": (0.0, 1.5),  # TODO: legged_gym only has one friction termn, how to choose?
            "dynamic_friction_range": (0.0, 1.5),
            "restitution_range": (0.0, 0.0),
            "num_buckets": 64,
        },
    )

    add_base_mass: EventTerm | None = EventTerm(
        func=mdp.randomize_rigid_body_mass,
        mode="startup",
        params={"asset_cfg": SceneEntityCfg("robot", body_names="base"), "mass_distribution_params": (-5.0, 5.0), "operation": "add"},
    )

    push_foot_constant: EventTerm | None = EventTerm(
        func=mdp.apply_external_force_torque,
        mode="reset",
        params={
            "asset_cfg": SceneEntityCfg("robot", body_names=".*foot"),
            "force_range": (0.0, 12.0),
            "torque_range": (-0.0, 0.0),
        },
    )

    reset_base = EventTerm(
        func=mdp.reset_root_state_uniform,
        mode="reset",
        params={
            "pose_range": {"x": (0.0, 0.0), "y": (0.0, 0.0), "yaw": (-3.14, 3.14)},
            "velocity_range": {
                "x": (0.0, 0.0),
                "y": (0.0, 0.0),
                "z": (0.0, 0.0),
                "roll": (0.0, 0.0),
                "pitch": (0.0, 0.0),
                "yaw": (0.0, 0.0),
            },
        },
    )

    reset_robot_joints = EventTerm(
        func=mdp.reset_joints_by_scale,
        mode="reset",
        params={
            "position_range": (0.5, 1.5),
            "velocity_range": (0.0, 0.0),
        },
    )

    '''

    note the joint limits (those are not enforced in articulation class)
        joint_limits = torch.tensor([[-0.7854,  0.6109],
                                [-0.7854,  0.6109],
                                [-0.6109,  0.7854],
                                [-0.6109,  0.7854],
                                [-9.4248,  9.4248],
                                [-9.4248,  9.4248],
                                [-9.4248,  9.4248],
                                [-9.4248,  9.4248],
                                [-9.4248,  9.4248],
                                [-9.4248,  9.4248],
                                [-9.4248,  9.4248],
                                [-9.4248,  9.4248]], device='cuda:0') 
    '''

    # # interval
    # push_robot: EventTerm | None = EventTerm(
    #     func=mdp.push_by_setting_velocity,
    #     mode="interval",
    #     interval_range_s=(2.0, 4.0),
    #     params={"velocity_range": {"x": (-1.5, 1.5),
    #                                 "y": (-1.5, 1.5),
    #                                 "z": (-1.5, 1.5),
    #                                 "roll": (-1.5, 1.5),
    #                                 "pitch": (-1.5, 1.5),
    #                                 "yaw": (-1.5, 1.5),
    #                                 }},
    # )

    # push_foot_interval: EventTerm | None = EventTerm(
    #     func=mdp.apply_external_force_torque_to_foot,
    #     mode="interval",
    #     interval_range_s=(2.0, 4.0),
    #     params={"asset_cfg": SceneEntityCfg("robot", body_names=".*FOOT"),
    #             "force_range": (-50, 50),
    #             "torque_range": (-0.0, 0.0),
    #             },
    # )


# @configclass
# class ActionsCfg:
#     """Action specifications for the MDP."""

#     joint_pos = mdp.JointPositionActionCfg(asset_name="robot", joint_names=[".*"], scale=0.5, use_default_offset=True)

@configclass
class ActionsCfg:
    """Action specifications for the MDP."""

    # joint_pos = mdp.RelativeJointPositionActionCfg(asset_name="robot", joint_names=[".*"], scale=0.5)
    joint_pos = mdp.JointPositionActionCfg(asset_name="robot", joint_names=[".*"], scale=0.5, use_default_offset=True)


@configclass
class TerminationsCfg:
    """Termination terms for the MDP."""

    time_out = DoneTerm(func=mdp.time_out, time_out=True)

    bad_orientation = DoneTerm(func=mdp.bad_orientation, 
                               params={"limit_angle": math.radians(90.0), "asset_cfg": SceneEntityCfg("robot")})

    base_contact = DoneTerm( # terminate if the base contacts the ground hard enough
        func=mdp.illegal_contact,
        params={"sensor_cfg": SceneEntityCfg("contact_forces", body_names="base"), "threshold": 0.1}, # enlarged threshold
    )


@configclass
class TerminationsRoughCfg(TerminationsCfg):
    def __post_init__(self):
        """Post initialization."""
        super().__post_init__()
        # add terrain related terminations
        self.illegal_force = DoneTerm(
                func=mdp.illegal_contact,
                params={"sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*"), "threshold": 5000.0},
            )
        self.illegal_force_feet = DoneTerm(
            func=mdp.illegal_contact,
            params={"sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*FOOT"), "threshold": 1500.0},
        )


@configclass
class RewardsCfg:
    """Reward terms for the MDP."""
    # -- rewards
    feet_air_time = RewTerm(
        func=mdp.feet_air_time,
        weight=400.0,
        params={
            "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*FOOT"),
            "threshold": 0.5, # the feet air time reward is clipped at this upper bound
        },
    )

    # -- penalties
    dof_torques_l2 = RewTerm(func=mdp.joint_torques_l2, weight=-2.0e-5)
    dof_acc_l2 = RewTerm(func=mdp.joint_acc_l2, weight=-5e-6)
    action_rate_l2 = RewTerm(func=mdp.action_rate_l2, weight=-0.01)
    action_l2 = RewTerm(func=mdp.action_l2, weight=-0.01)
    dof_vel = RewTerm(func=mdp.joint_vel_l2, weight=-0.05)
    collisions = RewTerm(
        func=mdp.undesired_contacts,
        weight=-5.0,
        params={"sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*(hip|thigh|trunk)"), "threshold": 0.1},
    )
    terminations = RewTerm(func=mdp.is_terminated, weight=-80.0)

#  ['trunk', 'FL_hip', 'FL_thigh', 'FL_calf', 'FL_foot', 'FR_hip', 'FR_thigh', 'FR_calf', 'FR_foot', 'RL_hip', 'RL_thigh', 'RL_calf', 'RL_foot', 'RR_hip', 'RR_thigh', 'RR_calf', 'RR_foot']


@configclass
class RewardsRLECfg:
    def __post_init__(self):
        """Post initialization."""
        # P4RL: in play, locoma suite sets corruption to False by default, but this removes all sensory noise and is 
        # less practical. in contrast here we still put it to true.
        self.rle_reward = RewTerm(func=mdp.rle_reward, weight=1.5)


@configclass
class P4RLExplorationFlatSceneCfg(P4RLExplorationBaseSceneCfg):
    """Configuration for the flat scene."""
    # Set the terrain to be a flat plane
    terrain: TerrainImporterCfg = TerrainImporterCfg(
            prim_path="/World/ground",
            terrain_type="plane", 
            terrain_generator=None,
            max_init_terrain_level=5,
            collision_group=-1,
            physics_material=sim_utils.RigidBodyMaterialCfg(
                friction_combine_mode="average",
                restitution_combine_mode="average",
                static_friction=1.0,
                dynamic_friction=1.0),
            debug_vis=False)


##
# Environment configuration
##

@configclass
class P4RLExplorationEnvCfg(ManagerBasedRLEnvCfg):
    """Configuration for the flat terrain exploration environment."""

    # Scene settings
    scene: P4RLExplorationBaseSceneCfg = P4RLExplorationBaseSceneCfg() # Note: 3070 8GB doesn't have enough memory for 4096 envs
    # Basic settings
    observations: ObservationsCfg = ObservationsCfg()
    actions: ActionsCfg = ActionsCfg()
    
    # MDP settings (rewards are defined in the child classes)
    terminations: TerminationsCfg = TerminationsCfg()
    events: EventCfg = EventCfg()
    # curriculum: CurriculumCfg = CurriculumCfg() # disable curriculum for p4rl
    # Simulation settings
    decimation: int = 4
    episode_length_s: float = 3.0 # time between resets. There may be multiple command resamplings within this time

    # rewards
    rewards: RewardsCfg = RewardsCfg()

    # scene
    scene: P4RLExplorationFlatSceneCfg = P4RLExplorationFlatSceneCfg()

    def __post_init__(self):
        """Post initialization."""

        # simulation settings
        self.sim.dt = 0.005
        self.sim.physics_material = self.scene.terrain.physics_material


@configclass
class P4RLExplorationEnvGo1FlatCfg(P4RLExplorationEnvCfg):
    def __post_init__(self):
        # post init of parent
        super().__post_init__()

        # reduce action scale
        self.actions.joint_pos.scale = 0.25

        # event
        self.events.add_base_mass.params["mass_distribution_params"] = (-1.0, 3.0)
        self.events.add_base_mass.params["asset_cfg"].body_names = "trunk"

        # rewards
        self.rewards.feet_air_time.params["sensor_cfg"].body_names = ".*_foot"
        # self.rewards.feet_air_time.weight = 0.01
        self.rewards.dof_torques_l2.weight = -0.0002
        self.rewards.dof_acc_l2.weight = -2.5e-7

        # terminations
        self.terminations.base_contact.params["sensor_cfg"].body_names = "trunk"


@configclass
class P4RLExplorationEnvGo1Cfg(P4RLExplorationEnvGo1FlatCfg):
    def __post_init__(self):
        raise NotImplementedError("Rough terrain config not implemented yet for Go1.")
        return super().__post_init__()
    

@configclass
class P4RLExplorationRoughTerrainEnvCfg(P4RLExplorationEnvCfg):
    """ Configuration for the rough terrain exploration environment."""
    def __post_init__(self):
        """Post initialization."""
        super().__post_init__()
        # Set the terrain to be a rough terrain
        self.scene.terrain = TerrainImporterCfg(
                prim_path="/World/ground",
                terrain_type="generator",
                terrain_generator=PARKOUR_TERRAIN_CFG,
                max_init_terrain_level=None,
                collision_group=-1,
                physics_material=sim_utils.RigidBodyMaterialCfg(
                    static_friction=1.0,
                    dynamic_friction=1.0,
                ),
                visual_material=sim_utils.MdlFileCfg(
                    mdl_path="{NVIDIA_NUCLEUS_DIR}/Materials/Base/Masonry/Concrete_Polished.mdl",
                    project_uvw=True,
                ),
                debug_vis=False,
            )
        self.scene.terrain.terrain_generator.sub_terrains = WALK_SUBTERRAINS
        self.scene.terrain.terrain_generator.difficulty_range = (0.0, 1.25)
        self.scene.terrain.terrain_generator.curriculum = True # generate terrains with increasing difficulty, not using RL curriculum

        self.events.physics_material.params["restitution_range"] = (0.0, 0.5)  # comply with parkour setting

        self.events.reset_base=EventTerm(
            func=mdp.reset_root_state_from_terrain,
            mode="reset",
            params={
                "pose_range": {"roll": (0.0, 0.0), "pitch": (0.0, 0.0), "yaw": (-3.14, 3.14)},
                "velocity_range": {
                    "x": (-0.5, 0.5),
                    "y": (-0.5, 0.5),
                    "z": (-0.5, 0.5),
                    "roll": (-0.5, 0.5),
                    "pitch": (-0.5, 0.5),
                    "yaw": (-0.5, 0.5),
                },
            },)
        
        self.terminations = TerminationsRoughCfg() # use the rough terrain terminations, with 2 extra terms


# # in this environment, the steps between resampling can be calculated, the result is (2.0)/(0.005*4) = 100 steps.
# # so if we collect 200 steps in the play env, we will have num_envs(20) * (200/100) = 40 different commands



# ---------------------- Not used implementations. Correctness not guaranteed ----------------------

@configclass
class ObservationsRNDCfg:
    """Observation specifications for the MDP."""

    # observation groups
    policy: PolicyCfg = PolicyCfg()
    rnd_state: PolicyCfg = PolicyCfg()
    inv_dynamics_input: InvInputCfg = InvInputCfg()


@configclass
class ObservationsRLECfg:
    """Observation specifications for the MDP."""

    @configclass
    class PolicyRLECfg(PolicyCfg):
        def __post_init__(self):
            self.rle_command = ObsTerm(func=mdp.rle_command) # indices 45:45+dim_latent
            self.enable_corruption = True
            self.concatenate_terms = True
            
    # observation groups
    policy: PolicyRLECfg = PolicyRLECfg()
    inv_dynamics_input: InvInputCfg = InvInputCfg()



# @configclass
# class P4RLRLEExplorationFlatEnvCfg(P4RLExplorationEnvCfg):
#     """Configuration for the pedipulation foot position tracking environment in flat terrain."""

#     # We inherit from the pedipulation base environment configuration and modify as needed
    
#     def __post_init__(self) -> None:
#         super().__post_init__() # call parent post init (PedipulationPositionEnvCfg)
#         self.rewards = RewardsRLECfg() # use the RLE rewards configuration
#         self.observations = ObservationsRLECfg()
#         self.commands = CommandsCfg() # use the RLE command configuration


# @configclass
# class P4RLExplorationPositionEnvCfg_PLAY(P4RLExplorationEnvCfg):
#     def __post_init__(self) -> None:
#         # post init of parent
#         super().__post_init__()

#         # make a smaller scene for play
#         self.scene.num_envs = 20
#         self.scene.env_spacing = 2.5
#         # P4RL: in play, locoma suite sets corruption to False by default, but this removes all sensory noise and is 
#         # less practical. in contrast here we still put it to true.
#         self.observations.policy.enable_corruption = True
#         # remove random pushing
#         self.events.push_robot = None
#         self.events.push_foot_interval = None
#         # remove randomization
#         self.events.physics_material = None
#         self.events.add_base_mass = None
#         ###########################################
#         # p4rl: disable push foot constant was not in the main branch config
#         self.events.push_foot_constant = None
#         ###########################################
#         # enable debug visualization
#         # decrease resample interval
#         # self.commands.foot_position.resampling_time_range = (2.0, 2.0)
#         # decrease episode lenght
#         self.episode_length_s = 6.0

#         # increase command domain
#         # self.commands.foot_position.ranges.pos_x = (0.0, 2.0)
#         # self.commands.foot_position.ranges.pos_y = (-1.4, 1.2)