# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from dataclasses import MISSING
from typing import Literal, Union, List, Any
from isaaclab_rl.rsl_rl.distillation_cfg import RslRlDistillationAlgorithmCfg, RslRlDistillationStudentTeacherCfg
from isaaclab_rl.rsl_rl.rnd_cfg import RslRlRndCfg
from isaaclab_rl.rsl_rl.symmetry_cfg import RslRlSymmetryCfg

from isaaclab.utils import configclass
from rsl_rl.rsl_rl.addons.kinematics.modules import KinematicSubmoduleConfig

# @configclass
# class ExtendableActorCriticSubmoduleConfig:
#     """Configuration for the Extendable Actor-Critic submodule."""

#     class_name: str = "ExtendableActorCriticSubmodule"
#     """The class name of the submodule."""

#     input_slice: tuple[int, int] = MISSING
#     """The slice of the input to use."""

#     trainable: bool = MISSING
#     """Whether the submodule is trainable."""

#     class_config = MISSING
#     """The configuration of the submodule."""

# TODO: complete all necessary parameters in the configs


"""
jacobian_module_cfg: dict,
        mlp_dims=[128, 128, 128],
        activation="elu",
        init_noise_std=1.0,
        noise_std_type: str = "scalar",

"""


"""
    def __init__(self, 
                 dim_states, 
                 dim_actions, 
                 input_timesteps, 
                 representation_dim, 
                 hidden_dims: list[int], 
                 mode: str = "inv",
                 activation_name: str = "elu",
                 weight_path = None, 
                 finetune_frozen = False,
                 device: str = "cpu",
                 **kwargs):
"""

@configclass
class DLStdActorCriticConfig:
    """Configuration for the Gated Multi Actor-Critic model."""

    class_name: str = "DLStdActorCritic"
    """The class name of the model."""

    inv_module_cfg: dict = MISSING
    """Configuration for the PIDM."""

    mlp_dims: list[int] = [128, 128, 128]
    """The dimensions of the MLP layers."""

    activation: str = "elu"
    """The activation function for the MLP layers."""

    init_noise_std: float = 1.0
    """The initial noise standard deviation for the policy."""


@configclass
class GatedMultiActorCriticConfig:
    """Configuration for the Gated Multi Actor-Critic model."""

    class_name: str = "GatedMultiActorCritic"
    """The class name of the model."""

    trained_policy_path: str = MISSING

    mlp_dims: list[int] = [128, 128, 128]
    """The dimensions of the MLP layers."""

    activation: str = "elu"
    """The activation function for the MLP layers."""

    init_noise_std: float = 1.0
    """The initial noise standard deviation for the policy."""



@configclass
class GatedActorCriticWithINVConfig:
    """Configuration for the Gated Actor-Critic with PIDM."""

    class_name: str = "GatedActorCriticWithINV"
    """The class name of the model."""

    inv_module_cfg: dict = MISSING
    """Configuration for the PIDM."""

    mlp_dims: list[int] = [128, 128, 128]
    """The dimensions of the MLP layers."""

    activation: str = "elu"
    """The activation function for the MLP layers."""

    init_noise_std: float = 1.0
    """The initial noise standard deviation for the policy."""


@configclass
class InvDynamicsMLPConfig:

    class_name: str = "InvDynamicsMLP"

    dim_states: int = MISSING
    """The dimension of the state input."""

    dim_actions: int = MISSING
    """The dimension of the action output."""

    input_timesteps: int = MISSING
    """The number of timesteps in the input sequence."""

    representation_dim: int = 128
    """The dimension of the representation layer."""

    hidden_dims: list[int] = [512, 256, 128]
    """The dimensions of the hidden layers."""

    mode: str = MISSING
    """The mode of the model. Can be 'inv', 'fwd', 'jacobian', 'dl'."""

    weight_path: Union[str, None] = None
    """Path to the pretrained weights. If None, the model will be initialized from scratch."""

    finetune_frozen: bool = False
    """Whether to freeze pretrained weights during finetuning."""

    ensemble_size: int = 1
    """The size of the ensemble. If > 1, the model will be an ensemble of models."""

    reward_scale: float = 0.0
    """The scale of the intrinsic reward. Defaults to 0.0, i.e. no intrinsic reward."""

    reward_max: float = 1.0 * 0.005
    """The maximum value of the intrinsic reward. NOTE: take dt into account when setting this value."""

    retrain_interval: int = 10
    """The interval at which to retrain the model. Defaults to 10."""

