from isaaclab.utils import configclass
from p4rl.rsl_rl.rl_cfg import (
    RslRlOnPolicyRunnerCfg, 
    RslRlPpoAlgorithmCfg, 
    RslRlPpoExtendableActorCriticCfg, 
    RslRlPpoHierarchicalActorCriticCfg, 
    InvDynamicsMLPConfig, 
    P4RLAsymmetricActorCriticCfg,
    RslRlPpoActorCriticForAnalysisCfg
)

from rsl_rl.rsl_rl.addons.kinematics.modules import KinematicSubmoduleConfig
from rsl_rl.rsl_rl.addons.resnet_blocks.modules import ResNetBlocksConfig
from rsl_rl.rsl_rl.addons.dynamics.modules import DynamicsSubmoduleConfig
# from rsl_rl.rsl_rl.addons.invdynamics.modules import InvDynamicsSubmoduleConfig

@configclass
class AnymalDForAnalysisRunnerCfg(RslRlOnPolicyRunnerCfg):
    num_steps_per_env = 24
    max_iterations = 800
    save_interval = 100
    experiment_name = "velocity_baseline_for_analysis"
    empirical_normalization = False
    policy = RslRlPpoActorCriticForAnalysisCfg(
        init_noise_std=1.0,
        actor_hidden_dims=[256, 256, 256],
        critic_hidden_dims=[256, 256, 256],
        activation="elu",
        layer_to_dynamics=[0, 1, 2],
        dim_dynamics_hidden=64,
        dim_dynamics_prediction=12,
        
    )
    algorithm = RslRlPpoAlgorithmCfg(
        value_loss_coef=1.0,
        use_clipped_value_loss=True,
        clip_param=0.2,
        entropy_coef=0.004,
        num_learning_epochs=5,
        num_mini_batches=4,
        learning_rate=1e-4,
        schedule="adaptive",
        gamma=0.99,
        lam=0.95,
        desired_kl=0.01,
        max_grad_norm=1.0,
        optimizer="Adam",
        # optimizer="SGD",
    )


@configclass
class AnymalDRoughPPORunnerCfg(RslRlOnPolicyRunnerCfg):
    num_steps_per_env = 24
    max_iterations = 1500
    save_interval = 50
    experiment_name = "velocity_baseline_final"
    empirical_normalization = False
    obs_groups = {"policy": ["policy"]}
    policy = RslRlPpoExtendableActorCriticCfg(
        direct_pathway_dim=48, # ATTENTION 72 with cartesian, 48 for all default dimension, 15 for only past action and command, 3 for only command
        final_mlp_dims=[128, 128, 128],
        init_noise_std=1.0,
        activation="elu",
        submodule_configs=[]
    )
    algorithm = RslRlPpoAlgorithmCfg(
        value_loss_coef=1.0,
        use_clipped_value_loss=True,
        clip_param=0.2,
        entropy_coef=0.005,
        num_learning_epochs=5,
        num_mini_batches=4,
        learning_rate=1.0e-3,
        schedule="adaptive",
        gamma=0.99,
        lam=0.95,
        desired_kl=0.01,
        max_grad_norm=1.0,
        optimizer="Adam",
    )


@configclass
class AnymalDFlatPPORunnerCfg(AnymalDRoughPPORunnerCfg):
    def __post_init__(self):
        super().__post_init__()

        self.max_iterations = 800
        self.experiment_name = "velocity_EAC_3_layer_flat"

        # self.max_iterations = 100
        # self.save_interval = 1


@configclass
class P4RLAsymmetricPPORunnerExplorationMixedCfg(AnymalDFlatPPORunnerCfg):
    def __post_init__(self):
        super().__post_init__()
        self.experiment_name = "velocity_INV_exploration_mixed"
        self.policy = P4RLAsymmetricActorCriticCfg(
            actor_submodule_config=InvDynamicsMLPConfig(
                dim_states=33, 
                dim_actions=12, 
                input_timesteps=5,
                representation_dim=256,
                mode="inv",
                weight_path="p4rl_assets/inv_dynamics_new/absolute_0906_mixed.pt",
                finetune_frozen=False,
            ),
            critic_submodule_config=InvDynamicsMLPConfig(
                dim_states=33, 
                dim_actions=12, 
                input_timesteps=5,
                representation_dim=256, # for fwd
                mode='inv',
                weight_path="p4rl_assets/inv_dynamics_new/absolute_0906_mixed.pt",
                finetune_frozen=False,
            ),
            actor_type="hamburger",
            critic_type="hamburger",
            mlp_block_dims=[512, 256, 128],
            # mlp_block_dims=[128, 128, 128],
            activation="elu",
            init_noise_std=1.0,
        ) # type: ignore


