from typing import Dict, List, Tuple, Type, Union

import gymnasium as gym
import torch as th
from torch import nn

from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp, MlpExtractor
from stable_baselines3.common.utils import get_device

class BaseHFeaturesExtractor(BaseFeaturesExtractor):
    """
    Base class that represents a features extractor.

    :param observation_space:
    :param features_dim: Number of features extracted.
    """

    def __init__(self, observation_space: gym.Space, hidden_obs_shape: Tuple, features_dim: int = 0) -> None:
        super(BaseFeaturesExtractor, self).__init__()
        assert features_dim > 0
        self._observation_space = observation_space
        self.hidden_obs_shape = hidden_obs_shape
        self._features_dim = features_dim


class HiddenObsExtractor(BaseHFeaturesExtractor):

    def __init__(
            self,
            observation_space: gym.Space,
            hidden_obs_shape: Tuple,
            obs_encoder_arch=None,
            hidden_obs_encoder_arch=None,
    ) -> None:
        if obs_encoder_arch is None:
            obs_encoder_arch = []
        if hidden_obs_encoder_arch is None:
            hidden_obs_encoder_arch = []
        obs_dim = get_flattened_obs_dim(observation_space)
        hidden_obs_dim = hidden_obs_shape[0]

        super().__init__(observation_space, hidden_obs_shape,
                         features_dim=(obs_dim + hidden_obs_dim))

        obs_encoder_net = create_mlp(obs_dim, -1, obs_encoder_arch, nn.ReLU)
        self.obs_encoder = nn.Sequential(*obs_encoder_net)
        hidden_obs_encoder_net = create_mlp(hidden_obs_dim, -1, hidden_obs_encoder_arch, nn.ReLU)
        self.hidden_obs_encoder = nn.Sequential(*hidden_obs_encoder_net)

        # self.obs_encoder = nn.Sequential(
        #     nn.Linear(obs_dim, 128),
        #     nn.ReLU(),
        #     nn.Linear(128, 64),
        #     nn.ReLU()
        # )
        #
        # self.hidden_obs_encoder = nn.Sequential(
        #     nn.Linear(hidden_obs_dim, 24),
        #     nn.ReLU(),
        #     nn.Linear(24, 48),
        #     nn.ReLU()
        # )

        # total_concat_size = 64 + 48
        obs_feature_dim = obs_encoder_arch[-1] if len(obs_encoder_arch) > 0 else obs_dim
        hidden_obs_feature_dim = hidden_obs_encoder_arch[-1] if len(hidden_obs_encoder_arch) > 0 else hidden_obs_dim
        # Update the features dim manually
        # self._features_dim = total_concat_size
        self._features_dim = obs_feature_dim + hidden_obs_feature_dim

    def forward(self, observations: th.Tensor, hidden_obs: th.Tensor) -> th.Tensor:
        encoded_tensor_list = [self.obs_encoder(observations.float()),
                               self.hidden_obs_encoder(hidden_obs.float())]

        return th.cat(encoded_tensor_list, dim=1)

