import torch
import torch.nn as nn

DIM_LATENT = 32


class Model(nn.Module):
    def __init__(self, dim_out, args):
        super(Model, self).__init__()

        self._enc = nn.Sequential(
            nn.Linear(args["state_dim"] + args["action_dim"], 256),
            nn.ReLU(),
            nn.Linear(256, DIM_LATENT),
        )
        self._dec = nn.Sequential(
            nn.Linear(DIM_LATENT, 256),
            nn.ReLU(),
            # LayerNorm(256),
            nn.Linear(256, dim_out),
        )

        self.optim = torch.optim.Adam(list(self._enc.parameters()) + list(self._dec.parameters()), lr=0.0005)
        self.criterion = nn.MSELoss()

    def forward(self, _in, if_return_latent=False):
        e = self._enc(torch.cat(_in, dim=-1))
        if if_return_latent:
            return e
        else:
            return self._dec(e)


class LayerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, affine=True):
        super(LayerNorm, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps

        if self.affine:
            self._gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
            self.beta = nn.Parameter(torch.zeros(num_features))

    def forward(self, x):
        shape = [-1] + [1] * (x.dim() - 1)
        mean = x.view(x.size(0), -1).mean(dim=1).view(*shape)
        std = x.view(x.size(0), -1).std(dim=1).view(*shape)

        y = (x - mean) / (std + self.eps)
        if self.affine:
            shape = [1, -1] if x.dim() == 2 else [1, 1, -1]
            # shape = [1, -1] + [1] * (x.dim() - 2)
            y = self._gamma.view(*shape) * y + self.beta.view(*shape)
        return y
