# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import torch
import torch.nn as nn
from torch.distributions import Normal
from typing import Union, List, Any, Dict, Generator
from torch.distributions.distribution import Distribution
from rsl_rl.networks import MLP, EmpiricalNormalization
from tensordict import TensorDict
import math


from rsl_rl.utils import resolve_nn_activation
# from rsl_rl.addons.kinematics.modules import *
# from rsl_rl.addons.resnet_blocks.modules import *
# from rsl_rl.addons.dynamics.modules import *
# from rsl_rl.addons.dynamics.modules_recurrent import *
# from rsl_rl.addons.invdynamics.modules import *
from rsl_rl.addons.invdynamics.inv_dynamics_module import InvDynamicsMLP
# from rsl_rl.addons.invdynamics.jacobian_utils import JacobianMLP
# from rsl_rl.addons.resolve_submodule import resolve_pretrained_module
import gc

def build_mlp(input_dims: int, hidden_dims: list[int], output_dims: int, activation_name: str = "elu"):
        """Builds target and predictor networks"""

        network_layers = []
        # resolve hidden dimensions
        # if dims is -1 then we use the number of observations
        hidden_dims = [input_dims if dim == -1 else dim for dim in hidden_dims]
        # resolve activation function
        activation = resolve_nn_activation(activation_name)
        # first layer
        network_layers.append(nn.Linear(input_dims, hidden_dims[0]))
        network_layers.append(activation)
        # subsequent layers
        for layer_index in range(len(hidden_dims)):
            if layer_index == len(hidden_dims) - 1:
                # last layer
                network_layers.append(nn.Linear(hidden_dims[layer_index], output_dims))
            else:
                # hidden layers
                network_layers.append(nn.Linear(hidden_dims[layer_index], hidden_dims[layer_index + 1]))
                network_layers.append(activation)
        return nn.Sequential(*network_layers)


def print_model_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total}")
    print(f"Trainable parameters: {trainable}")
    

def resolve_submodule(config: dict | Any, device: str = "cuda"):
    """Resolve the pretrained module."
    """
    if not isinstance(config, dict):
        config = config.to_dict()
    config = config.copy()
    module_class = eval(config.pop("class_name"))  # ActorCritic
    pt_module = module_class(**config).to(device)
 
    return pt_module


def count_trainable_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class ActorCriticBase(nn.Module):
    is_recurrent = False
    actor: nn.Module
    critic: nn.Module
    noise_std_type: str = "scalar"
    std: torch.Tensor
    log_std: torch.Tensor
    distribution: Distribution
    actor_obs_normalizer: EmpiricalNormalization
    critic_obs_normalizer: EmpiricalNormalization
    obs_groups: dict

    def __init__(self):
        super().__init__()

    def reset(self, dones=None):
        pass

    def forward(self):
        raise NotImplementedError

    @property
    def action_mean(self):
        return self.distribution.mean

    @property
    def action_std(self):
        return self.distribution.stddev

    @property
    def entropy(self):
        return self.distribution.entropy().sum(dim=-1)
    
    @property
    def exploration_params(self) -> List[torch.Tensor]:
        """Returns a list of the exploration parameters. Typicall for optimizer initialization."""
        if self.noise_std_type == "scalar":
            return [self.std]
        elif self.noise_std_type == "log":
            return [self.log_std]
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

    def update_distribution(self, obs):
        # compute mean
        mean = self.actor(obs)
        # compute standard deviation
        if self.noise_std_type == "scalar":
            std = self.std.expand_as(mean)
        elif self.noise_std_type == "log":
            std = torch.exp(self.log_std).expand_as(mean)
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
        # create distribution
        self.distribution = Normal(mean, std)

    def act(self, obs, **kwargs):
        obs = self.get_actor_obs(obs)
        obs = self.actor_obs_normalizer(obs)
        self.update_distribution(obs)
        return self.distribution.sample()

    def act_inference(self, obs):
        obs = self.get_actor_obs(obs)
        obs = self.actor_obs_normalizer(obs)
        return self.actor(obs)

    def evaluate(self, obs, **kwargs):
        obs = self.get_critic_obs(obs)
        obs = self.critic_obs_normalizer(obs)
        return self.critic(obs)

    def get_actor_obs(self, obs):
        obs_list = []
        for obs_group in self.obs_groups["policy"]:
            if isinstance(obs[obs_group], TensorDict):
                obs_list.append(torch.cat(list(obs[obs_group].values()), dim=-1))
            else:
                obs_list.append(obs[obs_group])
        return torch.cat(obs_list, dim=-1)

    def get_critic_obs(self, obs):
        obs_list = []
        for obs_group in self.obs_groups["critic"]:
            if isinstance(obs[obs_group], TensorDict):
                obs_list.append(torch.cat(list(obs[obs_group].values()), dim=-1))
            else:
                obs_list.append(obs[obs_group])
        return torch.cat(obs_list, dim=-1)

    def get_actions_log_prob(self, actions):
        return self.distribution.log_prob(actions).sum(dim=-1)

    def update_normalization(self, obs):
        if self.actor_obs_normalization:
            actor_obs = self.get_actor_obs(obs)
            self.actor_obs_normalizer.update(actor_obs)
        if self.critic_obs_normalization:
            critic_obs = self.get_critic_obs(obs)
            self.critic_obs_normalizer.update(critic_obs)

    def load_state_dict(self, state_dict, strict=True):
        """Load the parameters of the actor-critic model.

        Args:
            state_dict (dict): State dictionary of the model.
            strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this
                           module's state_dict() function.

        Returns:
            bool: Whether this training resumes a previous training. This flag is used by the `load()` function of
                  `OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation).
        """

        super().load_state_dict(state_dict, strict=strict)
        return True  # training resumes

"""

if not observation.concatenate_terms:

obs = {
"policy": 
TensorDict(
    fields={
        actions: Tensor(shape=torch.Size([4096, 72]), device=cuda:0, dtype=torch.float32, is_shared=True),
        base_ang_vel: Tensor(shape=torch.Size([4096, 18]), device=cuda:0, dtype=torch.float32, is_shared=True),
        base_lin_vel: Tensor(shape=torch.Size([4096, 18]), device=cuda:0, dtype=torch.float32, is_shared=True),
        height_scan: Tensor(shape=torch.Size([4096, 231]), device=cuda:0, dtype=torch.float32, is_shared=True),
        joint_pos: Tensor(shape=torch.Size([4096, 72]), device=cuda:0, dtype=torch.float32, is_shared=True),
        joint_vel: Tensor(shape=torch.Size([4096, 72]), device=cuda:0, dtype=torch.float32, is_shared=True),
        pos_command: Tensor(shape=torch.Size([4096, 4]), device=cuda:0, dtype=torch.float32, is_shared=True),
        projected_gravity: Tensor(shape=torch.Size([4096, 18]), device=cuda:0, dtype=torch.float32, is_shared=True),
        time_to_target: Tensor(shape=torch.Size([4096, 1]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([4096]),
    device=None,
    ), 

"critic":
    ...
    }

Note that histories of one observation term are always concatenated.
"""


class ActorCritic(ActorCriticBase):
    is_recurrent = False

    def __init__(
        self,
        obs,
        obs_groups,
        num_actions,
        actor_obs_normalization=False,
        critic_obs_normalization=False,
        actor_hidden_dims=[256, 256, 256],
        critic_hidden_dims=[256, 256, 256],
        activation="elu",
        init_noise_std=1.0,
        noise_std_type: str = "scalar",
        **kwargs,
    ):
        if kwargs:
            print(
                "ActorCritic.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        super().__init__()

        # get the observation dimensions
        self.obs_groups = obs_groups
        num_actor_obs = 0
        for obs_group in obs_groups["policy"]:
            if isinstance(obs[obs_group], TensorDict):
                obs[obs_group] = torch.cat(list(obs[obs_group].values()), dim=-1)
            assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations."
            num_actor_obs += obs[obs_group].shape[-1]
        num_critic_obs = 0
        for obs_group in obs_groups["critic"]:
            if isinstance(obs[obs_group], TensorDict):
                obs[obs_group] = torch.cat(list(obs[obs_group].values()), dim=-1)
            assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations."
            num_critic_obs += obs[obs_group].shape[-1]

        # actor
        self.actor = MLP(num_actor_obs, num_actions, actor_hidden_dims, activation)
        # actor observation normalization
        self.actor_obs_normalization = actor_obs_normalization
        if actor_obs_normalization:
            self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs)
        else:
            self.actor_obs_normalizer = torch.nn.Identity()
        print(f"Actor MLP: {self.actor}")

        # critic
        self.critic = MLP(num_critic_obs, 1, critic_hidden_dims, activation)
        # critic observation normalization
        self.critic_obs_normalization = critic_obs_normalization
        if critic_obs_normalization:
            self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs)
        else:
            self.critic_obs_normalizer = torch.nn.Identity()
        print(f"Critic MLP: {self.critic}")

        # Action noise
        self.noise_std_type = noise_std_type
        if self.noise_std_type == "scalar":
            self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        elif self.noise_std_type == "log":
            self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

        # Action distribution (populated in update_distribution)
        self.distribution = None
        # disable args validation for speedup
        Normal.set_default_validate_args(False)

    @staticmethod
    # not used at the moment
    def init_weights(sequential, scales):
        [
            torch.nn.init.orthogonal_(module.weight, gain=scales[idx])
            for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))
        ]