# @configclass
# class InvDynamicsMLPConfig:

#     class_name: str = "InvDynamicsMLP"

#     dim_states: int = MISSING
#     """The dimension of the state input."""

#     dim_actions: int = MISSING
#     """The dimension of the action output."""

#     input_timesteps: int = MISSING
#     """The number of timesteps in the input sequence."""

#     representation_dim: int = 64
#     """The dimension of the representation layer."""

#     hidden_dims: list[int] = [128, 128, 128]
#     """The dimensions of the hidden layers."""

#     weight_path: Union[str, None] = None
#     """Path to the pretrained weights. If None, the model will be initialized from scratch."""

#     finetune_frozen: bool = False
#     """Whether to freeze pretrained weights during finetuning."""

#     ensemble_size: int = 1
#     """The size of the ensemble. If > 1, the model will be an ensemble of models."""

#     reward_scale: float = 0.0
#     """The scale of the intrinsic reward. Defaults to 0.0, i.e. no intrinsic reward."""

#     reward_max: float = 1.0 * 0.005
#     """The maximum value of the intrinsic reward. NOTE: take dt into account when setting this value."""

#     retrain_interval: int = 10
#     """The interval at which to retrain the model. Defaults to 10."""


@configclass
class RandomNetworkDistillationCfg:
    """Configuration for the Random Network Distillation (RND) module."""

    num_states: int = MISSING
    """Number of states/inputs to the predictor and target networks."""

    num_outputs: int = MISSING
    """Number of outputs (embedding size) of the predictor and target networks."""

    predictor_hidden_dims: list[int] = MISSING
    """List of hidden dimensions of the predictor network."""

    target_hidden_dims: list[int] = MISSING
    """List of hidden dimensions of the target network."""

    activation: str = "elu"
    """Activation function for the predictor and target networks. Defaults to 'elu'."""

    weight: float = 1.0
    """Scaling factor of the intrinsic reward. Defaults to 0.0."""

    state_normalization: bool = False
    """Whether to normalize the input state. Defaults to False."""

    reward_normalization: bool = False
    """Whether to normalize the intrinsic reward. Defaults to False."""

    weight_schedule: dict | None = None
    """The type of schedule to use for the RND weight parameter. Defaults to None."""


@configclass
class RslRlPpoActorCriticForAnalysisCfg:
    """Configuration for the ActorCriticForAnalysis model."""

    class_name: str = "ActorCriticForAnalysis"
    """The policy class name. Default is ActorCriticForAnalysis."""

    actor_hidden_dims: list[int] = MISSING
    """The hidden dimensions of the actor network."""

    critic_hidden_dims: list[int] = MISSING
    """The hidden dimensions of the critic network."""

    activation: str = MISSING
    """The activation function for the actor and critic networks."""

    init_noise_std: float = MISSING
    """The initial noise standard deviation for the policy."""

    noise_std_type: str = "scalar"
    """The type of noise standard deviation. Can be 'scalar' or 'log'. Default is 'scalar'."""

    layer_to_dynamics: int = MISSING
    """The layers of the actor/critic network to use for dynamics predictions."""

    dim_dynamics_hidden: int = MISSING
    """The hidden dimensions for the dynamics prediction layers."""

    dim_dynamics_prediction: int = MISSING
    """The output dimensions for the dynamics prediction layers."""


@configclass
class RslRlPpoHierarchicalActorCriticCfg:
    """Configuration for the PPO hierarchical actor-critic networks."""

    class_name: str = "HierarchicalActorCritic"
    """The policy class name. Default is HierarchicalActorCritic."""

    high_level_mlp_dims: list[int] = MISSING
    """The hidden dimensions of the high-level actor and critic networks."""

    higher_level_input_dim: int = MISSING
    """The input dimensions of the high-level actor and critic networks."""

    residual_action: bool = MISSING
    """Whether to add skip connection to the action."""

    intermediate_action_space_is_raw: bool = MISSING
    """Whether the intermediate action space is in raw space or latent representation space."""

    init_noise_std: float = MISSING
    """The initial noise standard deviation for the policy."""

    num_intermediate_actions: int = MISSING
    """The dimension of intermediate action space."""

    activation: str = MISSING
    """The activation function for the actor and critic networks."""

    low_level_module_config: Any = MISSING
    """The configuration for the low-level module."""

    low_level_module_lr_scale: float = 0.1
    """The learning rate scale for the low-level module. """

