from nflows.transforms.base import Transform
import torch


def _sas(x, epsilon, delta):
    return torch.sinh((torch.arcsinh(x) + epsilon) / delta)


def _isas(x, epsilon, delta):
    return torch.sinh(delta * torch.arcsinh(x) - epsilon)


def _ldisas(x, epsilon, delta):
    return torch.log(
        torch.abs(
            delta
            * torch.cosh(epsilon - delta * torch.arcsinh(x))
            / torch.sqrt(1 + x**2)
        )
    )


def sas_forward(X, epsilon, delta):
    XX = _sas(X, epsilon, delta)
    ld = -_ldisas(XX, epsilon, delta)
    # ld = ld.flatten()
    return XX, ld


def sas_inverse(X, epsilon, delta):
    ld = _ldisas(X, epsilon, delta)
    # ld = ld.flatten()
    return _isas(X, epsilon, delta), ld


class SinhArcSinhTransform(Transform):
    def __init__(self, e, d):
        super().__init__()
        self.epsilon = e
        self.delta = d

    def forward(self, X, context=None):
        outputs, ld = sas_forward(X, self.epsilon, self.delta)
        return outputs, ld.flatten()  # ld.sum(dim=-1)

    def inverse(self, X, context=None):
        outputs, ld = sas_inverse(X, self.epsilon, self.delta)
        return outputs, ld.flatten()  # ld.sum(dim=-1) # TODO sumexceptbatch?