class ActorCriticConstrainedStd(ActorCritic):
    def __init__(
        self,
        obs,
        obs_groups,
        num_actions,
        actor_obs_normalization=False,
        critic_obs_normalization=False,
        actor_hidden_dims=[256, 256, 256],
        critic_hidden_dims=[256, 256, 256],
        activation="elu",
        init_noise_std=0.8,
        noise_std_type: str = "scalar",
        noise_lower_bound: float = 0.6,
        noise_upper_bound: float = 1.0,
        **kwargs,
    ):
        super().__init__(
            obs, obs_groups, num_actions, actor_obs_normalization, critic_obs_normalization,
            actor_hidden_dims, critic_hidden_dims, activation, init_noise_std, noise_std_type, **kwargs
        )
        self.noise_upper_bound = noise_upper_bound
        self.noise_lower_bound = noise_lower_bound
        assert init_noise_std < noise_upper_bound and init_noise_std > noise_lower_bound
        init_param = torch.logit(torch.Tensor([(init_noise_std-noise_lower_bound)*1.0/(noise_upper_bound-noise_lower_bound)])) # run the inverse sigmoid to get the initial parameter for std
        if self.noise_std_type == "scalar":
            self.std_param = nn.Parameter(init_param * torch.ones(num_actions)) 
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
        
    @property
    def exploration_params(self) -> List[torch.Tensor]:
        return [self.std_param]

    def update_distribution(self, obs):
        # compute mean
        mean = self.actor(obs)
        # compute standard deviation
        if self.noise_std_type == "scalar":
            std_param = self.std_param.expand_as(mean)
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
        
        # apply upper bound to the standard deviation
        std = torch.sigmoid(std_param) * (self.noise_upper_bound - self.noise_lower_bound) + self.noise_lower_bound

        # create distribution
        self.distribution = Normal(mean, std)




class ActorCriticForAnalysis(ActorCriticBase):
    is_recurrent = False

    def __init__(
        self,
        obs,
        obs_groups,
        num_actions,
        actor_hidden_dims=[256, 256, 256],
        critic_hidden_dims=[256, 256, 256],
        activation="elu",
        init_noise_std=1.0,
        noise_std_type: str = "scalar",
        layer_to_dynamics: int=0, 
        dim_dynamics_hidden=64,
        dim_dynamics_prediction=18, 
        actor_obs_normalization=False,
        critic_obs_normalization=False,
        **kwargs,
    ):
        if kwargs:
            print(
                "ActorCritic.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        super().__init__()
        activation = resolve_nn_activation(activation)

        self.actor_hidden_dims = actor_hidden_dims
        self.critic_hidden_dims = critic_hidden_dims
        assert type(layer_to_dynamics) == int
        self.layer_to_dynamics = layer_to_dynamics

        self.obs_groups = obs_groups



        assert len(obs_groups["policy"]) == 1, "Only single observation group is supported for ActorCriticForAnalysis"
        mlp_input_dim_a = obs[obs_groups["policy"][0]].shape[-1]
        mlp_input_dim_c = obs[obs_groups["critic"][0]].shape[-1]
        # Policy
        actor_layers = []
        actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
        actor_layers.append(activation)
        for layer_index in range(len(actor_hidden_dims)):
            if layer_index == len(actor_hidden_dims) - 1:
                actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], num_actions))
            else:
                actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], actor_hidden_dims[layer_index + 1]))
                actor_layers.append(activation)
        self.actor = nn.Sequential(*actor_layers)

        # Value function
        critic_layers = []
        critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
        critic_layers.append(activation)
        for layer_index in range(len(critic_hidden_dims)):
            if layer_index == len(critic_hidden_dims) - 1:
                critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], 1))
            else:
                critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], critic_hidden_dims[layer_index + 1]))
                critic_layers.append(activation)
        self.critic = nn.Sequential(*critic_layers)

        self.actor_obs_normalization = actor_obs_normalization
        if actor_obs_normalization:
            self.actor_obs_normalizer = EmpiricalNormalization(mlp_input_dim_a)
        else:
            self.actor_obs_normalizer = torch.nn.Identity()
        self.critic_obs_normalization = critic_obs_normalization
        if critic_obs_normalization:
            self.critic_obs_normalizer = EmpiricalNormalization(mlp_input_dim_c)
        else:
            self.critic_obs_normalizer = torch.nn.Identity()

        print(f"Actor MLP: {self.actor}")
        print(f"Critic MLP: {self.critic}")

        # Action noise
        self.noise_std_type = noise_std_type
        if self.noise_std_type == "scalar":
            self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        elif self.noise_std_type == "log":
            self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

        # Action distribution (populated in update_distribution)
        self.distribution = Normal(0, 1) # dummy distribution to initialize. I assume this won't be used.
        # disable args validation for speedup
        Normal.set_default_validate_args(False)

        # add the output layers to get dynamic predictions
        assert actor_hidden_dims[0] == actor_hidden_dims[1] and actor_hidden_dims[1] == actor_hidden_dims[2], "For simplicity, we assume all actor hidden dims are the same."
        self.actor_mu_pred = build_mlp(actor_hidden_dims[0]+mlp_input_dim_a, 
                                       [dim_dynamics_hidden], dim_dynamics_prediction, activation_name="elu")
        # self.actor_mu_pred = build_mlp(2*actor_hidden_dims[0], 
        #                                [dim_dynamics_hidden], dim_dynamics_prediction, activation_name="elu")
        # self.actor_sigma_pred = nn.Sequential(
        #     build_mlp(sum([actor_hidden_dims[i] for i in layer_to_dynamics]+[mlp_input_dim_a]), [dim_dynamics_hidden], dim_dynamics_prediction, activation_name="elu"), 
        #     nn.ReLU()
        # )
        self.critic_mu_pred = build_mlp(critic_hidden_dims[0]+mlp_input_dim_c, 
                                        [dim_dynamics_hidden], dim_dynamics_prediction, activation_name="elu")
        # self.critic_mu_pred = build_mlp(2*critic_hidden_dims[0], 
        #                                 [dim_dynamics_hidden], dim_dynamics_prediction, activation_name="elu")
        # self.critic_sigma_pred = nn.Sequential(
        #     build_mlp(sum([critic_hidden_dims[i] for i in layer_to_dynamics]+[mlp_input_dim_c]), [dim_dynamics_hidden], dim_dynamics_prediction, activation_name="elu"), 
        #     nn.ReLU()
        # )
        self.dynamics_prediction_raw_obs_preprocessing_layers_actor = nn.Sequential(
            nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]),
            activation,)
        self.dynamics_prediction_raw_obs_preprocessing_layers_critic = nn.Sequential(
            nn.Linear(mlp_input_dim_c, actor_hidden_dims[0]),  
            activation,)

    @staticmethod
    # not used at the moment
    def init_weights(sequential, scales):
        [
            torch.nn.init.orthogonal_(module.weight, gain=scales[idx])
            for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))
        ]  

    def get_dynamic_predictions(self, observations):
        """
            Get the dynamic predictions from the actor or critic network.
            observations: (batch_size, num_obs)
        """
        
        # if not critic: # actor
        #     trunk = self.actor
        #     mu_pred_layer = self.actor_mu_pred
        #     # sigma_pred_layer = self.actor_sigma_pred
        # else: # critic
        #     trunk = self.critic
        #     mu_pred_layer = self.critic_mu_pred
        #     # sigma_pred_layer = self.critic_sigma_pred

        trunk = self.actor
        mu_pred_layer = self.actor_mu_pred

        trunk: nn.Sequential
        # get the input to the dynamics module
        x = observations
        if self.layer_to_dynamics >= 0:
            activation_index = 2*self.layer_to_dynamics + 1 # get the indices of the corresponding activation outputs
            # input_to_dynamics = []
            with torch.no_grad():
                for i, layer in enumerate(trunk):
                    x = layer(x) # linear layer
                    if i == activation_index:
                        # input_to_dynamics.append(x)
                        input_to_dynamics_from_intermediate_rep = x
                # input_to_dynamics = torch.cat(input_to_dynamics, dim=-1)
                
        elif self.layer_to_dynamics == -1:
            # use all zero representation, as a baseline of dynamics prediction
            input_to_dynamics_from_intermediate_rep = torch.zeros([observations.shape[0], self.actor_hidden_dims[0]], device=observations.device)

        ## added during rebuttal:
        # input_to_dynamics = torch.cat([input_to_dynamics_from_intermediate_rep, 
        #                                 self.dynamics_prediction_raw_obs_preprocessing_layers_actor(observations)], 
        #                                 dim=-1)
        
        input_to_dynamics = torch.cat([input_to_dynamics_from_intermediate_rep, 
                                        observations], 
                                        dim=-1)

        # get the predictions with gradient
        mu_pred = mu_pred_layer(input_to_dynamics)
        # sigma_pred = sigma_pred_layer(input_to_dynamics) + 0.01 # NOTE: the std bias 0.1 is well tuned and very important.
        return mu_pred
    
    def load_all(self,  path: str):
        loaded_dict = torch.load(path, weights_only=True)
        self.load_state_dict(loaded_dict, strict=True)


    def load_trunk(self, path: str):
        """
            Load the trunk of model (actor, critic) from the given path, ignore dynamics predict layers.
        """
        loaded_dict = torch.load(path, weights_only=True)
        filtered_dict = {key: value for key, value in loaded_dict["model_state_dict"].items() if "mu" not in key and "sigma" not in key}
        load_result = nn.Module.load_state_dict(self, filtered_dict, strict=False)

        # print("Missing keys (not found in checkpoint, so kept default):")
        # print(load_result.missing_keys)
        # print("\nUnexpected keys (found in checkpoint but not used by model):")
        # print(load_result.unexpected_keys)

        missing_keys_to_notice = [key for key in load_result.missing_keys if (not ("mu" in key or "sigma" in key or "dynamics_prediction_raw_obs_preprocessing_layers" in key))]
        if len(missing_keys_to_notice) > 0:
            print("[WARNING]: Noticeable missing keys found in model. Please check the architecture of current instantiated model!.")
        if len(load_result.unexpected_keys) > 0:
            print("[WARNING]: unexpected keys found in checkpoint. Please check the model architecture of the checkpoint!.")

        return loaded_dict["infos"]


