import torch
from torch import nn

from lambda_ac.nn.common import MLP, DoubleHeadMLP
from lambda_ac.nn.critic_models import Encoder
from lambda_ac.rl_types import ActorModule, Encoder, EncoderActorModule, FeatureInput
from lambda_ac.util.distributions import DiracDistribution, TanhGaussian


class GaussianActor(ActorModule):
    def __init__(
        self,
        input_dim: int,
        action_dim: int,
        hidden_dim: int,
        hidden_layers: int,
        depend_on_hidden: bool = False,
    ):
        super().__init__()
        input_size = 2 * input_dim if depend_on_hidden else input_dim
        self.depend_on_hidden = depend_on_hidden
        self.action_dim = action_dim

        self.net = DoubleHeadMLP(
            input_size,
            action_dim,
            hidden_dim,
            hidden_layers,
            apply_spectral_norm=False,
            normalize_input=True,
            nonlinearity=nn.ELU,
            normalize_last_layer=False,
            batch_norm=False,
        )

    def forward(self, x: FeatureInput) -> TanhGaussian:
        if self.depend_on_hidden:
            features = torch.cat([x.encoded, x.hidden], -1)
        else:
            features = x.encoded
        mean, std = self.net(features)
        return TanhGaussian(mean, std)


class DeterministicActor(ActorModule):
    def __init__(
        self,
        input_dim: int,
        action_dim: int,
        hidden_dim: int,
        hidden_layers: int,
        depend_on_hidden: bool = False,
    ):
        super().__init__()
        input_size = 2 * input_dim if depend_on_hidden else input_dim
        self.depend_on_hidden = depend_on_hidden
        self.action_dim = action_dim

        self.net = MLP(
            input_size,
            action_dim,
            hidden_dim,
            hidden_layers,
            normalize_input=True,
            normalize_last_layer=False,
            nonlinearity=nn.ELU,
            batch_norm=False,
        )

    def forward(self, x: FeatureInput) -> DiracDistribution:
        if self.depend_on_hidden:
            features = torch.cat([x.encoded, x.hidden], -1)
        else:
            features = x.encoded
        mean = self.net(features)
        mean = torch.tanh(mean)
        return DiracDistribution(mean)


class EncoderActor(EncoderActorModule):
    def __init__(
        self,
        encoder: Encoder,
        head: ActorModule,
    ):
        super().__init__(encoder)
        self.depend_on_hidden = head.depend_on_hidden
        self.action_dim = head.action_dim

        self.encoder = encoder
        self.head = head

    def forward(self, x: torch.Tensor, detach_encoder: bool = False) -> TanhGaussian:
        enc = self.encoder(x)
        enc = FeatureInput.from_output(enc)
        if detach_encoder:
            enc.detach()
        return self.head(enc)
