import numpy as np
import unittest

from base_predictor import StumpParam

class TestBasePredictor(unittest.TestCase):

    def test_stump(self):
        from base_predictor import Stump
        X = np.arange(9).reshape(3,3)
        centers = np.array([StumpParam(0, -1),StumpParam(1,3.5), StumpParam(2,9)])
        correct_output = np.array([[1, -1, -1], [1, 1, -1], [1, 1, -1]])
        returned_output = Stump().eval(centers, X)
        self.assertTrue(np.all(correct_output == returned_output))


if __name__ == '__main__':
    unittest.main()