import unittest

import torch

from scrawl.metrics import RegressionAccuracy


class TestRegressionAccuracy(unittest.TestCase):
    def test_regression_accuracy(self):
        metric = RegressionAccuracy(threshold=0.1)
        accuracy = metric(
            torch.tensor([0.5, 0.50, 2.0, 8.0, 0.4, 0.1, 0.2, 0.3]),
            torch.tensor([0.4, 0.49, 1.0, 8.0, 0.4, 0.1, 0.2, 0.3]),
        )

        self.assertEqual(accuracy, 0.75)


if __name__ == "__main__":
    unittest.main()
