import math
import torch

from batlinet.builders import DATA_TRANSFORMATIONS
from batlinet.data.transformation.base import BaseDataTransformation


@DATA_TRANSFORMATIONS.register()
class MinMaxScaleDataTransformation(BaseDataTransformation):
    def __init__(self, min_val: float = None, max_val: float = None):
        self.user_min = min_val
        self.user_max = max_val
        self._fitted = False

    @torch.no_grad()
    def transform(self, data: torch.Tensor) -> torch.Tensor:
        if not self._fitted:
            self.min = data.min() if self.user_min is None else torch.tensor(self.user_min, device=data.device)
            self.max = data.max() if self.user_max is None else torch.tensor(self.user_max, device=data.device)
            self._fitted = True

        # Scale to [0, 1]
        scaled = (data - self.min) / (self.max - self.min + 1e-8)
        # Scale to [-1, 1]
        return scaled * 2 - 1

    @torch.no_grad()
    def inverse_transform(self, data: torch.Tensor) -> torch.Tensor:
        # Rescale from [-1, 1] to [0, 1]
        scaled = (data + 1) / 2
        # Rescale to original range
        return scaled * (self.max - self.min + 1e-8) + self.min