@configclass
class RslRlPpoExtendableActorCriticCfg:
    """Configuration for the PPO ResNet-inspired actor-critic networks.
        This class assume symetrical architecture for actor and critic networks."""

    class_name: str = "ExtendableActorCritic"

    direct_pathway_dim: int = MISSING
    """The dimension of the input that circumvent resnet-like block and get directly fed into
        the final mlp. IMPORTANT: These dimensions must be at the end of the observation vector."""

    final_mlp_dims: list[int] = MISSING
    """The final mlp dimensions of the actor and critic networks."""

    init_noise_std: float = MISSING
    """The initial noise standard deviation for the policy."""

    activation: str = MISSING
    """The activation function for the actor and critic networks."""

    submodule_configs: Union[List, None] = MISSING
    """The configuration for the pretrained kinematic module. Default is None."""



@configclass
class RslRlPpoActorCriticCfg:
    """Configuration for the PPO actor-critic networks."""

    class_name: str = "ActorCritic"
    """The policy class name. Default is ActorCritic."""

    init_noise_std: float = MISSING
    """The initial noise standard deviation for the policy."""

    noise_std_type: Literal["scalar", "log"] = "scalar"
    """The type of noise standard deviation for the policy. Default is scalar."""

    actor_obs_normalization: bool = False
    """Whether to normalize the observation for the actor network. Default is False."""

    critic_obs_normalization: bool = False
    """Whether to normalize the observation for the critic network. Default is False."""

    actor_hidden_dims: list[int] = MISSING
    """The hidden dimensions of the actor network."""

    critic_hidden_dims: list[int] = MISSING
    """The hidden dimensions of the critic network."""

    activation: str = MISSING
    """The activation function for the actor and critic networks."""

@configclass
class RslRlPpoActorCriticConstrainedStdCfg:
    """Configuration for the PPO actor-critic networks with constrained standard deviation."""

    class_name: str = "ActorCriticConstrainedStd"

    init_noise_std: float = MISSING
    """The initial noise standard deviation for the policy."""

    noise_std_type: Literal["scalar", "log"] = "scalar"
    """The type of noise standard deviation for the policy. Default is scalar."""

    actor_obs_normalization: bool = False
    """Whether to normalize the observation for the actor network. Default is False."""

    critic_obs_normalization: bool = False
    """Whether to normalize the observation for the critic network. Default is False."""

    actor_hidden_dims: list[int] = MISSING
    """The hidden dimensions of the actor network."""

    critic_hidden_dims: list[int] = MISSING
    """The hidden dimensions of the critic network."""

    activation: str = MISSING
    """The activation function for the actor and critic networks."""

    noise_lower_bound: float = 0.6
    """The lower bound for the standard deviation of the policy. Default is 0.6"""

    noise_upper_bound: float = 1.0
    """The upper bound for the standard deviation of the policy. Default is 1.0"""
    


@configclass
class RslRlPpoActorCriticRecurrentCfg(RslRlPpoActorCriticCfg):
    """Configuration for the PPO actor-critic networks with recurrent layers."""

    class_name: str = "ActorCriticRecurrent"
    """The policy class name. Default is ActorCriticRecurrent."""

    rnn_type: str = MISSING
    """The type of RNN to use. Either "lstm" or "gru"."""

    rnn_hidden_dim: int = MISSING
    """The dimension of the RNN layers."""

    rnn_num_layers: int = MISSING
    """The number of RNN layers."""