@configclass
class P4RLAsymmetricPPORunnerRandCfg(AnymalDFlatPPORunnerCfg):
    def __post_init__(self):
        super().__post_init__()
        self.experiment_name = "velocity_INV_rand"
        self.policy = P4RLAsymmetricActorCriticCfg(
            actor_submodule_config=InvDynamicsMLPConfig(
                dim_states=33, 
                dim_actions=12, 
                input_timesteps=5,
                representation_dim=256,
                mode="inv",
                weight_path="random_init",
                finetune_frozen=False,
            ),
            critic_submodule_config=InvDynamicsMLPConfig(
                dim_states=33, 
                dim_actions=12, 
                input_timesteps=5,
                representation_dim=256, # for fwd
                mode='inv',
                weight_path="random_init",
                finetune_frozen=False,
            ),
            actor_type="hamburger",
            critic_type="hamburger",
            mlp_block_dims=[512, 256, 128],
            # mlp_block_dims=[128, 128, 128],
            activation="elu",
            init_noise_std=1.0,
        ) # type: ignore



# @configclass
# class P4RLAsymmetricPPORunnerCfg(AnymalDFlatPPORunnerCfg):
#     def __post_init__(self):
#         super().__post_init__()
#         # only for debugging visualization purpose:
#         # self.save_interval = 1
#         # self.max_iterations = 200
#         # . 

#         # p4rl: enforce the training stages: actor-zero-output-pretrain, critic burn-in, RL training
#         # self.actor_zero_output_pretrain = 10
#         # self.start_critic_RL_at_iteration = 10
#         # self.start_actor_RL_at_iteration = 100

#         # without zero-output-pretrain, with critic burn-in
#         # self.actor_zero_output_pretrain = None
#         # self.start_critic_RL_at_iteration = 0
#         # self.start_actor_RL_at_iteration = 100
#         # self.max_iterations = 190

#         # without zero-output-pretrain, without critic burn-in
#         # self.actor_zero_output_pretrain = None
#         # self.start_critic_RL_at_iteration = 0
#         # self.start_actor_RL_at_iteration = 0
#         # self.max_iterations = 150

#         self.experiment_name = "velocity_INV_exploration_mixed"
#         # self.experiment_name = "pedipulation_Hamburger_INV_test"
#         self.policy = P4RLAsymmetricActorCriticCfg(
#             actor_submodule_config=InvDynamicsMLPConfig(
#                 dim_states=33, 
#                 dim_actions=12, 
#                 input_timesteps=5,
#                 representation_dim=256, # for fwd
#                 # representation_dim=128, # for inv
#                 mode="inv",
#                 # mode='fwd',
#                 # weight_path="random_init",
#                 # weight_path="p4rl_assets/inv_dynamics_new/history_5_with_noise.pt",
#                 # weight_path="p4rl_assets/inv_dynamics_new/absolute_0811_pedi_output_clamped.pt",
#                 # weight_path="p4rl_assets/inv_dynamics_new/absolute_exploration_0813.pt",
#                 # weight_path="p4rl_assets/inv_dynamics_new/absolute_0831_exploration_rough.pt",
#                 weight_path="p4rl_assets/inv_dynamics_new/absolute_0906_mixed.pt",

#                 # weight_path="p4rl_assets/inv_dynamics_new/foward.pt",
#                 # weight_path="p4rl_assets/inv_dynamics_new/foward_out_dim_21.pt",
#                 finetune_frozen=False,
#             ),
#             critic_submodule_config=InvDynamicsMLPConfig(
#                 dim_states=33, 
#                 dim_actions=12, 
#                 input_timesteps=5,
#                 representation_dim=256, # for fwd
#                 # representation_dim=128, # for inv
#                 mode='inv',
#                 # mode="fwd",
#                 # weight_path="random_init",
#                 # weight_path="p4rl_assets/inv_dynamics_new/history_5_with_noise.pt",
#                 # weight_path="p4rl_assets/inv_dynamics_new/absolute_0811_pedi_output_clamped.pt",
#                 # weight_path="p4rl_assets/inv_dynamics_new/absolute_exploration_0813.pt",
#                 # weight_path="p4rl_assets/inv_dynamics_new/absolute_0831_exploration_rough.pt",
#                 weight_path="p4rl_assets/inv_dynamics_new/absolute_0906_mixed.pt",

#                 # weight_path="p4rl_assets/inv_dynamics_new/foward.pt",
#                 # weight_path="p4rl_assets/inv_dynamics_new/foward_out_dim_21.pt",
#                 finetune_frozen=False,
#             ),
#             actor_type="hamburger",  # "hamburger", "residual", "gated",  "spliced", "mlp"
#             critic_type="hamburger",
#             mlp_block_dims=[512, 256, 128],
#             # mlp_block_dims=[128, 128, 128],
#             activation="elu",
#             init_noise_std=1.0,
#         ) # type: ignore