import torch
import numpy as np


class Standardizer(torch.nn.Module):

    def __init__(self, mean: np.ndarray | torch.Tensor | float = 0.0, std: np.ndarray | torch.Tensor | float = 1.0):

        super().__init__()

        self._mean = torch.Tensor(mean) if type(mean) is np.ndarray else mean
        self._std = torch.Tensor(std) if type(std) is np.ndarray else std

    @staticmethod
    def standardize(x: np.ndarray) -> np.ndarray:

        variance = np.var(x)
        var_rec_sqrt = 1 / np.sqrt(variance + 1e-8)
        mean_val = np.mean(x)
        batch_values = x * var_rec_sqrt - mean_val * var_rec_sqrt

        return batch_values

    def set_stats(self, mean: np.ndarray | torch.Tensor, std: np.ndarray | torch.Tensor):

        self._mean = torch.Tensor(mean) if type(mean) is np.ndarray else mean
        self._std = torch.Tensor(std) if type(std) is np.ndarray else std

    def forward(self, x):
        return (x - self._mean) / self._std


class DeStandardizer(torch.nn.Module):

    def __init__(self, mean: np.ndarray | torch.Tensor | float = 0.0, std: np.ndarray | torch.Tensor | float = 1.0):

        super().__init__()

        self._mean = torch.Tensor(mean) if type(std) is np.ndarray else mean
        self._std = torch.Tensor(std) if type(std) is np.ndarray else std

    def set_stats(self, mean: np.ndarray | torch.Tensor, std: np.ndarray | torch.Tensor):
        self._mean = torch.Tensor(mean) if type(mean) is np.ndarray else mean
        self._std = torch.Tensor(std) if type(std) is np.ndarray else std

    def forward(self, x):
        return (x * self._std) + self._mean
