
import torch
import torch.nn as nn
from rsl_rl.rsl_rl.utils.utils import resolve_nn_activation, siren_init
from abc import abstractmethod, ABC
from rsl_rl.rsl_rl.addons.nn_utils import init_lstm, init_gru
import math
from einops import repeat


def prepare_dl_input_from_obs(obs):
    """
        obs_t: [batch_size, 273]. See P4RL-Pre-Dynamic-Pedipulation-Flat-Blind-Anymal-D-v0's config for details.
    """
    b=obs.shape[0]
    history_length_in_one_obs = 6
    lin_vel, ang_vel, grav, jp, jv, a, command = obs[:, 0:18], obs[:, 18:36], obs[:, 36:54], obs[:, 54:126], obs[:, 126:198], obs[:, 198:270], obs[:, 270:273]
    jp, a, lin_vel, ang_vel, grav, jv = jp.reshape(b, history_length_in_one_obs, -1), a.reshape(b, history_length_in_one_obs, -1), lin_vel.reshape(b, history_length_in_one_obs, -1), ang_vel.reshape(b, history_length_in_one_obs, -1), grav.reshape(b, history_length_in_one_obs, -1), jv.reshape(b, history_length_in_one_obs, -1)
    obs_segment = torch.cat([lin_vel, ang_vel, grav, jp, jv], dim=-1) # [b, 6, 33]

    current_step_input = torch.cat([obs_segment[:, -1, :], a[:, -1, :], command], dim=-1)  # [b, 48]
                                   # mimic the default policy input
    return obs_segment[:, 1:, :], a[:, 1:, :], current_step_input 


def prepare_inv_input_from_obs(obs):
    """
        obs_t: [batch_size, 273]. See P4RL-Pre-Dynamic-Pedipulation-Flat-Blind-Anymal-D-v0's config for details.
    """
    b=obs.shape[0]
    history_length_in_one_obs = 6
    lin_vel, ang_vel, grav, jp, jv, a, command = obs[:, 0:18], obs[:, 18:36], obs[:, 36:54], obs[:, 54:126], obs[:, 126:198], obs[:, 198:270], obs[:, 270:273]
    jp, a, lin_vel, ang_vel, grav, jv = jp.reshape(b, history_length_in_one_obs, -1), a.reshape(b, history_length_in_one_obs, -1), lin_vel.reshape(b, history_length_in_one_obs, -1), ang_vel.reshape(b, history_length_in_one_obs, -1), grav.reshape(b, history_length_in_one_obs, -1), jv.reshape(b, history_length_in_one_obs, -1)
    obs_segment = torch.cat([lin_vel, ang_vel, grav, jp, jv], dim=-1) # [b, 6, 33]

    current_step_input = torch.cat([obs_segment[:, -1, :], a[:, -1, :], command], dim=-1)  # [b, 48]
                                   # mimic the default policy input
    return obs_segment[:, 2:, :], a[:, 3:, :], current_step_input



def build_mlp(input_dims: int, hidden_dims: list[int], output_dims: int, activation_name: str = "elu") -> nn.Sequential:
        """Builds target and predictor networks"""

        network_layers = []
        # resolve hidden dimensions
        # if dims is -1 then we use the number of observations
        hidden_dims = [input_dims if dim == -1 else dim for dim in hidden_dims]
        # resolve activation function
        if activation_name=="siren":
            activation_first = resolve_nn_activation("siren_30")
            activation_following = resolve_nn_activation("siren_1")
        else:
            activation_first = resolve_nn_activation(activation_name)
            activation_following = resolve_nn_activation(activation_name)
        # first layer
        network_layers.append(nn.Linear(input_dims, hidden_dims[0]))
        network_layers.append(activation_first)
        # subsequent layers
        for layer_index in range(len(hidden_dims)):
            if layer_index == len(hidden_dims) - 1:
                # last layer
                network_layers.append(nn.Linear(hidden_dims[layer_index], output_dims))
            else:
                # hidden layers
                network_layers.append(nn.Linear(hidden_dims[layer_index], hidden_dims[layer_index + 1]))
                network_layers.append(activation_following)
        mlp_model = nn.Sequential(*network_layers)

        if activation_name=="siren":
            for model in mlp_model.modules():
                if isinstance(model, nn.Linear):
                    siren_init(model, is_first=True if model == mlp_model[0] else False)
        return mlp_model


