import torch
from torch import nn


class StandardScalerFromUniform(nn.Module):
    """
    A layer to normalize input features using standard scaling. The input
    features are assumed to be uniformly distributed within given min and max
    values.

    Args:
        lb (torch.Tensor): A 1D tensor of lower bounds for each feature.
        ub (torch.Tensor): A 1D tensor of upper bounds for each feature.
    """

    lb: torch.Tensor
    ub: torch.Tensor
    mean: torch.Tensor
    std: torch.Tensor

    def __init__(self, lb: torch.Tensor, ub: torch.Tensor):
        super().__init__()

        # Register min, max, and range as buffers
        self.register_buffer("lb", lb)
        self.register_buffer("ub", ub)

        mean = (self.ub + self.lb) / 2
        std = (self.ub - self.lb) / (12**0.5)
        self.register_buffer("mean", mean)
        self.register_buffer("std", std)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Applies standard normalization."""
        return (x - self.mean) / (self.std + 1e-8)

class MinMaxScaler(nn.Module):
    """
    A layer to normalize input features using min-max scaling to the range [-1, 1].

    Args:
        lb (torch.Tensor): A 1D tensor of lower bounds for each feature.
        ub (torch.Tensor): A 1D tensor of upper bounds for each feature.
    """

    lb: torch.Tensor
    ub: torch.Tensor
    range: torch.Tensor

    def __init__(self, lb: torch.Tensor, ub: torch.Tensor):
        super().__init__()

        # Register min, max as buffers
        self.register_buffer("lb", lb)
        self.register_buffer("ub", ub)
        self.register_buffer("range", self.ub - self.lb)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Applies min-max normalization."""
        return 2 * (x - self.lb) / (self.range + 1e-8) - 1
