# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from isaaclab.utils import configclass

from p4rl.rsl_rl.rl_cfg import (
    RslRlOnPolicyRunnerCfg, 
    RslRlPpoAlgorithmCfg, 
    RslRlPpoExtendableActorCriticCfg, 
    RslRlPpoHierarchicalActorCriticCfg, 
    InvDynamicsMLPConfig, 
    P4RLAsymmetricActorCriticCfg,
    RslRlPpoActorCriticCfg
)

@configclass
class UnitreeGo1RoughPPORunnerCfg(RslRlOnPolicyRunnerCfg):
    num_steps_per_env = 24
    max_iterations = 1500
    save_interval = 50
    experiment_name = "unitree_go1_rough"
    policy = RslRlPpoActorCriticCfg(
        init_noise_std=1.0,
        actor_obs_normalization=False,
        critic_obs_normalization=False,
        actor_hidden_dims=[512, 256, 128],
        critic_hidden_dims=[512, 256, 128],
        activation="elu",
    )
    algorithm = RslRlPpoAlgorithmCfg(
        value_loss_coef=1.0,
        use_clipped_value_loss=True,
        clip_param=0.2,
        entropy_coef=0.01,
        num_learning_epochs=5,
        num_mini_batches=4,
        learning_rate=1.0e-3,
        schedule="adaptive",
        gamma=0.99,
        lam=0.95,
        desired_kl=0.01,
        max_grad_norm=1.0,
    )


@configclass
class UnitreeGo1FlatPPORunnerCfg(UnitreeGo1RoughPPORunnerCfg):
    def __post_init__(self):
        super().__post_init__()

        # self.max_iterations = 300
        self.max_iterations = 850
        self.experiment_name = "unitree_go1_flat_vanilla_mlp"
        self.policy.actor_hidden_dims = [128, 128, 128]
        self.policy.critic_hidden_dims = [128, 128, 128]


@configclass
class P4RLAsymmetricPPORunnerRandCfg(UnitreeGo1FlatPPORunnerCfg):
    def __post_init__(self):
        super().__post_init__()
        self.experiment_name = "go1_flat_pidm_rand"
        self.policy = P4RLAsymmetricActorCriticCfg(
            actor_submodule_config=InvDynamicsMLPConfig(
                dim_states=33, 
                dim_actions=12, 
                input_timesteps=5,
                representation_dim=256,
                mode="inv",
                weight_path="random_init",
                finetune_frozen=False,
            ),
            critic_submodule_config=InvDynamicsMLPConfig(
                dim_states=33, 
                dim_actions=12, 
                input_timesteps=5,
                representation_dim=256, # for fwd
                mode='inv',
                weight_path="random_init",
                finetune_frozen=False,
            ),
            actor_type="hamburger",
            critic_type="hamburger",
            mlp_block_dims=[512, 256, 128],
            # mlp_block_dims=[128, 128, 128],
            activation="elu",
            init_noise_std=1.0,
        ) # type: ignore


@configclass
class P4RLAsymmetricPPORunnerPretrainCfg(UnitreeGo1FlatPPORunnerCfg):
    def __post_init__(self):
        super().__post_init__()
        # self.experiment_name = "go1_flat_pidm_pretrain_locomotion_100"
        self.experiment_name = "go1_flat_pidm_pretrain_exploration_240it_100epochs"
        self.policy = P4RLAsymmetricActorCriticCfg(
            actor_submodule_config=InvDynamicsMLPConfig(
                dim_states=33, 
                dim_actions=12, 
                input_timesteps=5,
                representation_dim=256,
                mode="inv",
                # weight_path="p4rl_assets/inv_dynamics_new/rebuttal/go1_locomotion_100it_100epochs.pt",
                weight_path="p4rl_assets/inv_dynamics_new/rebuttal/go1_exploration_240it_100epochs.pt",
                finetune_frozen=False,
            ),
            critic_submodule_config=InvDynamicsMLPConfig(
                dim_states=33, 
                dim_actions=12, 
                input_timesteps=5,
                representation_dim=256, # for fwd
                mode='inv',
                # weight_path="p4rl_assets/inv_dynamics_new/rebuttal/go1_locomotion_100it_100epochs.pt",
                weight_path="p4rl_assets/inv_dynamics_new/rebuttal/go1_exploration_240it_100epochs.pt",
                finetune_frozen=False,
            ),
            actor_type="hamburger",
            critic_type="hamburger",
            mlp_block_dims=[512, 256, 128],
            # mlp_block_dims=[128, 128, 128],
            activation="elu",
            init_noise_std=1.0,
        ) # type: ignore