import torch
import torch.nn as nn
import torch.nn.functional as F
from .perception import MLP

MIN_LOGSTD = -20.
MAX_LOGSTD = 2.


class Encoder(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_size, n_hiddens=2, normalize=True):
        super().__init__()
        hidden_sizes = [hidden_size] * n_hiddens
        self.layers = MLP(input_dim, *hidden_sizes, output_dim)
        self.normalize = normalize

    def forward(self, embedded_obs):
        x = self.layers(embedded_obs)
        if self.normalize:
            return F.normalize(x, p=2.0, dim=-1)
        else:
            return x