import torch
import torch.nn as nn

def clamp_preserve_gradients(x, min, max):
    """Clamp the tensor while preserving gradients in the clamped region."""
    return x + (x.clamp(min, max) - x).detach()

class NegativeLog(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, **kwargs):
        eps = torch.finfo(x.dtype).tiny
        y = -torch.log(clamp_preserve_gradients(x, min=eps, max=None))
        log_jac = y
        return y, log_jac

    def inverse(self, y, **kwargs):
        x = torch.exp(-y)
        log_jac = -y
        return x, log_jac


class ExpNegative(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y, **kwargs):
        x = torch.exp(-y)
        log_jac = -y
        return x, log_jac

    def inverse(self, x, **kwargs):
        eps = torch.finfo(x.dtype).tiny
        y = -torch.log(clamp_preserve_gradients(x, min=eps, max=None))
        log_jac = y
        return y, log_jac
