import torch
import torch.nn.functional as F
from torch import nn
import math
import torch.distributions as td


class VAE(nn.Module):
    # Vanilla Variational Auto-Encoder
    def __init__(self, state_dim, action_dim, latent_dim, max_action, device, hidden_dim=750):
        super(VAE, self).__init__()
        self.e1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.e2 = nn.Linear(hidden_dim, hidden_dim)

        self.mean = nn.Linear(hidden_dim, latent_dim)
        self.log_std = nn.Linear(hidden_dim, latent_dim)

        self.d1 = nn.Linear(state_dim + latent_dim, hidden_dim)
        self.d2 = nn.Linear(hidden_dim, hidden_dim)
        self.d3 = nn.Linear(hidden_dim, action_dim)

        self.relu = nn.ReLU()
        self.max_action = max_action
        self.latent_dim = latent_dim
        self.device = device

    def forward(self, state, action):
        mean, std, log_std = self.encode(state, action)
        z = mean + std * torch.randn_like(std)
        u = self.decode(state, z)
        return u, mean, std, log_std

    def elbo_loss(self, state, action):
        mean, std, log_std = self.encode(state, action)
        z = mean + std * torch.randn_like(std)

        u = self.decode(state, z)
        recon_loss = ((u - action) ** 2).mean(-1)
        KL_loss = -0.5 * (1 + 2 * log_std - mean.pow(2) - std.pow(2)).mean(-1)
        vae_loss = recon_loss + KL_loss
        return vae_loss

    def encode(self, state, action):
        z = self.relu(self.e1(torch.cat([state, action], -1)))
        z = self.relu(self.e2(z))
        mean = self.mean(z)
        # Clamped for numerical stability
        log_std = self.log_std(z).clamp(-4, 15)
        std = torch.exp(log_std)
        return mean, std, log_std

    def decode(self, state, z=None):
        if z is None: # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
            z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-0.5, 0.5)

        a = self.relu(self.d1(torch.cat([state, z], -1)))
        a = self.relu(self.d2(a))
        if self.max_action is not None:
            return self.max_action * torch.tanh(self.d3(a))
        else:
            return self.d3(a)

    def train_vae(self, replay_buffer, vae_optimizer, batch, mod, t=0, logger=None):
        self.train()
        state, action, ind = replay_buffer.sample_vae(batch, mod)
        recon, mean, std, log_std = self.forward(state, action)
        recon_loss = F.mse_loss(recon, action)
        KL_loss = -0.5 * (1 + 2 * log_std - mean.pow(2) - std.pow(2)).mean()
        vae_loss = recon_loss + KL_loss

        vae_optimizer.zero_grad()
        vae_loss.backward()
        vae_optimizer.step()
        if logger:
            logger.log('train/vae_loss', vae_loss, t)
        self.eval()
        return ind


    @torch.no_grad()
    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        action = self.decode(state).cpu().data.numpy().flatten()
        return action