class ExtendableModel(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        final_mlp_dims=[128, 128, 128],
        direct_pathway_dim: int = 15, 
        activation="elu",
        submodule_configs: List[dict] | None = None
    ):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.direct_pathway_dim = direct_pathway_dim
        super().__init__()
        activation = resolve_nn_activation(activation)
        
        self.submodules = nn.ModuleDict()
        if submodule_configs is not None:
            for submodule_config in submodule_configs:
                submodule = resolve_submodule(submodule_config)
                self.submodules[submodule.module_type] = submodule

        self.final_mlp_input_dim = direct_pathway_dim + sum([submodule.backbone_output_dim for submodule in self.submodules.values()])

        final_layers = []
        final_layers.append(nn.Linear(self.final_mlp_input_dim, final_mlp_dims[0]))
        final_layers.append(activation)
        for layer_index in range(len(final_mlp_dims)):
            if layer_index == len(final_mlp_dims) - 1:
                final_layers.append(nn.Linear(final_mlp_dims[layer_index], output_dim))
            else:
                final_layers.append(nn.Linear(final_mlp_dims[layer_index], final_mlp_dims[layer_index + 1]))
                final_layers.append(activation)
        self.final_layers = nn.Sequential(*final_layers)
    
    def forward(self, observations):
        final_mlp_inputs = [observations[:, -self.direct_pathway_dim:]] if self.direct_pathway_dim > 0 else []
        for submodule in self.submodules.values():
            submodule_output = submodule.forward_RL(observations)
            final_mlp_inputs.append(submodule_output)
        final_layer_input = torch.cat(final_mlp_inputs, dim=-1)
        out = self.final_layers(final_layer_input)
        return out
    

