import warnings
from typing import Dict, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import nn
from typing_extensions import deprecated

from lambda_ac.nn.common import MLP, DoubleHeadMLP, conv_output_shape, orthogonal_init
from lambda_ac.rl_types import Distribution, Encoder, EncoderOutput
from lambda_ac.util.distributions import BoundedGaussian, DiracDistribution


class IdentityEncoder(Encoder):
    def __init__(
        self,
        input_dim: int,
        num_hidden_layers: int,
        hidden_dim: int,
        feature_dim: int,
        encoding_hidden: int = 50,
        normalize: bool = False,
    ):
        super().__init__()

        assert input_dim == feature_dim
        self.input_dim = feature_dim
        self.hidden_dim = 0
        self.num_hidden_layers = 0
        self.feature_dim = feature_dim
        self.encoding_hidden = encoding_hidden
        self.normalize = False
        self.norm = nn.Identity()

        if num_hidden_layers > 0 or hidden_dim > 0:
            warnings.warn("IdentityEncoder ignores hidden layers and hidden dim")

    def forward(self, x: torch.Tensor) -> EncoderOutput:
        hidden = 0.01 * torch.randn(x.shape[0], self.encoding_hidden).to(x.device)
        return EncoderOutput(DiracDistribution(x), hidden)


class MLPEncoder(Encoder):
    def __init__(
        self,
        input_dim: int,
        num_hidden_layers: int,
        hidden_dim: int,
        feature_dim: int,
        encoding_hidden: int = 50,
        normalize: bool = False,
    ):
        super().__init__()

        self.input_dim = input_dim
        self.feature_dim = feature_dim
        self.normalize = normalize
        self.num_hidden_layers = num_hidden_layers
        self.hidden_dim = hidden_dim
        self.encoding_hidden = encoding_hidden

        self.net = MLP(
            input_dim, feature_dim, hidden_dim, num_hidden_layers, normalize_input=False
        )

    def forward(self, state: torch.Tensor) -> EncoderOutput:
        x = self.net(state)
        x = F.elu(x)
        hidden = 0.01 * torch.randn(x.shape[0], self.encoding_hidden).to(x.device)
        return EncoderOutput(DiracDistribution(x), hidden)


class ProbabilisticMLPEncoder(Encoder):
    def __init__(
        self,
        input_dim: int,
        num_hidden_layers: int,
        hidden_dim: int,
        feature_dim: int,
        encoding_hidden: int = 50,
        normalize: bool = False,
    ):
        super().__init__()

        self.input_dim = input_dim
        self.feature_dim = feature_dim
        self.normalize = normalize
        self.num_hidden_layers = num_hidden_layers
        self.hidden_dim = hidden_dim
        self.encoding_hidden = encoding_hidden

        self.net = DoubleHeadMLP(
            input_dim,
            feature_dim,
            hidden_dim,
            num_hidden_layers,
            normalize_input=False,
            normalize_last_layer=False,
        )

    def forward(self, state: torch.Tensor) -> EncoderOutput:
        x, sigma = self.net(state)
        hidden = 0.01 * torch.randn(x.shape[0], self.encoding_hidden).to(x.device)
        return EncoderOutput(BoundedGaussian(x, sigma), hidden)


@deprecated
class ConvEncoder(Encoder):
    """Convolutional encoder for image-based observations. Taken from DrQ. It is not twinned"""

    def __init__(
        self, input_dim: Tuple[int, int, int], hidden_dim: int, feature_dim: int
    ):
        super().__init__()

        self.input_dim = input_dim
        self.feature_dim = feature_dim

        self.num_layers = 4
        self.num_filters = 32
        self.output_dim = 35
        self.output_logits = False
        self.feature_dim = feature_dim

        hw = conv_output_shape(84, kernel_size=3, stride=2)
        for _ in range(3):
            hw = conv_output_shape(hw, kernel_size=3, stride=1)

        self._convs = nn.Sequential(
            nn.Conv2d(input_dim[0], self.num_filters, 3, stride=2),
            nn.ELU(),
            nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1),
            nn.ELU(),
            nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1),
            nn.ELU(),
            nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1),
        )
        self._head = nn.Linear(self.num_filters * hw[0] * hw[1], self.feature_dim)
        self._normalization = nn.Sequential(
            nn.LayerNorm([self.feature_dim]),
            nn.Sigmoid(),
        )

    def _forward_conv(self, obs):
        obs1 = self._convs(obs)
        h1 = obs1.view(obs1.size(0), -1)
        return h1

    def forward(self, obs):
        h1 = self._forward_conv(obs)
        out1 = self._head(h1)
        out1 = self._normalization(out1)
        return out1
