import unittest
import numpy as np
from parameterized import parameterized
from sklearn.metrics import r2_score, mean_squared_error
from scipy.stats import loguniform, uniform

from learners import LeastSquaresLearner, LassoLearner
from models.sign import RWSign
from rkhs_weightings import RKHSWeightingClassifier, RKHSWeightingRegressor
from rkhs_weightings import RKHSWeightingGridSearchCV, RKHSWeightingRandomSearchCV

from test_common import SAMPLE_SIZE, N_DIM, INSTANTIATIONS
from test_common import make_scaled_classification, make_scaled_regression

class TestRKHSWeightingClassifier(unittest.TestCase):
    @parameterized.expand([[I] for I in INSTANTIATIONS])
    def test_mse_smaller_than_one(self, I):
        for _ in range(10):
            X, y = make_scaled_classification(n_samples=SAMPLE_SIZE, n_features=N_DIM)
            model = I(X, data_y=y)
            learner = LeastSquaresLearner()
            clf = RKHSWeightingClassifier(learner, model).fit(X, y)
            X_pred = clf.raw_output(X)
            mse = mean_squared_error(y, X_pred)
            zeros = np.zeros_like(y)
            msg = f'MSE greater than 1 {model.__class__.__name__}'
            self.assertLessEqual(mse, 1, msg)

class TestRKHSWeightingCV(unittest.TestCase):
    def test_gridsearch_cv(self):
        estimator_class = RKHSWeightingClassifier
        learner_class = LeastSquaresLearner
        model_class = RWSign
        learner_param_grid = {
            'n_iter' : [50, 100],
            'regularization' : [1e-5, 1e-6]
        }
        model_param_grid = {
            'max_theta' : [0.5, 0.9]
        }
        cv = RKHSWeightingGridSearchCV(estimator_class, learner_class, model_class, 
                                       learner_param_grid, model_param_grid, verbose=False)
        X, y = make_scaled_classification()
        cv.fit(X, y)

        score = cv.score(X, y)
        msg = f"\n RKHSWeightingGridSearchCV clf has error of {1-score}"
        self.assertGreaterEqual(score, 0.5, msg) 

    def test_random_cv(self):
        estimator_class = RKHSWeightingClassifier
        learner_class = LeastSquaresLearner
        model_class = RWSign
        learner_param_grid = {
            'n_iter' : [50, 100],
            'regularization' : loguniform(1e-6, 1e6)
        }
        model_param_grid = {
            'max_theta' : uniform(loc=0.5, scale=0.4)
        }
        cv = RKHSWeightingRandomSearchCV(estimator_class, learner_class, model_class, 
                                         learner_param_grid, model_param_grid, n_iter=5, verbose=False)
        X, y = make_scaled_classification()
        cv.fit(X, y)

        score = cv.score(X, y)
        msg = "\n RKHSWeightingRandomSearchCV clf has error of {}".format(1-score)
        self.assertGreaterEqual(score, 0.7, msg) 

class TestRKHSWeightingRegressor(unittest.TestCase):
    @parameterized.expand([[I] for I in INSTANTIATIONS])
    def test_positive_R2(self, I):
        X, y = make_scaled_regression(n_samples=100, n_features=5)
        model = I(X, data_y=y)
        learner = LeastSquaresLearner()
        clf = RKHSWeightingRegressor(learner, model).fit(X, y)
        score = r2_score(y, clf.predict(X))
        msg = f'Negative training R2 score for {I.__name__}'
        self.assertGreaterEqual(score, 0, msg) 


if __name__ == '__main__':
    unittest.main()