class MlpHExtractor(MlpExtractor):

    def __init__(
        self,
        feature_dim: int,
        reward_feature_dim: Union[int, None],
        net_arch: Union[List[int], Dict[str, List[int]]],
        activation_fn: Type[nn.Module],
        device: Union[th.device, str] = "auto",
    ) -> None:

        super(MlpExtractor, self).__init__()
        device = get_device(device)
        policy_net: List[nn.Module] = []
        value_net: List[nn.Module] = []
        log_score_value_net: List[nn.Module] = []
        last_layer_dim_pi = feature_dim
        last_layer_dim_vf = feature_dim if reward_feature_dim is None else reward_feature_dim
        last_layer_dim_log_score_vf = feature_dim

        # save dimensions of layers in policy and value nets
        if isinstance(net_arch, dict):
            # Note: if key is not specificed, assume linear network
            pi_layers_dims = net_arch.get("pi", [])  # Layer sizes of the policy network
            vf_layers_dims = net_arch.get("vf", [])  # Layer sizes of the value network
            log_score_vf_layers_dims = net_arch.get("log_score_vf", [])  # Layer sizes of the log_score value network
        else:
            pi_layers_dims = vf_layers_dims = log_score_vf_layers_dims = net_arch
        # Iterate through the policy layers and build the policy net
        for curr_layer_dim in pi_layers_dims:
            policy_net.append(nn.Linear(last_layer_dim_pi, curr_layer_dim))
            policy_net.append(activation_fn())
            last_layer_dim_pi = curr_layer_dim
        # Iterate through the value layers and build the value net
        for curr_layer_dim in vf_layers_dims:
            value_net.append(nn.Linear(last_layer_dim_vf, curr_layer_dim))
            value_net.append(activation_fn())
            last_layer_dim_vf = curr_layer_dim
        # Iterate through the log_score value layers and build the log_score value net
        for curr_layer_dim in log_score_vf_layers_dims:
            log_score_value_net.append(nn.Linear(last_layer_dim_log_score_vf, curr_layer_dim))
            log_score_value_net.append(activation_fn())
            last_layer_dim_log_score_vf = curr_layer_dim

        # Save dim, used to create the distributions
        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf
        self.latent_dim_log_score_vf = last_layer_dim_log_score_vf

        # Create networks
        # If the list of layers is empty, the network will just act as an Identity module
        self.policy_net = nn.Sequential(*policy_net).to(device)
        self.value_net = nn.Sequential(*value_net).to(device)
        self.log_score_value_net = nn.Sequential(*log_score_value_net).to(device)

    def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
        """
        :return: latent_policy, latent_value of the specified network.
            If all layers are shared, then ``latent_policy == latent_value``
        """
        return self.forward_actor(features), self.forward_critic(features), self.forward_log_score_critic(features)

    def forward_log_score_critic(self, features: th.Tensor) -> th.Tensor:
        return self.log_score_value_net(features)


class MlpCExtractor(MlpExtractor):

    def __init__(
        self,
        feature_dim: int,
        net_arch: Union[List[int], Dict[str, List[int]]],
        activation_fn: Type[nn.Module],
        device: Union[th.device, str] = "auto",
    ) -> None:

        super(MlpExtractor, self).__init__()
        device = get_device(device)
        policy_net: List[nn.Module] = []
        value_net: List[nn.Module] = []
        neg_cost_value_net: List[nn.Module] = []
        last_layer_dim_pi = feature_dim
        last_layer_dim_vf = feature_dim
        last_layer_dim_neg_cost_vf = feature_dim

        # save dimensions of layers in policy and value nets
        if isinstance(net_arch, dict):
            # Note: if key is not specificed, assume linear network
            pi_layers_dims = net_arch.get("pi", [])  # Layer sizes of the policy network
            vf_layers_dims = net_arch.get("vf", [])  # Layer sizes of the value network
            neg_cost_vf_layers_dims = net_arch.get("neg_cost_vf", [])  # Layer sizes of the neg_cost value network
        else:
            pi_layers_dims = vf_layers_dims = neg_cost_vf_layers_dims = net_arch
        # Iterate through the policy layers and build the policy net
        for curr_layer_dim in pi_layers_dims:
            policy_net.append(nn.Linear(last_layer_dim_pi, curr_layer_dim))
            policy_net.append(activation_fn())
            last_layer_dim_pi = curr_layer_dim
        # Iterate through the value layers and build the value net
        for curr_layer_dim in vf_layers_dims:
            value_net.append(nn.Linear(last_layer_dim_vf, curr_layer_dim))
            value_net.append(activation_fn())
            last_layer_dim_vf = curr_layer_dim
        # Iterate through the log_score value layers and build the log_score value net
        for curr_layer_dim in neg_cost_vf_layers_dims:
            neg_cost_value_net.append(nn.Linear(last_layer_dim_neg_cost_vf, curr_layer_dim))
            neg_cost_value_net.append(activation_fn())
            last_layer_dim_neg_cost_vf = curr_layer_dim

        # Save dim, used to create the distributions
        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf
        self.latent_dim_neg_cost_vf = last_layer_dim_neg_cost_vf

        # Create networks
        # If the list of layers is empty, the network will just act as an Identity module
        self.policy_net = nn.Sequential(*policy_net).to(device)
        self.value_net = nn.Sequential(*value_net).to(device)
        self.neg_cost_value_net = nn.Sequential(*neg_cost_value_net).to(device)

    def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
        """
        :return: latent_policy, latent_value of the specified network.
            If all layers are shared, then ``latent_policy == latent_value``
        """
        return self.forward_actor(features), self.forward_critic(features), self.forward_neg_cost_critic(features)

    def forward_neg_cost_critic(self, features: th.Tensor) -> th.Tensor:
        return self.neg_cost_value_net(features)