class ExtendableActorCritic(ActorCriticBase):
    is_recurrent = False

    def __init__(
        self,
        obs,
        obs_groups,
        num_actions,
        direct_pathway_dim=15,
        final_mlp_dims=[128, 128, 128],
        activation="elu",
        init_noise_std=1.0,
        noise_std_type: str = "scalar",
        submodule_configs: List[dict] | None = None, 
        **kwargs,
    ):
        if kwargs:
            print(
                "ExtendableActorCritic.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        super().__init__()

        num_actor_obs = 0
        for obs_group in obs_groups["policy"]:
            assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations."
            num_actor_obs += obs[obs_group].shape[-1]
        num_critic_obs = 0
        for obs_group in obs_groups["critic"]:
            assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations."
            num_critic_obs += obs[obs_group].shape[-1]

        self.obs_groups = obs_groups

        self.direct_pathway_dim = direct_pathway_dim

        self.normalization_setup()

        # Policy
        self.actor = ExtendableModel(
            input_dim=num_actor_obs,
            output_dim=num_actions,
            direct_pathway_dim=direct_pathway_dim,
            final_mlp_dims=final_mlp_dims,
            activation=activation, 
            submodule_configs=submodule_configs
        )
        

        # Value function
        self.critic = ExtendableModel(
            input_dim=num_critic_obs,
            output_dim=1,
            direct_pathway_dim=direct_pathway_dim,
            final_mlp_dims=final_mlp_dims,
            activation=activation, 
            submodule_configs=submodule_configs
        )

        # print(f"Actor MLP: {self.actor}")
        # print(f"Critic MLP: {self.critic}")

        # Action noise
        self.noise_std_type = noise_std_type
        if self.noise_std_type == "scalar":
            self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        elif self.noise_std_type == "log":
            self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

        # Action distribution (populated in update_distribution)
        self.distribution = Normal(0, 1) # dummy distribution to initialize. I assume this won't be used.
        # disable args validation for speedup
        Normal.set_default_validate_args(False)

        # seems that we get better performance without init
        # self.init_memory_weights(self.memory_a, 0.001, 0.)
        # self.init_memory_weights(self.memory_c, 0.001, 0.)
    def unfreeze_all(self):
        for param in self.actor.parameters():
            param.requires_grad = True
        for param in self.critic.parameters():
            param.requires_grad = True

    def normalization_setup(self): 
        """No normalization is used here. """
        self.actor_obs_normalization = False
        self.critic_obs_normalization = False
        self.actor_obs_normalizer = torch.nn.Identity()
        self.critic_obs_normalizer = torch.nn.Identity()


######### Actor Integration for InvDynamicsMLP #########

# def hamburger_prepare_inv_input_from_obs(obs):
#     """
#         obs_t: [batch_size, 273]. See P4RL-Pre-Dynamic-Pedipulation-Flat-Blind-Anymal-D-v0's config for details.
#     """
#     b=obs.shape[0]
#     history_length_in_one_obs = 6
#     lin_vel, ang_vel, grav, jp, jv, a, command_and_exteroception = obs[:, 0:18], obs[:, 18:36], obs[:, 36:54], obs[:, 54:126], obs[:, 126:198], obs[:, 198:270], obs[:, 270:]
#     jp, a, lin_vel, ang_vel, grav, jv = jp.reshape(b, history_length_in_one_obs, -1), a.reshape(b, history_length_in_one_obs, -1), lin_vel.reshape(b, history_length_in_one_obs, -1), ang_vel.reshape(b, history_length_in_one_obs, -1), grav.reshape(b, history_length_in_one_obs, -1), jv.reshape(b, history_length_in_one_obs, -1)
#     obs_segment = torch.cat([lin_vel, ang_vel, grav, jp, jv], dim=-1) # [b, 6, 33]

#     current_step_input = torch.cat([obs_segment[:, -1, :], a[:, -1, :], command_and_exteroception], dim=-1)  # [b, 48]
#                                    # mimic the default policy input
#     return obs_segment[:, 2:, :], a[:, 3:, :], current_step_input

def hamburger_prepare_inv_input_from_obs(obs, hist_len=6, action_dim=12):
    """
        obs_t: [batch_size, 273]. See P4RL-Pre-Dynamic-Pedipulation-Flat-Blind-Anymal-D-v0's config for details.
    """
    b=obs.shape[0]
    i=0
    lin_vel = obs[:, i:i+3*hist_len]
    i+=3*hist_len
    ang_vel = obs[:, i:i+3*hist_len]
    i+=3*hist_len
    grav = obs[:, i:i+3*hist_len]
    i+=3*hist_len
    jp = obs[:, i:i+action_dim*hist_len]
    i+=action_dim*hist_len
    jv = obs[:, i:i+action_dim*hist_len]
    i+=action_dim*hist_len
    a = obs[:, i:i+action_dim*hist_len]
    i+=action_dim*hist_len
    command_and_exteroception = obs[:, i:]
    jp, a, lin_vel, ang_vel, grav, jv = jp.reshape(b, hist_len, -1), a.reshape(b, hist_len, -1), lin_vel.reshape(b, hist_len, -1), ang_vel.reshape(b, hist_len, -1), grav.reshape(b, hist_len, -1), jv.reshape(b, hist_len, -1)
    obs_segment = torch.cat([lin_vel, ang_vel, grav, jp, jv], dim=-1) # [b, 6, 33]

    current_step_input = torch.cat([obs_segment[:, -1, :], a[:, -1, :], command_and_exteroception], dim=-1)  # [b, 48]
                                   # mimic the default policy input
    return obs_segment[:, 2:, :], a[:, 3:, :], current_step_input

class P4RLHamburgerCritic(nn.Module): 
    def __init__(
        self,
        top_mlp_in_features: int,
        representation_dim: int,
        submodule_config: dict, 
        mlp_block_dims=[128, 128, 128],
        activation="elu",
    ):
        super().__init__()
        representation_dim = submodule_config["representation_dim"]
        self.action_dim = submodule_config["dim_actions"]
        self.obs_history_len = submodule_config["input_timesteps"] + 1 
        # This "+1" is purely due to convention, that the observation of the MDP contains one more frame than 
        # the need of the inv dynamics module input.
        self.pretrained_module: InvDynamicsMLP = resolve_submodule(submodule_config)

        self.top_mlp = MLP(top_mlp_in_features, representation_dim, mlp_block_dims, activation)
        self.bottom_mlp = MLP(representation_dim*2, 1, mlp_block_dims, activation)

        # self.pre_in_linear = nn.Linear(representation_dim, representation_dim)
        # self.skip_in_linear = nn.Linear(representation_dim, representation_dim)

    #     self._initialize_weights()

    # @torch.no_grad()
    # def _initialize_weights(self):
    #     # for the purpose of better symmetric action distribution at the beginning of training, initialize the linear 
    #     # layers with smaller weights than default.
    #     for m in self.bottom_mlp:
    #         if isinstance(m, nn.Linear):
    #             # nn.init.kaiming_uniform_(m.weight, a=math.sqrt(23)) # smaller init
    #             # nn.init.constant_(m.bias, 0)
    #             m.weight *= 0.01
    #     for m in self.top_mlp:
    #         if isinstance(m, nn.Linear):
    #             # nn.init.kaiming_uniform_(m.weight, a=math.sqrt(23)) # smaller init
    #             # nn.init.constant_(m.bias, 0)
    #             m.weight *= 0.01
        
    def forward(self, observations):
        x_cut, a_cut, current_step_input = hamburger_prepare_inv_input_from_obs(observations, self.obs_history_len, self.action_dim)
        top_out = self.top_mlp(current_step_input)

        # pre_in = self.pre_in_linear(top_out)
        # skip_in = self.skip_in_linear(top_out)
        
        # only for debugging: save the top_out tensor
        ####
        # if hasattr(self, "save_top_out_counter"):
        #     self.save_top_out_counter += 1                
        # else:
        #     self.save_top_out_counter = 0
        # top_out_cpu = top_out.detach().cpu()
        # torch.save(top_out_cpu, f"logs/analysis/high_level_policy_search_latent/top_out_it_{self.save_top_out_counter:04d}.pt")
        # del top_out_cpu
        # gc.collect()  # Force Python to free memory
        ####

        # pre_out = self.pretrained_module.forward_RL(x_cut, a_cut, pre_in)
        # bottom_in = torch.cat([skip_in, pre_out], dim=-1)

        pre_out = self.pretrained_module.forward_RL(x_cut, a_cut, top_out)
        bottom_in = torch.cat([top_out, pre_out], dim=-1)
        out = self.bottom_mlp(bottom_in)

        return out
    
    def freeze_pretrained_module(self):
        """Save the requires_grad state of each parameter for later restoration purpose, and freeze the pretrained module."""
        
        self.requires_grad_state = {name: param.requires_grad for name, param in self.pretrained_module.named_parameters()}
        for param in self.pretrained_module.parameters():
            param.requires_grad = False
        print("Frozen the pretrained module in P4RLHamburgerCritic.")

    def reset_trainable_state_of_pretrained_module(self):
        """Restore the requires_grad state of each parameter in the pretrained module."""
        if hasattr(self, "requires_grad_state"):
            for name, param in self.pretrained_module.named_parameters():
                if name in self.requires_grad_state:
                    param.requires_grad = self.requires_grad_state[name]
            print("Restored the trainable state of the pretrained module in P4RLHamburgerCritic.")
        else:
            raise RuntimeError("No requires_grad_state found. Skipping restoration of pretrained module's trainable state.")


class P4RLHamburgerActor(nn.Module): 
    def __init__(
        self,
        top_mlp_in_features: int,
        action_dim: int,
        submodule_config: dict, 
        mlp_block_dims=[128, 128, 128],
        activation="elu",
    ):
        
        super().__init__()
        representation_dim = submodule_config["representation_dim"]
        self.action_dim = submodule_config["dim_actions"]
        self.obs_history_len = submodule_config["input_timesteps"] + 1 

        # export to onnx needs the attribute "input_dim"
        command_and_exteroception_dim = top_mlp_in_features - (3*self.action_dim + 9)
        self.input_dim = command_and_exteroception_dim + (3*self.action_dim + 9)*self.obs_history_len

        # This "+1" is purely due to convention, that the observation of the MDP contains one more frame than 
        # the need of the inv dynamics module input.
        self.pretrained_module: InvDynamicsMLP = resolve_submodule(submodule_config)

        self.top_mlp = MLP(top_mlp_in_features, representation_dim, mlp_block_dims, activation)
        self.bottom_mlp = MLP(representation_dim*2, action_dim, mlp_block_dims, activation)

        # self.bottom_mlp[-1].weight.data.zero_()
        # self.bottom_mlp[-1].bias.data.zero_()

        # self.pre_in_linear = nn.Linear(representation_dim, representation_dim)
        # self.skip_in_linear = nn.Linear(representation_dim, representation_dim)

    #     self._initialize_weights()

    # @torch.no_grad()
    # def _initialize_weights(self):
    #     # for the purpose of better symmetric action distribution at the beginning of training, initialize the linear 
    #     # layers with smaller weights than default.
    #     for m in self.bottom_mlp:
    #         if isinstance(m, nn.Linear):
    #             # nn.init.kaiming_uniform_(m.weight, a=math.sqrt(23)) # smaller init
    #             # nn.init.constant_(m.bias, 0)
    #             m.weight *= 0.01
    #     for m in self.top_mlp:
    #         if isinstance(m, nn.Linear):
    #             # nn.init.kaiming_uniform_(m.weight, a=math.sqrt(23)) # smaller init
    #             # nn.init.constant_(m.bias, 0)
    #             m.weight *= 0.01
        
    
    def forward(self, observations):
        x_cut, a_cut, current_step_input = hamburger_prepare_inv_input_from_obs(observations, self.obs_history_len, self.action_dim)
        top_out = self.top_mlp(current_step_input)

        # pre_in = self.pre_in_linear(top_out)
        # skip_in = self.skip_in_linear(top_out)
        
        # only for debugging: save the top_out tensor
        ####
        # if hasattr(self, "save_top_out_counter"):
        #     self.save_top_out_counter += 1                
        # else:
        #     self.save_top_out_counter = 0
        # top_out_cpu = top_out.detach().cpu()
        # torch.save(top_out_cpu, f"logs/analysis/high_level_policy_search_latent/top_out_it_{self.save_top_out_counter:04d}.pt")
        # del top_out_cpu
        # gc.collect()  # Force Python to free memory
        ####

        # pre_out = self.pretrained_module.forward_RL(x_cut, a_cut, pre_in)
        # bottom_in = torch.cat([skip_in, pre_out], dim=-1)

        pre_out = self.pretrained_module.forward_RL(x_cut, a_cut, top_out)
        bottom_in = torch.cat([top_out, pre_out], dim=-1)
        out = self.bottom_mlp(bottom_in)

        # ---   try clamping the action
        # out = torch.tanh(out)*5 
        # ---   try clamping the action

        return out
        # return torch.zeros_like(out)
    

# class P4RLResidualActor(nn.Module): 
#     def __init__(
#         self,
#         top_mlp_in_features: int,
#         action_dim: int,
#         submodule_config: dict, 
#         mlp_block_dims=[128, 128, 128],
#         activation="elu",
#     ):
        
#         super().__init__()
#         representation_dim = submodule_config["representation_dim"]
#         self.high_level_policy = MLP(top_mlp_in_features, representation_dim, mlp_block_dims, activation)
#         self.pretrained_module: InvDynamicsMLP = resolve_submodule(submodule_config)
#         self.residual_policy = MLP(top_mlp_in_features+action_dim, action_dim, mlp_block_dims, activation)
#         # initialize the last layer of the residual policy to zero
#         self.residual_policy[-1].weight.data.zero_()
#         self.residual_policy[-1].bias.data.zero_()
#         print("initialized P4RLResidualActor's residual policy last layer to zero.")

    
#     def forward(self, observations):
#         x_cut, a_cut, current_step_input = hamburger_prepare_inv_input_from_obs(observations)
#         implicit_goal = self.high_level_policy(current_step_input)
#         pre_out = self.pretrained_module.forward_RL_get_action(x_cut, a_cut, implicit_goal)
#         bottom_in = torch.cat([current_step_input, pre_out.detach().clone()], dim=-1)
#         residual_out = self.residual_policy(bottom_in)
#         out = pre_out + residual_out
#         return out
    

# class P4RLGatedActor(nn.Module):
#     def __init__(
#         self,
#         top_mlp_in_features,
#         action_dim,
#         submodule_config: dict, 
#         mlp_block_dims=[128, 128, 128],
#         activation: str="elu",
#         simple_splice: bool = False,  # if True, use a simple splice architecture without action residual policy and gating network
#     ):
#         super().__init__()

#         self.maneuver_pretrain_policy = MLP(top_mlp_in_features, submodule_config["representation_dim"], mlp_block_dims, activation)
#         assert submodule_config["mode"] == "inv"
#         self.pretrained_module: InvDynamicsMLP = resolve_submodule(submodule_config)

#         self.gating_value: torch.Tensor = torch.empty(0)

#         self.simple_splice = simple_splice
#         if not simple_splice:
#             self.action_policy = MLP(top_mlp_in_features, action_dim, mlp_block_dims, activation)
#             self.gating_network = MLP(top_mlp_in_features, 1, mlp_block_dims, activation)

#             self.action_policy[-1].weight.data.zero_()
#             self.action_policy[-1].bias.data.zero_()
        


#     def forward(self, observations):
#         """
#         Forward pass of the gated actor.
#         observations: (batch_size, num_obs)
#         returns: (batch_size, num_actions)
#         """

#         x_cut, a_cut, current_step_input = hamburger_prepare_inv_input_from_obs(observations)
#         next_obs_target = self.maneuver_pretrain_policy(current_step_input)
#         action_inv = self.pretrained_module.forward_RL_get_action(x_cut, a_cut, next_obs_target)

#         if self.simple_splice:
#             return action_inv
#         else:
#             action_a_policy = self.action_policy(current_step_input)
#             self.gating_value = nn.functional.sigmoid(self.gating_network(current_step_input))
#             # combine the predictions using the gating value
#             actions = self.gating_value * action_a_policy + (1 - self.gating_value) * action_inv
#             return actions


class P4RLAsymmetricActorCritic(ActorCriticBase):
    is_recurrent = False

    def __init__(
        self,
        obs,
        obs_groups,
        num_actions,
        actor_submodule_config: dict, 
        critic_submodule_config: dict,
        critic_obs_normalization=False,
        actor_type="hamburger",  # "hamburger" or "residual" or "mlp"
        critic_type="mlp", # "mlp" or "hamburger"
        mlp_block_dims=[128, 128, 128],
        activation="elu",
        init_noise_std=1.0,
        noise_std_type: str = "scalar",
        **kwargs,
    ):
        if kwargs:
            print(
                "ExtendableActorCritic.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        super().__init__()

        num_actor_obs = 0
        for obs_group in obs_groups["policy"]:
            assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations."
            num_actor_obs += obs[obs_group].shape[-1]
        num_critic_obs = 0
        for obs_group in obs_groups["critic"]:
            assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations."
            num_critic_obs += obs[obs_group].shape[-1]

        self.obs_groups = obs_groups

        single_frame_obs_dim = actor_submodule_config["dim_states"] + actor_submodule_config["dim_actions"] # add dim of last_action. should be 45 for anymal and 120 for g1
        top_mlp_in_features = single_frame_obs_dim + (num_actor_obs - single_frame_obs_dim*(actor_submodule_config["input_timesteps"]+1)) # 45 is the number of features of basic proprioception observation set. 6 is the default number of history length per observation term.

        self.normalization_setup()

        if actor_type == "hamburger":
            self.actor = P4RLHamburgerActor(
                top_mlp_in_features=top_mlp_in_features,
                action_dim=num_actions,
                submodule_config=actor_submodule_config,
                mlp_block_dims=mlp_block_dims,
                activation=activation
            )
        elif actor_type == "residual":
            self.actor = P4RLResidualActor(
                top_mlp_in_features=top_mlp_in_features,
                action_dim=num_actions,
                submodule_config=actor_submodule_config,
                mlp_block_dims=mlp_block_dims,
                activation=activation
            )

        elif actor_type == "gated":
            self.actor = P4RLGatedActor(
                top_mlp_in_features=top_mlp_in_features,
                action_dim=num_actions,
                submodule_config=actor_submodule_config,
                mlp_block_dims=mlp_block_dims,
                activation=activation,
                simple_splice=False  
            )

        elif actor_type == "spliced":
            self.actor = P4RLGatedActor(
                top_mlp_in_features=top_mlp_in_features,
                action_dim=num_actions,
                submodule_config=actor_submodule_config,
                mlp_block_dims=mlp_block_dims,
                activation=activation,
                simple_splice=True  
            )

        elif actor_type == "mlp":
            self.actor = MLP(num_actor_obs, num_actions, mlp_block_dims, activation)
        else:
            raise ValueError(f"Unknown actor type: {actor_type}. Should be 'hamburger' or 'residual'")

        # do not use obs normalization for actor
        self.actor_obs_normalization = False
        self.actor_obs_normalizer = torch.nn.Identity()

        # critic
        if critic_type == "mlp":
            self.critic = MLP(num_critic_obs, 1, mlp_block_dims, activation)
        elif critic_type == "hamburger":
            # use the pretrained inv dynamics module as critic
            self.critic = P4RLHamburgerCritic(
                top_mlp_in_features=top_mlp_in_features,
                representation_dim=mlp_block_dims[0],
                submodule_config=critic_submodule_config,
                mlp_block_dims=mlp_block_dims,
                activation=activation
            )
            # do not use obs normalization for critic
            self.critic_obs_normalization = False
            self.critic_obs_normalizer = torch.nn.Identity()
            print(f"Critic Pretrained: {self.critic}")
            assert critic_obs_normalization is False, "pretrained inv dynamics module does not support obs normalization."
        else:
            raise ValueError(f"Unknown critic type: {critic_type}. Should be 'mlp' or 'pretrained'")

        # critic observation normalization
        self.critic_obs_normalization = critic_obs_normalization
        if critic_obs_normalization:
                self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs)
        else:
            self.critic_obs_normalizer = torch.nn.Identity()
        
        print(f"Actor MLP: {self.actor}")
        print(f"Critic MLP: {self.critic}")

        # Action noise
        self.noise_std_type = noise_std_type
        if self.noise_std_type == "scalar":
            self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        elif self.noise_std_type == "log":
            self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

        # Action distribution (populated in update_distribution)
        self.distribution = Normal(0, 1) # dummy distribution to initialize. I assume this won't be used.
        # disable args validation for speedup
        Normal.set_default_validate_args(False)

        # seems that we get better performance without init
        # self.init_memory_weights(self.memory_a, 0.001, 0.)
        # self.init_memory_weights(self.memory_c, 0.001, 0.)
    def unfreeze_all(self):
        for param in self.actor.parameters():
            param.requires_grad = True
        for param in self.critic.parameters():
            param.requires_grad = True

    def normalization_setup(self): 
        """No normalization is used here. """
        self.actor_obs_normalization = False
        self.critic_obs_normalization = False
        self.actor_obs_normalizer = torch.nn.Identity()
        self.critic_obs_normalizer = torch.nn.Identity()        


    def get_actor_pretrained_params(self):
        if hasattr(self.actor, "pretrained_module"):
            pretrained_params=list(self.actor.pretrained_module.parameters())
        return pretrained_params
    
    def get_critic_pretrained_params(self):
        if hasattr(self.critic, "pretrained_module"):
            pretrained_params=list(self.critic.pretrained_module.parameters())
        return pretrained_params



class JacobianActorWrapper(nn.Module):
    """
    A wrapper for the JacobianMLP to be used as an actor in the ActorCritic framework.
    It allows the actor to output actions based on the current state and action residuals.
    """
    def __init__(self, jacobian_module_cfg):
        super().__init__()
        self.jacobian_mlp: JacobianMLP = resolve_submodule(jacobian_module_cfg)

    def forward(self, obs):
        return self.jacobian_mlp.forward_RL(obs)


class JacobianActorCritic(ActorCriticBase):
    is_recurrent = False

    def __init__(
        self,
        num_actor_obs,
        num_critic_obs,
        num_actions,
        jacobian_module_cfg: dict,
        mlp_dims=[128, 128, 128],
        activation="elu",
        init_noise_std=1.0,
        noise_std_type: str = "scalar",
        **kwargs,
    ):
        if kwargs:
            print(
                "JacobianActorCritic.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        super().__init__()

        # Policy
        self.actor = JacobianActorWrapper(
            jacobian_module_cfg=jacobian_module_cfg
        )
        
        # Value function
        activation = resolve_nn_activation(activation)
        critic_layers = []
        critic_layers.append(nn.Linear(num_critic_obs, mlp_dims[0]))
        critic_layers.append(activation)
        for layer_index in range(len(mlp_dims)):
            if layer_index == len(mlp_dims) - 1:
                critic_layers.append(nn.Linear(mlp_dims[layer_index], 1))
            else:
                critic_layers.append(nn.Linear(mlp_dims[layer_index], mlp_dims[layer_index + 1]))
                critic_layers.append(activation)
        self.critic = nn.Sequential(*critic_layers)


        # print(f"Actor MLP: {self.actor}")
        # print(f"Critic MLP: {self.critic}")

        # Action noise
        self.noise_std_type = noise_std_type
        if self.noise_std_type == "scalar":
            self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        elif self.noise_std_type == "log":
            self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

        # Action distribution (populated in update_distribution)
        self.distribution = Normal(0, 1) # dummy distribution to initialize. I assume this won't be used.
        # disable args validation for speedup
        Normal.set_default_validate_args(False)

        # seems that we get better performance without init
        # self.init_memory_weights(self.memory_a, 0.001, 0.)
        # self.init_memory_weights(self.memory_c, 0.001, 0.)

    def unfreeze_all(self):
        self.actor.jacobian_mlp.unfreeze_all()



    # def clip_grad_norm(self, max_norm, finetune=False):
    #     """
    #         If finetuning, this function SEPARATELY clips the gradients of the pretrained module and the output module 
    #         of respectively and the actor and critic model.
    #         If not finetune, this function falls back to the default behavior of clipping the gradients of the entire
    #         model, to ensure the consistency with the baseline.
    #     """
    #     if finetune:
    #         nn.utils.clip_grad_norm_(self.actor.residual_blocks.parameters(), max_norm*self.pretrain_clipping_scale)
    #         nn.utils.clip_grad_norm_(self.actor.final_layers.parameters(), max_norm)
    #         nn.utils.clip_grad_norm_(self.critic.residual_blocks.parameters(), max_norm*self.pretrain_clipping_scale)
    #         nn.utils.clip_grad_norm_(self.critic.final_layers.parameters(), max_norm)
    #     else:
    #         nn.utils.clip_grad_norm_(self.parameters(), max_norm) 
    
    # def freeze_pretrained_submodules(self):
    #     for param in self.actor.residual_blocks.parameters():
    #         param.requires_grad = False
    #     for param in self.critic.residual_blocks.parameters():
    #         param.requires_grad = False
    #     self._pretrained_modules_frozen = True

    # def unfreeze_pretrained_submodules(self):
    #     for param in self.actor.residual_blocks.parameters():
    #         param.requires_grad = True
    #     for param in self.critic.residual_blocks.parameters():
    #         param.requires_grad = True
    #     self._pretrained_modules_frozen = False
 

class HierarchicalModel(nn.Module):
    """
        In a HierarchicalModel, the model is composed of a high-level module which is often initialized from scratch, 
        and a low-level module which is often initialized from a pretrained model.

        Note that the input to high-level module must be at the end of the observation tensor, in consecutive 
        [input_dim] dimensions. 
    
    """
    def __init__(
        self,
        input_dim: int,
        intermediate_action_dim: int,
        final_output_dim: int,
        higher_level_input_dim: int,
        high_level_mlp_dims=[128, 128, 128],
        activation="elu",
        intermediate_action_space_is_raw: bool = False,
        residual_action: bool = False,
        low_level_module_config: Dict | None = None,
        is_critic=False,
        low_level_module_lr_scale: float = 0.1,
    ):
        self.input_dim = input_dim
        self.intermediate_action_dim = intermediate_action_dim
        self.residual_action = residual_action
        self.final_output_dim = final_output_dim
        self.intermediate_action_space_is_raw = intermediate_action_space_is_raw
        self.higher_level_input_dim = higher_level_input_dim
        self.is_critic = is_critic
        self.low_level_module_lr_scale = low_level_module_lr_scale
        super().__init__()
        activation = resolve_nn_activation(activation)
        
        assert low_level_module_config is not None, "low_level_module_config must be provided for HierarchicalModel"
        self.low_lovel_module = resolve_submodule(low_level_module_config)

        high_level_layers = []
        high_level_layers.append(nn.Linear(self.higher_level_input_dim, high_level_mlp_dims[0]))
        high_level_layers.append(activation)
        for layer_index in range(len(high_level_mlp_dims)-1):
            high_level_layers.append(nn.Linear(high_level_mlp_dims[layer_index], high_level_mlp_dims[layer_index + 1]))
            high_level_layers.append(activation)
        self.high_level_layers = nn.Sequential(*high_level_layers)

        if intermediate_action_space_is_raw:
            self.intermediate_action_layer = nn.Linear(high_level_mlp_dims[-1], intermediate_action_dim)
        else:
            self.intermediate_action_layer = nn.Linear(high_level_mlp_dims[-1], 64) # assumed 64 is latent space dim in inv-dynamics module; this must be consistent with the inv-dynamics module.

        if residual_action:
            self.residual_action_layer = nn.Linear(high_level_mlp_dims[-1], final_output_dim)
 
    def forward(self, observations):
        intermediate_action_latent = self.high_level_layers(observations[:, -self.higher_level_input_dim:])
        intermediate_action = self.intermediate_action_layer(intermediate_action_latent)

        out = self.low_lovel_module.forward_RL(observations, intermediate_action, self.intermediate_action_space_is_raw, self.is_critic)

        if self.residual_action:
            out = out + self.residual_action_layer(intermediate_action_latent)
        else:
            out = out
        return out
    
    def get_optimizer_params(self, base_lr):
        """
            Returns the parameters of the actor and critic networks.
        """
        return [
            {"params": self.high_level_layers.parameters(), "lr": base_lr},
            {"params": self.intermediate_action_layer.parameters(), "lr": base_lr},
            {"params": self.residual_action_layer.parameters(), "lr": base_lr},
            {"params": self.low_lovel_module.get_optim_group_for_base_lr(), "lr": base_lr},
            {"params": self.low_lovel_module.get_optim_group_for_reduced_lr(), "lr": base_lr*self.low_level_module_lr_scale},
        ]


class HierarchicalActorCritic(ActorCriticBase):
    is_recurrent = False

    def __init__(
        self,
        num_actor_obs,
        num_critic_obs,
        num_actions,
        num_intermediate_actions,
        intermediate_action_space_is_raw: bool,
        low_level_module_lr_scale: float,
        high_level_mlp_dims=[128, 128, 128],
        residual_action: bool = False,
        activation="elu",
        init_noise_std=1.0,
        noise_std_type: str = "scalar",
        low_level_module_config: Dict | None = None,
        higher_level_input_dim: int = 48,
        **kwargs,
    ):
        if kwargs:
            print(
                "HierarchicalActorCritic.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        super().__init__()

        self.actor = HierarchicalModel(
            input_dim=num_actor_obs,
            intermediate_action_dim=num_intermediate_actions,
            final_output_dim=num_actions,
            high_level_mlp_dims=high_level_mlp_dims,
            activation=activation, 
            intermediate_action_space_is_raw=intermediate_action_space_is_raw,
            higher_level_input_dim=higher_level_input_dim,
            residual_action=residual_action,
            low_level_module_config=low_level_module_config,
            is_critic=False,
            low_level_module_lr_scale=low_level_module_lr_scale
        )
        # Value function
        self.critic = HierarchicalModel(
            input_dim=num_critic_obs,
            intermediate_action_dim=num_intermediate_actions,
            final_output_dim=1,
            high_level_mlp_dims=high_level_mlp_dims,
            activation=activation, 
            intermediate_action_space_is_raw=intermediate_action_space_is_raw,
            higher_level_input_dim=higher_level_input_dim,
            residual_action=residual_action,
            low_level_module_config=low_level_module_config,
            is_critic=True,
            low_level_module_lr_scale=low_level_module_lr_scale
        )

        # Action noise
        self.noise_std_type = noise_std_type
        if self.noise_std_type == "scalar":
            self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        elif self.noise_std_type == "log":
            self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

        # Action distribution (populated in update_distribution)
        self.distribution = Normal(0, 1) # dummy distribution to initialize. I assume this won't be used.
        # disable args validation for speedup
        Normal.set_default_validate_args(False)

    def get_optimizer_params(self, base_lr):
        """
            Returns the parameters of the actor and critic networks.
        """
        if self.noise_std_type == "scalar": 
            std_config = [{"params": self.std, "lr": base_lr}]
        elif self.noise_std_type == "log":
            std_config = [{"params": self.log_std, "lr": base_lr}]
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

        return self.actor.get_optimizer_params(base_lr) + self.critic.get_optimizer_params(base_lr) + std_config
    

# class GatedActorWithINV(nn.Module):
#     def __init__(
#         self,
#         next_obs_policy: nn.Module,
#         action_policy: nn.Module,
#         gating_network: nn.Module,
#         inv_dynamics_module: InvDynamicsMLP,
#     ):
#         super().__init__()
#         self.next_obs_policy = next_obs_policy
#         self.action_policy = action_policy
#         self.gating_network = gating_network
#         self.inv_dynamics_module = inv_dynamics_module

#         self.gating_value: torch.Tensor = torch.empty(0)


#     def get_current_timestep_observation(self, obs): # TODO
#         """
#             obs_t: [batch_size, 273]. See P4RL-Pre-Dynamic-Pedipulation-Flat-Blind-Anymal-D-v0's config for details.
#         """
#         b=obs.shape[0]
#         history_length_in_one_obs = 6
#         lin_vel, ang_vel, grav, jp, jv, a, command = obs[:, 0:18], obs[:, 18:36], obs[:, 36:54], obs[:, 54:126], obs[:, 126:198], obs[:, 198:270], obs[:, 270:273]
#         jp, a, lin_vel, ang_vel, grav, jv = jp.reshape(b, history_length_in_one_obs, -1), a.reshape(b, history_length_in_one_obs, -1), lin_vel.reshape(b, history_length_in_one_obs, -1), ang_vel.reshape(b, history_length_in_one_obs, -1), grav.reshape(b, history_length_in_one_obs, -1), jv.reshape(b, history_length_in_one_obs, -1)
#         obs_segment = torch.cat([lin_vel, ang_vel, grav, jp, jv], dim=-1) # [b, 6, 33]
#         current_step_input = torch.cat([obs_segment[:, -1, :], a[:, -1, :], command], dim=-1)  # [b, 48]
#         return current_step_input


#     def forward(self, observations):
#         """
#         Forward pass of the gated actor.
#         observations: (batch_size, num_obs)
#         returns: (batch_size, num_actions)
#         """

#         x_cut, a_cut, current_step_input = hamburger_prepare_inv_input_from_obs(observations)

#         # compute the gating value
#         self.gating_value, self.std = self.gating_network(current_step_input)
#         # compute the next observation prediction
#         next_obs_target = self.next_obs_policy(current_step_input)
#         # compute the action prediction
#         action_a_policy = self.action_policy(current_step_input)

#         action_inv = self.inv_dynamics_module.forward_RL_get_action(x_cut, a_cut, next_obs_target)

#         # combine the predictions using the gating value
#         actions = self.gating_value * action_a_policy + (1 - self.gating_value) * action_inv

#         return actions
    


class GatedActorEnsemble(nn.Module):
    def __init__(
        self,
        gating_network: nn.Module,
        action_policy_random: nn.Module,
        action_policy_expert: nn.Module,
    ):
        super().__init__()
        self.gating_network = gating_network
        self.action_policy_random = action_policy_random
        self.action_policy_expert = action_policy_expert

        self.action_policy_expert.requires_grad_(False)  # freeze the expert policy, so that it won't be updated during training.

        self.gating_value: torch.Tensor = torch.empty(0)


    def get_current_timestep_observation(self, obs): # TODO
        """
            obs_t: [batch_size, 273]. See P4RL-Pre-Dynamic-Pedipulation-Flat-Blind-Anymal-D-v0's config for details.
        """
        b=obs.shape[0]
        history_length_in_one_obs = 6
        lin_vel, ang_vel, grav, jp, jv, a, command = obs[:, 0:18], obs[:, 18:36], obs[:, 36:54], obs[:, 54:126], obs[:, 126:198], obs[:, 198:270], obs[:, 270:273]
        jp, a, lin_vel, ang_vel, grav, jv = jp.reshape(b, history_length_in_one_obs, -1), a.reshape(b, history_length_in_one_obs, -1), lin_vel.reshape(b, history_length_in_one_obs, -1), ang_vel.reshape(b, history_length_in_one_obs, -1), grav.reshape(b, history_length_in_one_obs, -1), jv.reshape(b, history_length_in_one_obs, -1)
        obs_segment = torch.cat([lin_vel, ang_vel, grav, jp, jv], dim=-1) # [b, 6, 33]
        current_step_input = torch.cat([obs_segment[:, -1, :], a[:, -1, :], command], dim=-1)  # [b, 48]
        return current_step_input


    def forward(self, observations):
        """
        Forward pass of the gated actor.
        observations: (batch_size, num_obs)
        returns: (batch_size, num_actions)
        """

        current_timestep_observation = self.get_current_timestep_observation(observations)

        self.gating_value, self.std = self.gating_network(current_timestep_observation)
        action_policy_r = self.action_policy_random(current_timestep_observation)
        action_policy_e = self.action_policy_expert(current_timestep_observation)

        actions = self.gating_value * action_policy_r + (1 - self.gating_value) * action_policy_e

        return actions
    


class GatingAndStdNetwork(nn.Module):
    """
    A network that outputs both the gating value and the standard deviation for the action distribution.
    This is used in the GatedActorWithINV to compute the actions.
    """
    def __init__(self, backend: nn.Sequential, action_dim: int, min_std: float = 0.2, max_std: float = 1.0):
        super().__init__()
        self.backend = backend
        self.gating_output_head = nn.Sequential(
            nn.ELU(),
            nn.Linear(backend[-1].out_features, 1),  # output a single value for gating
            nn.Sigmoid()
        )
        self.std_output_head = nn.Sequential(
            nn.ELU(),
            nn.Linear(backend[-1].out_features, action_dim),  # output std for each action dimension
            nn.Sigmoid()  # ensure std is positive
        )

        self.min_std = min_std
        self.max_std = max_std

    def forward(self, current_timestep_observations):
        backend_out = self.backend(current_timestep_observations)
        gating_value = self.gating_output_head(backend_out)
        std_network_out = self.std_output_head(backend_out)
        # scale the std to be within the range [min_std, max_std]
        std_value = self.min_std + (self.max_std - self.min_std) * std_network_out

        return gating_value, std_value


class DLStdActorCritic(ActorCriticBase):
    is_recurrent = False

    def __init__(
        self,
        num_actor_obs,
        num_critic_obs,
        num_actions,
        inv_module_cfg: dict,
        mlp_dims=[128, 128, 128],
        activation="elu",
        init_noise_std=1.0,
        noise_std_type: str = "scalar",
        **kwargs,
    ):
        if kwargs:
            print(
                "DLStdActorCritic.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        super().__init__()
        
        self.dl: InvDynamicsMLP = resolve_submodule(inv_module_cfg)
        self.dl.requires_grad_(False) 
        assert self.dl.mode == "dl"

        # Policy
        self.actor = nn.Sequential(*self.get_MLP_list(48, num_actions, mlp_dims, activation))
        self.critic = nn.Sequential(*self.get_MLP_list(num_critic_obs, 1, mlp_dims, activation))

        # Action noise
        self.noise_std_type = noise_std_type
        if self.noise_std_type == "scalar":
            self.std_scale = nn.Parameter(init_noise_std * torch.ones(num_actions))
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

        # Action distribution (populated in update_distribution)
        self.distribution = Normal(0, 1) # dummy distribution to initialize. I assume this won't be used.
        # disable args validation for speedup
        Normal.set_default_validate_args(False)

        # seems that we get better performance without init
        # self.init_memory_weights(self.memory_a, 0.001, 0.)
        # self.init_memory_weights(self.memory_c, 0.001, 0.)

    def get_current_timestep_observation(self, obs): # TODO
        b=obs.shape[0]
        history_length_in_one_obs = 6
        lin_vel, ang_vel, grav, jp, jv, a, command = obs[:, 0:18], obs[:, 18:36], obs[:, 36:54], obs[:, 54:126], obs[:, 126:198], obs[:, 198:270], obs[:, 270:273]
        jp, a, lin_vel, ang_vel, grav, jv = jp.reshape(b, history_length_in_one_obs, -1), a.reshape(b, history_length_in_one_obs, -1), lin_vel.reshape(b, history_length_in_one_obs, -1), ang_vel.reshape(b, history_length_in_one_obs, -1), grav.reshape(b, history_length_in_one_obs, -1), jv.reshape(b, history_length_in_one_obs, -1)
        obs_segment = torch.cat([lin_vel, ang_vel, grav, jp, jv], dim=-1) # [b, 6, 33]
        current_step_input = torch.cat([obs_segment[:, -1, :], a[:, -1, :], command], dim=-1)  # [b, 48]
        return current_step_input


    def update_distribution(self, observations):
        # compute mean
        mean = self.actor(self.get_current_timestep_observation(observations))
        w, b = self.dl.forward_RL(observations) # w: [num_envs, num_actions]
        inclination = torch.arctan(1/torch.abs(w)) # inclination must be in [0, pi/2]
        coeff = 1/torch.abs(w)
        normalized_coeff = coeff / torch.std(coeff, dim=0, keepdim=True)  # normalize w to avoid division by zero
        stds = self.std_scale[None, :] / normalized_coeff
        stds = torch.clamp(stds, min=0.2, max=1.5)
        # compute standard deviation
        if self.noise_std_type == "scalar":
            self.distribution = Normal(mean, stds)
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
        # create distribution
        

    @classmethod
    def get_MLP_list(cls, input_dim, output_dim, mlp_dims, activation) -> List[nn.Module]:
        """
        Helper function to create a MLP with the given dimensions and activation function.
        """
        activation = resolve_nn_activation(activation)
        layers = []
        layers.append(nn.Linear(input_dim, mlp_dims[0]))
        layers.append(activation)
        for layer_index in range(len(mlp_dims)):
            if layer_index == len(mlp_dims) - 1:
                layers.append(nn.Linear(mlp_dims[layer_index], output_dim))
            else:
                layers.append(nn.Linear(mlp_dims[layer_index], mlp_dims[layer_index + 1]))
                layers.append(activation)
        return layers


    def unfreeze_all(self):
        pass




# gated INV actor critic
class GatedActorCriticWithINV(ActorCriticBase):
    is_recurrent = False

    def __init__(
        self,
        obs,
        obs_groups,
        num_actions,
        inv_module_cfg: dict,
        mlp_dims=[128, 128, 128],
        activation="elu",
        init_noise_std=1.0,
        noise_std_type: str = "scalar",
        **kwargs,
    ):
        if kwargs:
            print(
                "JacobianActorCritic.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        super().__init__()

        num_actor_obs = 0
        for obs_group in obs_groups["policy"]:
            assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations."
            num_actor_obs += obs[obs_group].shape[-1]
        num_critic_obs = 0
        for obs_group in obs_groups["critic"]:
            assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations."
            num_critic_obs += obs[obs_group].shape[-1]

        self.obs_groups = obs_groups
        self.normalization_setup()

        # TODO 
        current_timestep_obs_dim = 48

        next_obs_policy = nn.Sequential(*self.get_MLP_list(current_timestep_obs_dim, inv_module_cfg["representation_dim"], mlp_dims, activation))
        action_policy = nn.Sequential(*self.get_MLP_list(current_timestep_obs_dim, num_actions, mlp_dims, activation))
        gating_network = GatingAndStdNetwork(
                        backend=nn.Sequential(*(self.get_MLP_list(current_timestep_obs_dim, 128, mlp_dims, activation))), 
                        action_dim=num_actions,
                        min_std=0.2,  
                        max_std=1.0)  # gating output

        # Policy
        self.actor: GatedActorWithINV = GatedActorWithINV(
            next_obs_policy=next_obs_policy,
            action_policy=action_policy,
            gating_network=gating_network,
            inv_dynamics_module=resolve_submodule(inv_module_cfg)
        )

        self.critic = nn.Sequential(*self.get_MLP_list(num_critic_obs, 1, mlp_dims, activation))


        # Action noise
        self.noise_std_type = noise_std_type
        if self.noise_std_type == "scalar":
            self.nominal_std = nn.Parameter(init_noise_std * torch.ones(num_actions)) # this is only for compliance purpose, not used. 
            pass
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

        # Action distribution (populated in update_distribution)
        self.distribution = Normal(0, 1) # dummy distribution to initialize. I assume this won't be used.
        # disable args validation for speedup
        Normal.set_default_validate_args(False)

        # seems that we get better performance without init
        # self.init_memory_weights(self.memory_a, 0.001, 0.)
        # self.init_memory_weights(self.memory_c, 0.001, 0.)

    def normalization_setup(self): 
        """No normalization is used here. """
        self.actor_obs_normalization = False
        self.critic_obs_normalization = False
        self.actor_obs_normalizer = torch.nn.Identity()
        self.critic_obs_normalizer = torch.nn.Identity()       


    def update_distribution(self, observations):
        # compute mean
        mean = self.actor(observations)
        # compute standard deviation
        if self.noise_std_type == "scalar":
            self.distribution = Normal(mean, self.actor.std)
        elif self.noise_std_type == "log":
            raise NotImplementedError("GatedActorCriticWithINV does not support log std at the moment.")
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
        # create distribution

    @property
    def std(self):
        return self.actor.std
    
    @property
    def exploration_params(self):
        return self.nominal_std
        
    @classmethod
    def get_MLP_list(cls, input_dim, output_dim, mlp_dims, activation) -> List[nn.Module]:
        """
        Helper function to create a MLP with the given dimensions and activation function.
        """
        activation = resolve_nn_activation(activation)
        layers = []
        layers.append(nn.Linear(input_dim, mlp_dims[0]))
        layers.append(activation)
        for layer_index in range(len(mlp_dims)):
            if layer_index == len(mlp_dims) - 1:
                layers.append(nn.Linear(mlp_dims[layer_index], output_dim))
            else:
                layers.append(nn.Linear(mlp_dims[layer_index], mlp_dims[layer_index + 1]))
                layers.append(activation)
        return layers
    
    @property
    def gating_value(self):
        """
        Returns the gating value computed in the last forward pass.
        """
        return self.actor.gating_value

    def unfreeze_all(self):
        pass




class GatedMultiActorCritic(ActorCriticBase):
    is_recurrent = False

    def __init__(
        self,
        num_actor_obs,
        num_critic_obs,
        num_actions,
        trained_policy_path: str,
        mlp_dims=[128, 128, 128],
        activation="elu",
        init_noise_std=1.0,
        noise_std_type: str = "scalar",
        **kwargs,
    ):
        if kwargs:
            print(
                "JacobianActorCritic.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        super().__init__()

        # TODO 
        current_timestep_obs_dim = 48

        self.critic = nn.Sequential(*self.get_MLP_list(num_critic_obs, 1, mlp_dims, activation))


        # Action noise
        self.noise_std_type = noise_std_type
        if self.noise_std_type == "scalar":
            self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        elif self.noise_std_type == "log":
            self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

        # Action distribution (populated in update_distribution)
        self.distribution = Normal(0, 1) # dummy distribution to initialize. I assume this won't be used.
        # disable args validation for speedup
        Normal.set_default_validate_args(False)

        # seems that we get better performance without init
        # self.init_memory_weights(self.memory_a, 0.001, 0.)
        # self.init_memory_weights(self.memory_c, 0.001, 0.)

    def load_expert(self, model: nn.Module, path: str):
        """
        Load the expert policy from the given path.
        The path should point to a file containing the state_dict of the expert policy.
        """
        loaded_dict = torch.load(path, weights_only=True)["model_state_dict"]
        processed_dict = {}
        for k, v in loaded_dict.items():
            if k.startswith("actor.final_layers."):
                # remove the "actor." prefix from the key
                processed_dict[k.replace("actor.final_layers.", "")] = v

        model.load_state_dict(processed_dict, strict=True)
        print(f"Expert policy loaded from {path}")


    def update_distribution(self, observations):
        # compute mean
        mean = self.actor(observations)
        # compute standard deviation
        if self.noise_std_type == "scalar":
            self.distribution = Normal(mean, self.actor.std)
        elif self.noise_std_type == "log":
            raise NotImplementedError("GatedActorCriticWithINV does not support log std at the moment.")
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
        # create distribution
        

    @classmethod
    def get_MLP_list(cls, input_dim, output_dim, mlp_dims, activation) -> List[nn.Module]:
        """
        Helper function to create a MLP with the given dimensions and activation function.
        """
        activation = resolve_nn_activation(activation)
        layers = []
        layers.append(nn.Linear(input_dim, mlp_dims[0]))
        layers.append(activation)
        for layer_index in range(len(mlp_dims)):
            if layer_index == len(mlp_dims) - 1:
                layers.append(nn.Linear(mlp_dims[layer_index], output_dim))
            else:
                layers.append(nn.Linear(mlp_dims[layer_index], mlp_dims[layer_index + 1]))
                layers.append(activation)
        return layers
    
    @property
    def gating_value(self):
        """
        Returns the gating value computed in the last forward pass.
        """
        return self.actor.gating_value

    def unfreeze_all(self):
        pass