import unittest
import torch

from sbsep.scaler import Scaler


class TestScaler(unittest.TestCase):
    def test_scaler(self):
        a, b = 0.0, 50.0

        t0 = torch.arange(51, dtype=float)
        sc = Scaler(a, b)
        t = sc.scale(t0)
        self.assertEqual(t.min(), -1)
        self.assertEqual(t.max(), 1)
        t_restored = sc.unscale(t)
        self.assertTrue(torch.sum((t_restored - t0) ** 2) ** 0.5 < 1e-8)


if __name__ == "__main__":
    unittest.main()