class ScaleFollowingSigmoid(nn.Module):
  def __init__(self, min=-5.0, max=5.0):
    super(ScaleFollowingSigmoid, self).__init__()
    self.min = min
    self.max = max

  def forward(self, tensor):
    return self.min + tensor * (self.max - self.min)


class DynamicsModule(nn.Module, ABC):
    mode: str 
    """ mode can be either "inv" for inverse dynamics or "fwd" for forward dynamics."""



    """
        Abstract base class for dynamics modules.
        All dynamics modules should inherit from this class and implement the abstract methods.
    """
    
    def __init__(self):
        super().__init__()
    
    @abstractmethod
    def forward(self, x, a, mask):
        """Forward pass of the dynamics module. 

        Args:
            x: Input states sequence. [batch_size, T, input_dim_states]
            a: Input actions sequence. [batch_size, T, input_dim_actions]
            mask: Mask for the input sequence. [batch_size, T]

        Always expect equal length of x and a; it's OK that the target is present in the input, it won't be used. 
        The alignment will always be consistent as in RL training loop; a_{t-1} is paired with s_t. 

        s_{t-1} --**a_{t-1}--> s_t** 
        
        """
        pass

    @abstractmethod
    def forward_RL(self, observations):
        pass

    def get_number_of_trainable_parameters(self) -> int:
        """
            Returns the number of trainable parameters in the model.
        """
        return sum(p.numel() for p in self.parameters() if p.requires_grad)



    


