# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import torch
import torch.nn as nn
from typing import Union, List, Any, Dict
from torch.distributions import Normal
import warnings

from rsl_rl.networks import MLP, EmpiricalNormalization, Memory

from rsl_rl.modules.actor_critic import ActorCritic, ExtendableActorCritic, ExtendableModel, ActorCriticBase
from rsl_rl.addons.dynamics.modules_recurrent import *

class ActorCriticRecurrent(nn.Module):
    is_recurrent = True

    def __init__(
        self,
        obs,
        obs_groups,
        num_actions,
        actor_obs_normalization=False,
        critic_obs_normalization=False,
        actor_hidden_dims=[256, 256, 256],
        critic_hidden_dims=[256, 256, 256],
        activation="elu",
        init_noise_std=1.0,
        noise_std_type: str = "scalar",
        rnn_type="lstm",
        rnn_hidden_dim=256,
        rnn_num_layers=1,
        **kwargs,
    ):
        if "rnn_hidden_size" in kwargs:
            warnings.warn(
                "The argument `rnn_hidden_size` is deprecated and will be removed in a future version. "
                "Please use `rnn_hidden_dim` instead.",
                DeprecationWarning,
            )
            if rnn_hidden_dim == 256:  # Only override if the new argument is at its default
                rnn_hidden_dim = kwargs.pop("rnn_hidden_size")
        if kwargs:
            print(
                "ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()),
            )
        super().__init__()

        # get the observation dimensions
        self.obs_groups = obs_groups
        num_actor_obs = 0
        for obs_group in obs_groups["policy"]:
            assert len(obs[obs_group].shape) == 2, "The ActorCriticRecurrent module only supports 1D observations."
            num_actor_obs += obs[obs_group].shape[-1]
        num_critic_obs = 0
        for obs_group in obs_groups["critic"]:
            assert len(obs[obs_group].shape) == 2, "The ActorCriticRecurrent module only supports 1D observations."
            num_critic_obs += obs[obs_group].shape[-1]

        # actor
        self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim)
        self.actor = MLP(rnn_hidden_dim, num_actions, actor_hidden_dims, activation)
        # actor observation normalization
        self.actor_obs_normalization = actor_obs_normalization
        if actor_obs_normalization:
            self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs)
        else:
            self.actor_obs_normalizer = torch.nn.Identity()
        print(f"Actor RNN: {self.memory_a}")
        print(f"Actor MLP: {self.actor}")

        # critic
        self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim)
        self.critic = MLP(rnn_hidden_dim, 1, critic_hidden_dims, activation)
        # critic observation normalization
        self.critic_obs_normalization = critic_obs_normalization
        if critic_obs_normalization:
            self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs)
        else:
            self.critic_obs_normalizer = torch.nn.Identity()
        print(f"Critic RNN: {self.memory_c}")
        print(f"Critic MLP: {self.critic}")

        # Action noise
        self.noise_std_type = noise_std_type
        if self.noise_std_type == "scalar":
            self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        elif self.noise_std_type == "log":
            self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

        # Action distribution (populated in update_distribution)
        self.distribution = None
        # disable args validation for speedup
        Normal.set_default_validate_args(False)

    @property
    def action_mean(self):
        return self.distribution.mean

    @property
    def action_std(self):
        return self.distribution.stddev

    @property
    def entropy(self):
        return self.distribution.entropy().sum(dim=-1)

    def reset(self, dones=None):
        self.memory_a.reset(dones)
        self.memory_c.reset(dones)

    def forward(self):
        raise NotImplementedError

    def update_distribution(self, obs):
        # compute mean
        mean = self.actor(obs)
        # compute standard deviation
        if self.noise_std_type == "scalar":
            std = self.std.expand_as(mean)
        elif self.noise_std_type == "log":
            std = torch.exp(self.log_std).expand_as(mean)
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
        # create distribution
        self.distribution = Normal(mean, std)

    def act(self, obs, masks=None, hidden_states=None):
        obs = self.get_actor_obs(obs)
        obs = self.actor_obs_normalizer(obs)
        out_mem = self.memory_a(obs, masks, hidden_states).squeeze(0)
        self.update_distribution(out_mem)
        return self.distribution.sample()

    def act_inference(self, obs):
        obs = self.get_actor_obs(obs)
        obs = self.actor_obs_normalizer(obs)
        out_mem = self.memory_a(obs).squeeze(0)
        return self.actor(out_mem)

    def evaluate(self, obs, masks=None, hidden_states=None):
        obs = self.get_critic_obs(obs)
        obs = self.critic_obs_normalizer(obs)
        out_mem = self.memory_c(obs, masks, hidden_states).squeeze(0)
        return self.critic(out_mem)

    def get_actor_obs(self, obs):
        obs_list = []
        for obs_group in self.obs_groups["policy"]:
            obs_list.append(obs[obs_group])
        return torch.cat(obs_list, dim=-1)

    def get_critic_obs(self, obs):
        obs_list = []
        for obs_group in self.obs_groups["critic"]:
            obs_list.append(obs[obs_group])
        return torch.cat(obs_list, dim=-1)

    def get_actions_log_prob(self, actions):
        return self.distribution.log_prob(actions).sum(dim=-1)

    def get_hidden_states(self):
        return self.memory_a.hidden_states, self.memory_c.hidden_states

    def update_normalization(self, obs):
        if self.actor_obs_normalization:
            actor_obs = self.get_actor_obs(obs)
            self.actor_obs_normalizer.update(actor_obs)
        if self.critic_obs_normalization:
            critic_obs = self.get_critic_obs(obs)
            self.critic_obs_normalizer.update(critic_obs)

    def load_state_dict(self, state_dict, strict=True):
        """Load the parameters of the actor-critic model.

        Args:
            state_dict (dict): State dictionary of the model.
            strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this
                           module's state_dict() function.

        Returns:
            bool: Whether this training resumes a previous training. This flag is used by the `load()` function of
                  `OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation).
        """

        super().load_state_dict(state_dict, strict=strict)
        return True


