import torch
import torch.nn as nn
from rsl_rl.utils.distribution import SquashedNormal
from typing import Tuple, List

class Actor(nn.Module):
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        layers_array: List[int],
        activation: str = "ReLU",
        use_layer_norm: bool = False,
    ):
        super(Actor, self).__init__()

        # Outputs mean value of action
        self.actor = create_net(state_dim, action_dim, layers_array, activation, use_layer_norm)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.actor(x)

    def get_actions(self, x: torch.Tensor, variance: torch.Tensor, is_deterministic: bool = False):
        # variance has the size of [population_size, batch_size or 1, action_dim]
        # assert len(variance.shape) == 3 and len(x.shape) == 3
        mean = self.actor(x)
        dist = torch.distributions.Normal(mean, variance)
        if is_deterministic:
            actions = mean
        else:
            actions = dist.rsample()
        return actions, dist


class SACActor(Actor):
    """
    Our implementation is based on Stable Baselines3 implementation of SAC.
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        layers_array: List[int],
        activation: str,
        log_std_min: float = -5,
        log_std_max: float = 5,
        use_layer_norm: bool = False,
    ) -> None:
        super().__init__(state_dim, action_dim * 2, layers_array, activation, use_layer_norm)

        self.log_std_min = log_std_min
        self.log_std_max = log_std_max

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.get_actions(x)

    def get_actions(self, state: torch.Tensor, is_deterministic: bool = False) -> Tuple[torch.Tensor, SquashedNormal]:
        mu, log_std = self.actor(state).chunk(2, dim=-1)
        std = log_std.clamp(self.log_std_min, self.log_std_max).exp()
        dist = SquashedNormal(mu, std)
        if is_deterministic:
            actions = mu
        else:
            actions = dist.rsample()
        return actions, dist


class DoubleQCritic(nn.Module):
    def __init__(
        self,
        state_dim: int,
        act_dim: int,
        layers_array: List[int],
        activation: str,
        use_layer_norm: bool = False,
    ) -> None:
        super().__init__()

        self.net_q1 = create_net(state_dim + act_dim, 1, layers_array, activation, use_layer_norm)
        self.net_q2 = create_net(state_dim + act_dim, 1, layers_array, activation, use_layer_norm)

    def get_q_min(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        return torch.min(*self.get_q1_q2(state, action))  # min Q value

    def get_q1_q2(self, state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        input_x = torch.cat((state, action), dim=-1)
        return self.net_q1(input_x), self.net_q2(input_x)  # two Q values

    def get_q1(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        input_x = torch.cat((state, action), dim=-1)
        return self.net_q1(input_x)

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        return self.get_q_min(state, action)


def create_net(
    input_size: int,
    output_size: int,
    layers_array: List[int],
    activation: str,
    use_layer_norm: bool = False,
) -> nn.Sequential:
    layers = []
    layers.append(nn.Linear(input_size, layers_array[0]))
    if use_layer_norm:
        layers.append(nn.LayerNorm(layers_array[0]))
    layers.append(get_activation(activation))
    for i in range(len(layers_array) - 1):
        layers.append(nn.Linear(layers_array[i], layers_array[i + 1]))
        if use_layer_norm:
            layers.append(nn.LayerNorm(layers_array[i + 1]))
        layers.append(get_activation(activation))
    layers.append(nn.Linear(layers_array[-1], output_size))
    return nn.Sequential(*layers)

def get_activation(act_name):
    if act_name == "elu":
        return nn.ELU()
    elif act_name == "selu":
        return nn.SELU()
    elif act_name == "relu":
        return nn.ReLU()
    elif act_name == "crelu":
        raise NotImplementedError
        # return nn.CReLU()
    elif act_name == "lrelu":
        return nn.LeakyReLU()
    elif act_name == "tanh":
        return nn.Tanh()
    elif act_name == "sigmoid":
        return nn.Sigmoid()
    else:
        print("invalid activation function!")
        return None
