# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import torch
import torch.nn as nn

from rsl_rl.modules.normalizer import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization
from rsl_rl.utils import resolve_nn_activation


class DynamicsExploration(nn.Module):
    """
        This class implements dynamics exploration for P4RL. The intrinsic reward is given by the prediction error of 
        the (inverse) dynamics model.
    """

    def __init__(
        self,
        num_states: int,
        cluster_dim: int,
        cluster_hidden_dims: list[int],
        reward_weight: float = 1.0,
        activation: str = "elu",
        device: str = "cpu",
    ):

        # initialize parent class
        super().__init__()

        # Store parameters
        self.num_states = num_states
        self.num_outputs = cluster_dim
        self.device = device
        self.reward_weight = reward_weight

        # Create network architecture
        self.cluster = self._build_mlp(num_states, cluster_hidden_dims, cluster_dim, activation).to(self.device)

    def get_intrinsic_reward(self, observation) -> tuple[torch.Tensor, torch.Tensor]:
        # note: the counter is updated number of env steps per learning iteration
        self.update_counter += 1
        # Obtain the embedding of the rnd state from the target and predictor networks
        pred = self.forward_inv_dynamics_model(s_t, s_tp1)
        # Compute the intrinsic reward as the distance between the embeddings
        intrinsic_reward = self.error_function(a_t, a_t_pred)
        intrinsic_reward *= self.reward_weight

        return intrinsic_reward, observation
    
    def error_function(self, a_t, a_t_pred) -> torch.Tensor:
        """Computes the error between the predicted and actual action."""
        return torch.linalg.norm(a_t - a_t_pred, dim=-1)
    
    def forward_inv_dynamics_model(self, s_t, s_tp1) -> torch.Tensor:
        pred = self.dynamics_model(s_t, s_tp1)
        return pred


    def forward(self, *args, **kwargs):
        raise RuntimeError("Forward method is not implemented. Use get_intrinsic_reward instead.")

    def train(self, mode: bool = True):
        # sets module into training mode
        self.dynamics_model.train(mode)

    def eval(self):
        return self.train(False)

    """
    Private Methods
    """

    @staticmethod
    def _build_mlp(input_dims: int, hidden_dims: list[int], output_dims: int, activation_name: str = "elu"):
        """Builds target and predictor networks"""

        network_layers = []
        # resolve hidden dimensions
        # if dims is -1 then we use the number of observations
        hidden_dims = [input_dims if dim == -1 else dim for dim in hidden_dims]
        # resolve activation function
        activation = resolve_nn_activation(activation_name)
        # first layer
        network_layers.append(nn.Linear(input_dims, hidden_dims[0]))
        network_layers.append(activation)
        # subsequent layers
        for layer_index in range(len(hidden_dims)):
            if layer_index == len(hidden_dims) - 1:
                # last layer
                network_layers.append(nn.Linear(hidden_dims[layer_index], output_dims))
            else:
                # hidden layers
                network_layers.append(nn.Linear(hidden_dims[layer_index], hidden_dims[layer_index + 1]))
                network_layers.append(activation)
        return nn.Sequential(*network_layers)