import torch
import torch.nn as nn
import torch.nn.functional as F


class AffineNorm(nn.Module):
    """Applies (x - mu) / sigma with learnable sigma > 0."""

    def __init__(self, mu=0.0, sigma=1.0):
        super().__init__()
        mu = torch.tensor(float(mu))
        sigma = torch.tensor(float(sigma))
        if torch.any(sigma <= 0):
            raise ValueError("sigma must be > 0")
        b = -mu
        a = torch.log(torch.expm1(1 / sigma))  # inverse sofplus
        self.a = nn.Parameter(a, requires_grad=True)
        self.b = nn.Parameter(b, requires_grad=True)

    def forward(self, x):
        return F.softplus(self.a) * (x + self.b)
