# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from isaaclab.utils import configclass
from p4rl.rsl_rl.rl_cfg import (
    RslRlOnPolicyRunnerCfg, 
    RslRlPpoActorCriticCfg, 
    RslRlPpoAlgorithmCfg, 
    RslRlPpoExtendableActorCriticCfg, 
    RslRlPpoHierarchicalActorCriticCfg, 
    RslRlPpoActorCriticForAnalysisCfg,
    RandomNetworkDistillationCfg,
    InvDynamicsMLPConfig, 
    RslRlPpoActorCriticConstrainedStdCfg
)
from rsl_rl.rsl_rl.addons.kinematics.modules import KinematicSubmoduleConfig
from rsl_rl.rsl_rl.addons.dynamics.modules import DynamicsSubmoduleConfig, DynamicsSubmoduleForHACConfig
from rsl_rl.rsl_rl.addons.dynamics.modules_recurrent import DynamicsSubmoduleConfigRNN
# from rsl_rl.rsl_rl.addons.invdynamics.modules import InvDynamicsSubmoduleConfig
from rsl_rl.rsl_rl.addons.resnet_blocks.modules import ResNetBlocksConfig



@configclass
class AnymalDFlatPPORunnerCfg(RslRlOnPolicyRunnerCfg):
    num_steps_per_env = 24
    max_iterations = 1000
    save_interval = 20
    # max_iterations = 100
    # save_interval = 1
    experiment_name = "exploration_base"
    empirical_normalization = False
    obs_groups = {"policy": ["policy"], "critic": ["policy"]}
    policy = RslRlPpoActorCriticConstrainedStdCfg(
        init_noise_std=0.8,
        noise_std_type="scalar",
        actor_obs_normalization=False,
        critic_obs_normalization=False,
        actor_hidden_dims=[512, 256, 128],
        critic_hidden_dims=[512, 256, 128],
        activation="elu",
        noise_lower_bound= 0.6, 
        noise_upper_bound= 1.0,
    )  # type: ignore
    algorithm = RslRlPpoAlgorithmCfg(
        value_loss_coef=1.0,
        use_clipped_value_loss=True,
        clip_param=0.2,
        entropy_coef=0.004,
        num_learning_epochs=5,
        num_mini_batches=4,
        learning_rate=1e-4,
        schedule="adaptive",
        gamma=0.99,
        lam=0.95,
        desired_kl=0.01,
        max_grad_norm=1.0,
        optimizer="Adam",
        # optimizer="SGD",
        # !!!!NOTE!!!! Follwing inv_dynamics_cfg normally should be disabled, except for sample visualization purpose
        # inv_dynamics_cfg = InvDynamicsMLPConfig(
        #     reward_scale=0.0
        # )
    )

@configclass
class P4RLINVExplorationDataCollectionRunnerCfg(AnymalDFlatPPORunnerCfg):
    def __post_init__(self):
        super().__post_init__()
        self.experiment_name = "exploration_INV_ensemble"
        self.max_iterations = 1500
        self.start_actor_RL_at_iteration = 150
        self.start_critic_RL_at_iteration = 100
        self.algorithm.inv_dynamics_cfg = InvDynamicsMLPConfig(
            mode="inv",
            input_timesteps=5, 
            # NOTE: multiply step time interval (dt used in reward calculation) when setting following two values!
            # step time interval is sim.dt * decimation, which is usually 0.005*4 = 0.02 for Anymal-D
            reward_scale= 10 * 0.02, 
            reward_max= 30 * 0.02,
            # ---------------------------------------------------------------------------------------------------
            ensemble_size=5,
            retrain_interval=20,
        )