from jaxtyping import Float
from torch import Tensor, nn


class Upscaler(nn.Module):
    """Upscaler Module.

    Normalize x and transform it to specified dimension.
    """

    def __init__(self, d_input: int, d_output: int) -> None:
        """Initialize the Residual Graph Neural Network.

        :param d_input: Number of features of input.
        :param d_output: Number of features of output.
        """
        super().__init__()

        self.net = nn.Sequential(
            nn.BatchNorm1d(d_input), nn.Linear(d_input, d_output), nn.LeakyReLU()
        )

    def forward(self, x: Float[Tensor, "n_nodes d_input"]) -> Float[Tensor, "n_nodes d_output"]:
        """Perform upscaling for feature matrix x.

        :param x: Feature matrix.
        :return: Upscaled feature matrix
        """
        return self.net(x)
