import torch
import torch.nn as nn
import torch.nn.functional as F


class ExpL(nn.Module):
    '''
    Exponential Linear Function
        exp(x) if x < 0
        x + 1  if x > 0
    '''
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        zero = torch.zeros_like(x)
        return torch.minimum(zero, x).exp() + torch.maximum(zero, x)


class LogExpL(nn.Module):
    '''
    log of Exponential Linear Function
        log[exp(x)] if x < 0 --> x
        log[x + 1 ] if x > 0 --> log1p(x)
    '''
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        zero = torch.zeros_like(x)
        return torch.minimum(zero, x) + torch.maximum(zero, x).log1p()


# `torch.nn` already has `LogSigmoid`
# class LogSigmoid(nn.Module):
#     '''
#     log of Sigmoid Function
#            log[1/(1+exp(-x))] 
#         = -log[1+exp(-x)] 
#         = -F.softplus(-x)
#         = F.softplus(x, beta=-1)
#     '''
#     def __init__(self):
#         super().__init__()
#     
#     def forward(self, x):
#         return F.softplus(x, beta=-1)

