import torch
import numpy as np


class Normalizer(torch.nn.Module):

    def __init__(self, _min: np.ndarray | torch.Tensor, _max: np.ndarray | torch.Tensor):

        super().__init__()

        self._min = torch.Tensor(_min) if type(_min) is np.ndarray else _min
        self._max = torch.Tensor(_max) if type(_max) is np.ndarray else _max

    def forward(self, x):
        return (x - self._min) / (self._max - self._min)


class DeNormalizer(torch.nn.Module):

    def __init__(self, _min: np.ndarray | torch.Tensor, _max: np.ndarray | torch.Tensor):
        super().__init__()

        self._min = torch.Tensor(_min) if type(_min) is np.ndarray else _min
        self._max = torch.Tensor(_max) if type(_max) is np.ndarray else _max

    def forward(self, x):
        return (x * (self._max - self._min)) + self._min