################ P4RL ########################

# following has not been tested after merging with main-rsl branch!

class ExtendableModelRecurrent(ExtendableModel):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        final_mlp_dims=[128, 128, 128],
        direct_pathway_dim: int = 15, 
        activation="elu",
        submodule_configs: List[dict] | None = None
    ):

        super().__init__(
            input_dim,
            output_dim,
            final_mlp_dims,
            direct_pathway_dim,
            activation,
            submodule_configs,
        )

    def forward(self, observations, masks=None, hidden_states=None):
        final_mlp_inputs = [observations[:, -self.direct_pathway_dim:]] if self.direct_pathway_dim > 0 else []
        for submodule in self.submodules.values():
            submodule_output = submodule.forward_RL(observations, masks, hidden_states)
            final_mlp_inputs.append(submodule_output)
        final_layer_input = torch.cat(final_mlp_inputs, dim=-1)
        out = self.final_layers(final_layer_input)
        return out
    
    def reset(self, dones=None):
        for submodule in self.submodules.values():
            if hasattr(submodule, 'reset'):
                submodule.reset(dones)
            else:
                print(f"Submodule {submodule} does not have a reset method.")

    def get_hidden_states(self):
        return self.submodules["dynamic"].get_hidden_states()
    

class ExtendableActorCriticRecurrent(ActorCriticRecurrent):
    is_recurrent = True
    actor: ExtendableModelRecurrent
    critic: ExtendableModelRecurrent

    def __init__(
        self,
        num_actor_obs,
        num_critic_obs,
        num_actions,
        direct_pathway_dim=15,
        final_mlp_dims=[128, 128, 128],
        activation="elu",
        init_noise_std=1.0,
        noise_std_type: str = "scalar",
        submodule_configs: List[dict] | None = None, 
        **kwargs: Any,
    ):
        if kwargs:
            print(
                "ExtendableActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()),
            )

        super().__init__()


        self.direct_pathway_dim = direct_pathway_dim

        # Policy
        self.actor = ExtendableModelRecurrent(
            input_dim=num_actor_obs,
            output_dim=num_actions,
            direct_pathway_dim=direct_pathway_dim,
            final_mlp_dims=final_mlp_dims,
            activation=activation, 
            submodule_configs=submodule_configs
        )
        

        # Value function
        self.critic = ExtendableModelRecurrent(
            input_dim=num_critic_obs,
            output_dim=1,
            direct_pathway_dim=direct_pathway_dim,
            final_mlp_dims=final_mlp_dims,
            activation=activation, 
            submodule_configs=submodule_configs
        )

        # print(f"Actor MLP: {self.actor}")
        # print(f"Critic MLP: {self.critic}")

        # Action noise
        self.noise_std_type = noise_std_type
        if self.noise_std_type == "scalar":
            self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        elif self.noise_std_type == "log":
            self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

        # Action distribution (populated in update_distribution)
        self.distribution = Normal(0, 1) # dummy distribution to initialize. I assume this won't be used.
        # disable args validation for speedup
        Normal.set_default_validate_args(False)