@configclass
class RslRlPpoAlgorithmCfg:
    """Configuration for the PPO algorithm."""

    class_name: str = "PPO"
    """The algorithm class name. Default is PPO."""

    num_learning_epochs: int = MISSING
    """The number of learning epochs per update."""

    num_mini_batches: int = MISSING
    """The number of mini-batches per update."""

    learning_rate: float = MISSING
    """The learning rate for the policy."""

    schedule: str = MISSING
    """The learning rate schedule."""

    gamma: float = MISSING
    """The discount factor."""

    lam: float = MISSING
    """The lambda parameter for Generalized Advantage Estimation (GAE)."""

    entropy_coef: float = MISSING
    """The coefficient for the entropy loss."""

    desired_kl: float = MISSING
    """The desired KL divergence."""

    max_grad_norm: float = MISSING
    """The maximum gradient norm."""

    value_loss_coef: float = MISSING
    """The coefficient for the value loss."""

    use_clipped_value_loss: bool = MISSING
    """Whether to use clipped value loss."""

    clip_param: float = MISSING
    """The clipping parameter for the policy."""

    normalize_advantage_per_mini_batch: bool = False
    """Whether to normalize the advantage per mini-batch. Default is False.

    If True, the advantage is normalized over the mini-batches only.
    Otherwise, the advantage is normalized over the entire collected trajectories.
    """

    optimizer: Literal["Adam", "SGD"] = "Adam"
    """The optimizer to use. Default is Adam."""

    pretrained_module_lr_factor: float | None = None
    """The learning rate factor for the pretrained modules. If None, the pretrained modules will use the same learning rate as the rest of the model. Default is None."""

    rnd_cfg: RslRlRndCfg | None = None
    """The RND configuration. Default is None, in which case RND is not used."""

    symmetry_cfg: RslRlSymmetryCfg | None = None
    """The symmetry configuration. Default is None, in which case symmetry is not used."""

    inv_dynamics_cfg: Union[InvDynamicsMLPConfig, None] = None
    """The inverse dynamics configuration. Default is None, in which case inverse dynamics is not used."""



@configclass
class RslRlBaseRunnerCfg:
    """Base configuration of the runner."""

    seed: int = 42
    """The seed for the experiment. Default is 42."""

    device: str = "cuda:0"
    """The device for the rl-agent. Default is cuda:0."""

    num_steps_per_env: int = MISSING
    """The number of steps per environment per update."""

    max_iterations: int = MISSING
    """The maximum number of iterations."""

    obs_groups: dict[str, list[str]] = MISSING
    """A mapping from observation groups to observation sets.

    The keys of the dictionary are predefined observation sets used by the underlying algorithm
    and values are lists of observation groups provided by the environment.

    For instance, if the environment provides a dictionary of observations with groups "policy", "images",
    and "privileged", these can be mapped to algorithmic observation sets as follows:

    .. code-block:: python

        obs_groups = {
            "policy": ["policy", "images"],
            "critic": ["policy", "privileged"],
        }

    This way, the policy will receive the "policy" and "images" observations, and the critic will
    receive the "policy" and "privileged" observations.

    For more details, please check ``vec_env.py`` in the rsl_rl library.
    """

    clip_actions: float | None = None
    """The clipping value for actions. If None, then no clipping is done. Defaults to None.

    .. note::
        This clipping is performed inside the :class:`RslRlVecEnvWrapper` wrapper.
    """

    save_interval: int = MISSING
    """The number of iterations between saves."""

    experiment_name: str = MISSING
    """The experiment name."""

    run_name: str = ""
    """The run name. Default is empty string.

    The name of the run directory is typically the time-stamp at execution. If the run name is not empty,
    then it is appended to the run directory's name, i.e. the logging directory's name will become
    ``{time-stamp}_{run_name}``.
    """

    logger: Literal["tensorboard", "neptune", "wandb"] = "tensorboard"
    """The logger to use. Default is tensorboard."""

    neptune_project: str = "isaaclab"
    """The neptune project name. Default is "isaaclab"."""

    wandb_project: str = "isaaclab"
    """The wandb project name. Default is "isaaclab"."""

    resume: bool = False
    """Whether to resume a previous training. Default is False.

    This flag will be ignored for distillation.
    """

    load_run: str = ".*"
    """The run directory to load. Default is ".*" (all)."""

    load_checkpoint: str = "model_.*.pt"
    """The checkpoint file to load. Default is ``"model_.*.pt"`` (all).

    If regex expression, the latest (alphabetical order) matching file will be loaded.
    """