# Define MLP Model
class InvDynamicsMLP(DynamicsModule):
    def __init__(self, 
                 dim_states, 
                 dim_actions, 
                 input_timesteps, 
                 representation_dim, 
                 hidden_dims: list[int], 
                 mode: str = "inv",
                 lstm_core: bool = False,
                 activation_name: str = "elu",
                 weight_path = None, 
                 finetune_frozen = False,
                 device: str = "cpu",
                 **kwargs):
        """
            Initialize the inverse dynamics MLP model.
            Args:
                dim_states: Dimension of the state space.
                dim_actions: Dimension of the action space.
                input_timesteps: Number of timesteps in the input sequence.
                representation_dim: Dimension of the state representation.
                hidden_dims: List of hidden layer dimensions.
                activation_name: Activation function to use in the MLP.

                weight_path: Path to the pretrained weights (optional).
                finetune_frozen: If True, freeze the pretrained weights during fine-tuning.
                device: Device to run the model on. 
        """
        super().__init__()
        if kwargs:
            print(
                "InvDynamicMLP.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        self.module_type = "inv_dynamic"
        self.dim_states = dim_states
        self.dim_actions = dim_actions
        self.device = device # NOTE: not used
        assert input_timesteps >= 2, "Input timesteps must be greater than 2!"
        self.input_timesteps = input_timesteps
        self.mode = mode

        self.hidden_dim = hidden_dims
        self.state_encoder = build_mlp(dim_states, [128], representation_dim, activation_name)
        if input_timesteps > 2:
            self.action_encoder = build_mlp(dim_actions, [128], representation_dim, activation_name)

        if mode == "inv":
            self.state_delta_encoder = build_mlp(dim_states, [128], representation_dim, activation_name)
            if lstm_core:
                self.LSTM_core = nn.LSTM(
                    input_size=dim_states+dim_actions,
                    hidden_size=representation_dim,
                    num_layers=2,
                    batch_first=True,
                )

        penultimate_dim = representation_dim
        self.main_MLP = build_mlp(representation_dim*(input_timesteps*2-2), hidden_dims, penultimate_dim, activation_name)
        
        if mode == "jacobian":
            self.output_layer = nn.Sequential(
                    resolve_nn_activation(activation_name),
                    nn.Linear(penultimate_dim, self.dim_actions*(dim_actions + 1))
            )
        elif mode == "dl":  # stands for decoupled linear
            self.output_layer = nn.Sequential(
                    resolve_nn_activation(activation_name),
                    nn.Linear(penultimate_dim, 2*self.dim_actions)  
            )
            self.backbone_output_dim = 48+12*2 # TODO temporary solution
        elif mode == "inv":
            self.output_layer = nn.Sequential(
                    resolve_nn_activation(activation_name),
                    nn.Linear(penultimate_dim, dim_actions),
                    nn.Sigmoid(),
                    ScaleFollowingSigmoid(min=-5.0, max=5.0)  # scale the output to match the action range
            )
        elif mode == "fwd":
            self.output_layer = nn.Sequential(
                    resolve_nn_activation(activation_name),
                    nn.Linear(penultimate_dim, 21)
            )
        else:
            raise ValueError(f"Invalid mode: {mode}. Supported modes are 'inv', 'fwd', 'jacobian', and 'dl'.")
        
        if weight_path is not None:
            try:
                if weight_path != "random_init":
                    self.load_state_dict(torch.load(weight_path, weights_only=True))
                else:
                    print("The inv dynamics module is initialized with random weights: only for comparison.")
                self.weight_path = weight_path
                # only instantiate the action encoder for RL when loading weights
                # self.critic_output_layer = nn.Linear(penultimate_dim, 1)
                if finetune_frozen:
                    self.freeze_pretrained_weights()
            except:
                raise Exception("Failed to load pretrained weights; Check if the path is correct and the architectures match.")
            
        # self.reinitialize_weights()

        self.to(device)
        print(f"INV Dynamics trainable parameters: {self.get_number_of_trainable_parameters()}")
        

    def freeze_pretrained_weights(self):
        self.state_encoder.requires_grad_(False)
        self.action_encoder.requires_grad_(False)
        self.main_MLP.requires_grad_(False)
        self.output_layer.requires_grad_(False)

    def forward_inv(self, x_cut, a_cut, desired_delta_x):

        if hasattr(self, "LSTM_core"):
            per_timestep_input = torch.cat((x_cut, a_cut), dim=-1)  # [batch_size, T-1, input_dim_states + input_dim_actions]
            lstm_out, _ = self.LSTM_core(per_timestep_input)
            out = self.output_layer(lstm_out[:, -1, :])  # [batch_size, output_dim_actions]
            return out

        embedded_states = torch.flatten(self.state_encoder(x_cut), start_dim=1)
        embedded_delta_states = self.state_delta_encoder(desired_delta_x)  # [batch_size, representation_dim]
        # embedded_delta_states = self.state_delta_encoder(desired_delta_x)  # [batch_size, representation_dim]
        if self.input_timesteps > 2:
            embedded_actions = torch.flatten(self.action_encoder(a_cut), start_dim=1)
            # the first action is not used because we do not know the state before the first action
            # the last action is not used because it is our target
            main_MLP_input = torch.cat((embedded_delta_states, embedded_states, embedded_actions), dim=-1)
        else:
            main_MLP_input = torch.cat((embedded_delta_states, embedded_states), dim=-1)
        mlp_o = self.main_MLP(main_MLP_input)
        out = self.output_layer(mlp_o)
        return out

          
    def forward(self, x, a, mask=None):
        """
            args:
                states: [batch_size, T, input_dim_states]
                actions: [batch_size, T, input_dim_actions]
        """
        if self.mode == "inv":
            ### !! following is no longer used, but kept for reference;
            ### see function self.forward_inv

            # embedded_states = torch.flatten(self.state_encoder(x[:, :-1]), start_dim=1)
            # embedded_delta_states = self.state_delta_encoder(x[:, -1, :] - x[:, -2, :])  # [batch_size, representation_dim]
            # if self.input_timesteps > 2:
            #     embedded_actions = torch.flatten(self.action_encoder(a[:, 1:-1]), start_dim=1)
            #     # the first action is not used because we do not know the state before the first action
            #     # the last action is not used because it is our target
            #     main_MLP_input = torch.cat((embedded_delta_states, embedded_states, embedded_actions), dim=-1)
            # else:
            #     main_MLP_input = torch.cat((embedded_delta_states, embedded_states), dim=-1)

            raise NotImplementedError("Use forward_inv method for inverse dynamics.")
        elif self.mode == "fwd":
            embedded_states = torch.flatten(self.state_encoder(x[:, :-1]), start_dim=1)
            embedded_actions = torch.flatten(self.action_encoder(a[:, 1:]), start_dim=1)
            main_MLP_input = torch.cat((embedded_states, embedded_actions), dim=-1)
        
        elif self.mode == "jacobian":
            embedded_states = torch.flatten(self.state_encoder(x[:, :-1, :]), start_dim=1)
            embedded_actions = torch.flatten(self.action_encoder(a[:, 1:]), start_dim=1)
            main_MLP_input = torch.cat((embedded_states, embedded_actions), dim=-1)
            mlp_o = self.main_MLP(main_MLP_input)
            out_x = self.output_layer(mlp_o)
            jacobian = out_x[:, :self.dim_actions * self.dim_actions].reshape(x.shape[0], self.dim_actions, self.dim_actions)
            bias_state = out_x[:, self.dim_actions * self.dim_actions:].reshape(x.shape[0], self.dim_actions)
            return jacobian, bias_state
        
        elif self.mode == "dl": # stands for decoupled linear
            embedded_states = torch.flatten(self.state_encoder(x[:, :-1, :]), start_dim=1)
            embedded_actions = torch.flatten(self.action_encoder(a[:, 1:]), start_dim=1)
            main_MLP_input = torch.cat((embedded_states, embedded_actions), dim=-1)
            mlp_o = self.main_MLP(main_MLP_input)
            out_x = self.output_layer(mlp_o)
            weight_state = out_x[:, :self.dim_actions]
            bias_state = out_x[:, self.dim_actions:]
            return weight_state, bias_state

        else:
            raise ValueError(f"Invalid mode: {self.mode}. Supported modes are 'inv' and 'fwd'.")

        mlp_o = self.main_MLP(main_MLP_input)
        out = self.output_layer(mlp_o)
        return out
    
    def forward_RL_get_action(self, *args) -> torch.Tensor:
        x_cut, a_cut, top_out = args
        embedded_states = torch.flatten(self.state_encoder(x_cut), start_dim=1)
        embedded_delta_states = top_out
        embedded_actions = torch.flatten(self.action_encoder(a_cut), start_dim=1)
        main_MLP_input = torch.cat((embedded_delta_states, embedded_states, embedded_actions), dim=-1)
        mlp_o = self.main_MLP(main_MLP_input)
        out = self.output_layer(mlp_o)
        return out
    
    def forward_RL(self, *args) -> torch.Tensor:
        if self.mode == "dl":
            observations = args[0]
            x, a, current_step_input = prepare_dl_input_from_obs(observations)
            weight_state, bias_state = self.forward(x, a)
            # if using EAC with decoupled linear, use following line:
            # return torch.cat((current_step_input, weight_state, bias_state), dim=-1)  # [batch_size, 48+12*2]
            # if using DLStdActorCritic, use following line: 
            return weight_state, bias_state
        
        if self.mode == "fwd":
            x_cut, a_cut, top_out = args
            embedded_states = torch.flatten(self.state_encoder(x_cut), start_dim=1)
            embedded_actions = torch.flatten(self.action_encoder(a_cut), start_dim=1)
            main_MLP_input = torch.cat((embedded_states, embedded_actions, top_out), dim=-1)
            mlp_o = self.main_MLP(main_MLP_input)
            return mlp_o
        
        if self.mode == "inv":
            x_cut, a_cut, top_out = args
            embedded_states = torch.flatten(self.state_encoder(x_cut), start_dim=1)
            embedded_delta_states = top_out
            embedded_actions = torch.flatten(self.action_encoder(a_cut), start_dim=1)
            main_MLP_input = torch.cat((embedded_delta_states, embedded_states, embedded_actions), dim=-1)
            mlp_o = self.main_MLP(main_MLP_input)
            return mlp_o
            
        raise NotImplementedError("This method is not implemented for the PIDM")
    
    @torch.no_grad()
    def get_pred_error_as_intrinsic_reward(self, x, a, dones): # TODO
        """
            Calculate the intrinsic reward based on the predicted action and the actual action taken.
        """
        # get the predicted action
        predicted_action = self.forward(x, a)
        a_tm1 = a[:, -1, :]
        errors = torch.linalg.norm(predicted_action - a_tm1, dim=-1) # [num_envs] TODO: now using MSE, but can be changed to other metrics
        reward_mat = torch.zeros_like(errors)
        reward_mat[~(dones.to(torch.bool))] = errors[~(dones.to(torch.bool))] # to bool is very important!
        return reward_mat
    
    def reinitialize_weights(self):
        """
            Reinitialize the weights of the model.
            This is useful for resetting the model after training or for fine-tuning.
        """
        ####### TODO
        if hasattr(self, "LSTM_core"):
            return
        #######
        for layer in self.modules():
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_uniform_(layer.weight)
                if layer.bias is not None:
                    nn.init.zeros_(layer.bias)
    
    # def get_optim_group_for_reduced_lr(self):
    #     return chain(self.state_encoder.parameters(), self.action_encoder_pretrain.parameters(), self.main_MLP.parameters(), self.output_layer.parameters())
    
    # def get_optim_group_for_base_lr(self):
    #     return chain(self.critic_output_layer.parameters())  



class InvDynamicsRNN(DynamicsModule):
    def __init__(self, 
                 dim_states, 
                 dim_actions, 
                 representation_dim, 
                 hidden_dim: int, 
                 num_layers: int = 1,
                 rnn_type: str = "gru",
                 activation_name: str = "elu",
                 mode: str = "inv",
                 weight_path = None, 
                 finetune_frozen = False,
                 device: str = "cpu",
                 **kwargs):
        """
            Initialize the inverse dynamics MLP model.
            Args:
                dim_states: Dimension of the state space.
                dim_actions: Dimension of the action space.
                input_timesteps: Number of timesteps in the input sequence.
                representation_dim: Dimension of the state representation.
                hidden_dims: List of hidden layer dimensions.
                activation_name: Activation function to use in the MLP.

                weight_path: Path to the pretrained weights (optional).
                finetune_frozen: If True, freeze the pretrained weights during fine-tuning.
                device: Device to run the model on. 
        """
        super().__init__()
        if kwargs:
            print(
                "InvDynamicMLP.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        self.module_type = "inv_dynamic"
        self.dim_states = dim_states
        self.dim_actions = dim_actions
        self.device = device # NOTE: not used
        self.mode = mode

        self.hidden_dim = hidden_dim
        self.per_step_encoder = build_mlp(dim_states*2+dim_actions if mode=="inv" else dim_states+dim_actions, [128], representation_dim, activation_name)
        
        self.rnn_type = rnn_type.lower()
        if self.rnn_type == "gru":
            self.main_RNN = nn.GRU(
                input_size=representation_dim,
                hidden_size=hidden_dim,
                num_layers=num_layers,
                batch_first=True,  # Set to True to have input shape as (batch, seq, feature)
                dropout=0.1,
            )
        elif self.rnn_type == "lstm":
            self.main_RNN = nn.LSTM(
                input_size=representation_dim,
                hidden_size=hidden_dim,
                num_layers=num_layers,
                batch_first=True,  # Set to True to have input shape as (batch, seq, feature)
                dropout=0.1,
            )
        else:
            raise ValueError(f"Unsupported RNN type: {self.rnn_type}. Supported types are 'gru' and 'lstm'.")
        
        self.output_layer = nn.Linear(hidden_dim, dim_actions)
        
        if weight_path is not None:
            try:
                if weight_path != "random_init":
                    self.load_state_dict(torch.load(weight_path, weights_only=True))
                else:
                    print("The inv dynamics module is initialized with random weights: only for comparison.")
                self.weight_path = weight_path
                # only instantiate the action encoder for RL when loading weights
                # self.critic_output_layer = nn.Linear(penultimate_dim, 1)
                if finetune_frozen:
                    self.freeze_pretrained_weights()
            except:
                raise Exception("Failed to load pretrained weights; Check if the path is correct and the architectures match.")
            
        self.reinitialize_weights()

        self.to(device)
        print(f"INV Dynamics trainable parameters: {self.get_number_of_trainable_parameters()}")
        

    def freeze_pretrained_weights(self):
        self.per_step_encoder.requires_grad_(False)
        self.main_RNN.requires_grad_(False)
        self.output_layer.requires_grad_(False)
          
    def make_per_timestep_input(self, x, a, mask):
        if self.mode == "inv":
            input_per_timestep = torch.cat((x[:, :-1, :], a[:, :-1, :], x[:, 1:, :]), dim=-1)  # [batch_size, T-1, input_dim_states + input_dim_actions]
            input_mask = mask[:, 1:] & mask[:, :-1]
            return input_per_timestep, input_mask
        elif self.mode == "fwd":
            input_per_timestep = torch.cat((x[:, :-1, :], a[:, 1:, :]), dim=-1)
            input_mask = mask[:, 1:] & mask[:, :-1]
            return input_per_timestep, input_mask

    def forward(self, x, a, mask=None):
        """
            args:
                states: [batch_size, T, input_dim_states]
                actions: [batch_size, T, input_dim_actions]
        """
        input_per_timestep, input_mask = self.make_per_timestep_input(x, a, mask)
        embeddings = self.per_step_encoder(input_per_timestep)
        h, _ = self.main_RNN(embeddings)  # h: [batch_size, T-1, hidden_dim]
        out = self.output_layer(h)  # out: [batch_size, T-1, dim_actions]
        return out
    
    def forward_RL(self, observations):
        raise NotImplementedError("This method is not implemented for the PIDM")
    
    def reinitialize_weights(self):
        """
            Reinitialize the weights of the model.
            This is useful for resetting the model after training or for fine-tuning.
        """
        for mlp in [self.per_step_encoder, self.output_layer]:
            for layer in mlp.modules():
                if isinstance(layer, nn.Linear):
                    nn.init.kaiming_uniform_(layer.weight)
                    if layer.bias is not None:
                        nn.init.zeros_(layer.bias)
        if self.rnn_type == init_lstm:
            init_lstm(self.main_RNN)
        elif self.rnn_type == init_gru:
            init_gru(self.main_RNN)
    

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()

        # Create a long enough matrix of [max_len x d_model]
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)  # even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices
        
        pe = pe.unsqueeze(0)  # Shape: (1, max_len, d_model)
        self.register_buffer('pe', pe)  # non-trainable

    def forward(self, x):
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, d_model)
        Returns:
            Tensor with positional encodings added
        """
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len]


def reinitialize_weights_tf(self):
    def init_fn(m):
        if hasattr(m, 'reset_parameters'):
            m.reset_parameters()
        elif isinstance(m, nn.Linear):
            nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
            if m.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
                bound = 1 / math.sqrt(fan_in)
                nn.init.uniform_(m.bias, -bound, bound)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0, std=0.02)


class InvDynamicsTransformer(DynamicsModule):
    def __init__(self, 
                 dim_states, 
                 dim_actions, 
                 input_timesteps, 
                 representation_dim, 
                 feedforward_dim: int = 512,
                 num_layers: int = 4,
                 num_heads: int = 4, 
                 mode: str = "inv",
                 activation_name: str = "elu",
                 weight_path = None, 
                 finetune_frozen = False,
                 device: str = "cpu",
                 **kwargs):

        super().__init__()
        if kwargs:
            print(
                "InvDynamicMLP.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        self.module_type = "inv_dynamic"
        self.dim_states = dim_states
        self.dim_actions = dim_actions
        self.device = device # NOTE: not used
        assert input_timesteps >= 2, "Input timesteps must be greater than 2!"
        self.input_timesteps = input_timesteps
        self.mode = mode

        self.state_encoder = build_mlp(dim_states, [128], representation_dim, activation_name)
        if input_timesteps > 2:
            self.action_encoder = build_mlp(dim_actions, [128], representation_dim, activation_name)

        if mode == "inv":
            self.state_delta_encoder = build_mlp(dim_states, [128], representation_dim, activation_name)
        elif mode == "fwd":
            self.next_state_seed = nn.Parameter(torch.randn(representation_dim), requires_grad=True)

        self.tf_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model=representation_dim,
                nhead=num_heads,  # Number of attention heads
                dim_feedforward=feedforward_dim,  # Feedforward network dimension
                activation=activation_name,  # Activation function
                batch_first=True,
                dropout=0.1
            ),
            num_layers=num_layers,  # Number of decoder layers
        )
        self.output_layer = build_mlp(representation_dim, [256, 128], dim_actions, activation_name)
        self.pe = SinusoidalPositionalEncoding(representation_dim, max_len=input_timesteps)
        
        if weight_path is not None:
            try:
                if weight_path != "random_init":
                    self.load_state_dict(torch.load(weight_path, weights_only=True))
                else:
                    print("The inv dynamics module is initialized with random weights: only for comparison.")
                self.weight_path = weight_path
                # only instantiate the action encoder for RL when loading weights
                # self.critic_output_layer = nn.Linear(penultimate_dim, 1)
                if finetune_frozen:
                    self.freeze_pretrained_weights()
            except:
                raise Exception("Failed to load pretrained weights; Check if the path is correct and the architectures match.")
            
        self.reinitialize_weights()

        self.to(device)
        print(f"INV Dynamics trainable parameters: {self.get_number_of_trainable_parameters()}")

    def freeze_pretrained_weights(self):
        self.state_encoder.requires_grad_(False)
        self.action_encoder.requires_grad_(False)
        self.tf_decoder.requires_grad_(False)
        self.output_layer.requires_grad_(False)

    def forward(self, x, a, mask=None):
        """
            args:
                states: [batch_size, T, input_dim_states]
                actions: [batch_size, T, input_dim_actions]
        """
        if self.mode == "inv":
            states_tokens = self.pe(self.state_encoder(x[:, :-1]))
            action_tokens = self.pe(self.action_encoder(a[:, 1:-1]))
            delta_states_tokens = self.state_delta_encoder(x[:, -1, :] - x[:, -2, :])

            memory_tokens = torch.cat((states_tokens, action_tokens), dim=1)  # [batch_size, T-1, representation_dim]

            out_tf = self.tf_decoder(
                tgt=delta_states_tokens.unsqueeze(1),  # [batch_size, 1, representation_dim]
                memory=memory_tokens,  # [T-1, batch_size, representation_dim]
            )

            out = self.output_layer(out_tf.squeeze(1))

        elif self.mode == "fwd":
            states_tokens = self.pe(self.state_encoder(x[:, :-1]))
            action_tokens = self.pe(self.action_encoder(a[:, 1:]))

            memory_tokens = torch.cat((states_tokens, action_tokens), dim=1)  # [batch_size, T-1, representation_dim]

            seeds = repeat(self.next_state_seed, "d -> b 1 d", b=memory_tokens.shape[0])
            out_tf = self.tf_decoder(
                tgt=seeds,
                memory=memory_tokens,  # [T-1, batch_size, representation_dim]
            )

            out = self.output_layer(out_tf.squeeze(1))

        return out  # [batch_size, dim_actions]
    

    def reinitialize_weights(self):
        self.apply(reinitialize_weights_tf)

    def forward_RL(self, observations):
        raise NotImplementedError("This method is not implemented for the PIDM")