import torch
import torch.nn as nn


def actmodule(activation):
    if activation == 'softplus':
        return nn.Softplus()
    elif activation == 'relu':
        return nn.ReLU()
    elif activation == 'leakyrelu':
        return nn.LeakyReLU()
    elif activation == 'prelu':
        return nn.PReLU()
    elif activation == 'elu':
        return nn.ELU()
    elif activation == 'tanh':
        return nn.Tanh()
    else:
        raise ValueError('unknown activation function specified')


def draw_normal(mean, lnvar):
    std = torch.exp(0.5*lnvar)
    eps = torch.randn_like(std)
    return mean + eps*std


@torch.jit.script
def kldiv_normal_normal(mean1, lnvar1, mean2, lnvar2):
    if lnvar1.ndim == 2 and lnvar2.ndim == 2:
        return 0.5 * torch.sum((lnvar1-lnvar2).exp() - 1.0 + lnvar2 - lnvar1 + (mean2-mean1).pow(2)/lnvar2.exp(), dim=1)
    elif lnvar1.ndim == 1 and lnvar2.ndim == 1:
        d = mean1.shape[1]
        return 0.5 * (d*((lnvar1-lnvar2).exp() - 1.0 + lnvar2 - lnvar1) + torch.sum((mean2-mean1).pow(2), dim=1)/lnvar2.exp())
    else:
        raise ValueError()