@configclass
class RslRlOnPolicyRunnerCfg(RslRlBaseRunnerCfg):
    """Configuration of the runner for on-policy algorithms."""

    class_name: str = "OnPolicyRunner"
    """The runner class name. Default is OnPolicyRunner."""

    start_actor_RL_at_iteration: int = 0
    """The iteration until which the actor is frozen. Default is 0, i.e. the actor is not frozen. This is useful for critic burning-in."""

    start_critic_RL_at_iteration: int = 0
    """The iteration until which the critic is frozen. Default is 0, i.e. the critic is not frozen. This is useful when the initial collected data has unnormally large magnitude of rewards."""
    
    actor_zero_output_pretrain: None | int = None
    """The number of iterations to pretrain the actor with zero output. Default is None, i.e. no pretraining. 
    Specify a positive integer N to pretrain the actor with zero output using observations collected in the first N iterations.
    Note that this number must be less than start_actor_RL_at_iteration.
    """
    
    policy: RslRlPpoActorCriticCfg = MISSING
    """The policy configuration."""

    algorithm: RslRlPpoAlgorithmCfg = MISSING
    """The algorithm configuration."""


@configclass
class RslRlDistillationRunnerCfg(RslRlBaseRunnerCfg):
    """Configuration of the runner for distillation algorithms."""

    class_name: str = "DistillationRunner"
    """The runner class name. Default is DistillationRunner."""

    policy: RslRlDistillationStudentTeacherCfg = MISSING
    """The policy configuration."""

    algorithm: RslRlDistillationAlgorithmCfg = MISSING
    """The algorithm configuration."""



"""
JacobianSubmoduleConfig(
                                            dim_states=33,  # state dimension
                                            dim_actions=12,  # action dimension
                                            dim_states_output=21, 
                                            input_timesteps=window_size,
                                            hidden_dims=[dim_hidden, dim_hidden//2, dim_hidden//4],
                                            representation_dim=dim_hidden//8,
                                            backbone_output_dim=dim_hidden//2,
                                            weight_path=None,  # path to the pretrained weights if available
                                            finetune_frozen=False,
                                            activation_name="elu"
                                        )
"""

@configclass
class JacobianSubmoduleConfig:
    class_name: str = "JacobianMLP"
    """The class name of the submodule."""

    dim_states: int = MISSING
    """The dimension of the state input.""" 

    dim_actions: int = MISSING

    dim_states_output: int = MISSING

    input_timesteps: int = MISSING

    hidden_dims: list[int] = MISSING
    """The dimensions of the hidden layers."""

    representation_dim: int = MISSING
    """The dimension of the representation layer."""

    backbone_output_dim: int = MISSING
    """The output dimension of the backbone network."""

    weight_path: Union[str, None] = None
    """Path to the pretrained weights. If None, the model will be initialized from scratch."""

    finetune_frozen: bool = False
    """Whether to freeze pretrained weights during finetuning."""

    activation_name: str = "elu"
    """The activation function for the submodule."""


@configclass
class JacobianActorCriticCfg:
    """Configuration for the Jacobian Actor-Critic model."""

    class_name: str = "JacobianActorCritic"
    """The class name of the model."""

    jacobian_module_cfg: JacobianSubmoduleConfig = MISSING
    """Configuration for the Jacobian submodule."""

    mlp_dims: list[int] = [128, 128, 128]
    """The dimensions of the MLP layers."""

    activation: str = "elu"
    """The activation function for the MLP layers."""

    init_noise_std: float = 1.0
    """The initial noise standard deviation for the policy."""

@configclass
class P4RLAsymmetricActorCriticCfg:
    """Configuration for the P4RL Asymmetric Actor-Critic model."""

    class_name: str = "P4RLAsymmetricActorCritic"
    """The class name of the model."""

    actor_type: Literal["hamburger", "residual", "gated", "spliced", "mlp"] = "hamburger"
    """The type of actor network. Can be 'hamburger' or 'residual'."""

    critic_type: Literal["hamburger", "mlp"] = "mlp"
    """The type of critic network. Can be 'mlp' or 'pretrained'."""

    actor_submodule_config: InvDynamicsMLPConfig = MISSING

    critic_submodule_config: InvDynamicsMLPConfig = MISSING

    critic_obs_normalization: bool = False

    mlp_block_dims: list[int] = [128, 128, 128]
    """The dimensions of the MLP blocks."""

    activation: str = "elu"
    """The activation function for the MLP blocks."""

    init_noise_std: float = 1.0
    """The initial noise standard deviation for the policy."""

    noise_std_type: Literal["scalar", "log"] = "scalar"
    """The type of noise standard deviation for the policy. Default is scalar."""
