#!/usr/bin/python3

'''
Scaled Hyperbolic Tangent (sTanh)
'''
__copyright__ = '''This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details.'''

import torch
from torch import Tensor
from torch.nn.modules.module import Module


class sTanh(Module):
    r'''Applies the scaled Hyperbolic Tangent (sTanh) function element-wise.

    sTanh is defined as:

    .. math::
        \text{sTanh}(x) = 1.59223 \tanh(x) = 1.59223 \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}

    Shape:
        - input: :math:`(*)`, where :math:`*` means any number of dimensions.
        - output: :math:`(*)`, same shape as the input.

    Examples::

        >>> m = lmlp.stanh.sTanh()
        >>> input = torch.randn(2)
        >>> output = m(input)
    '''

    def forward(self, input: Tensor) -> Tensor:
        '''
        Return: sTanh(input)
        '''
        return 1.59223 * torch.tanh(input)
