import torch
import torch.nn as nn


class ActionSampler(nn.Module):
    def __init__(self, state_dim, action_dim, latent_dim=16, device="cuda"):
        super(ActionSampler, self).__init__()

        self.enc = nn.Sequential(nn.Linear(state_dim+action_dim, 128),
                                       nn.LeakyReLU(),
                                       nn.BatchNorm1d(128),
                                       nn.Linear(128, 128),
                                       nn.LeakyReLU(),
                                       nn.BatchNorm1d(128),
                                       nn.Linear(128, 128))

        self.enc_mean = nn.Linear(128, latent_dim)
        self.enc_log_std = nn.Linear(128, latent_dim)

        self.dec = nn.Sequential(nn.Linear(state_dim + latent_dim, 128),
                                       nn.LeakyReLU(),
                                       nn.BatchNorm1d(128),
                                       nn.Linear(128, 128),
                                       nn.LeakyReLU(),
                                       nn.BatchNorm1d(128),
                                       nn.Linear(128, 128),
                                       nn.Linear(128, action_dim))

        self.latent_dim = latent_dim
        self.device = device

    def forward(self, state, action):
        z = self.enc(torch.cat([state, action], 1))
        mean = self.enc_mean(z)

        # Clamped for numerical stability
        log_std = self.enc_log_std(z) #.clamp(-4, 15)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        u = self.decode(state, z)

        return u, mean, std

    def decode(self, state, z=None):
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        if z is None:
            z = torch.randn((state.shape[0], self.latent_dim)).to(self.device) #.clamp(-0.5,0.5)

        a = self.dec(torch.cat([state, z], 1))
        return a

    def save(self, filename):
        torch.save(self.state_dict(), filename)