import unittest
from parameterized import parameterized
from typing import Type

from base_model import BaseRKHSWeighting
from learners import LeastSquaresLearner
from rkhs_weightings import RKHSWeightingClassifier
from test_common import INSTANTIATIONS, make_scaled_classification, RNG, N_ITER, make_scaled_regression
from rks import *

from sklearn.metrics import r2_score


def basic_classification_error(model_class: Type[BaseRKHSWeighting], solver='default'):
    X, y = make_scaled_classification()
    model = model_class(data_x=X, data_y=y, rng=RNG)
    rks_params = get_rks_params_from_rkhs_weighting(model)
    rks_params['n_neurons'] = N_ITER
    rks_params['solver'] = solver
    clf = RKSClassifier(**rks_params).fit(X, y)
    pred = clf.predict(X)
    error = np.mean(y != pred)
    return error

def basic_classification_error_with_custom_centers(model_class: Type[BaseRKHSWeighting]):
    X, y = make_scaled_classification()
    model = model_class(data_x=X, data_y=y, rng=RNG)
    learner = LeastSquaresLearner(n_iter=N_ITER, rng=RNG)
    clf = RKHSWeightingClassifier(learner, model).fit(X, y)
    rks_params = get_rks_params_from_rkhs_weighting(clf.model, keep_centers=True)
    clf = RKSClassifier(**rks_params).fit(X, y)
    pred = clf.predict(X)
    error = np.mean(y != pred)
    return error

def basic_regression_r2(model_class: Type[BaseRKHSWeighting], n_iter=N_ITER, solver='default'):
    X, y = make_scaled_regression()
    model = model_class(data_x=X, data_y=y, rng=RNG)
    rks_params = get_rks_params_from_rkhs_weighting(model)
    rks_params['n_neurons'] = n_iter
    rks_params['solver'] = solver
    clf = RKSRegressor(**rks_params).fit(X, y)
    pred = clf.predict(X)
    r2 = r2_score(y, pred)
    return r2

class TestInstantiations(unittest.TestCase):
    @parameterized.expand([[I] for I in INSTANTIATIONS])
    def test_acceptable_error(self, I):
        error = basic_classification_error(I)
        msg = "\n{}: Error above 50%".format(I.__name__)
        self.assertGreater(0.5, error, msg)
        
    @parameterized.expand([[I] for I in INSTANTIATIONS])
    def test_acceptable_error_lasso(self, I):
        error = basic_classification_error(I, solver='lasso')
        msg = "\n{}: Error above 50%".format(I.__name__)
        self.assertGreater(0.5, error, msg)

    @parameterized.expand([[I] for I in INSTANTIATIONS])
    def test_acceptable_error_with_custom_centers(self, I):
        error = basic_classification_error(I)
        msg = "\n{}: Error above 50%".format(I.__name__)
        self.assertGreater(0.5, error, msg)

    @parameterized.expand([[I] for I in INSTANTIATIONS])
    def test_r2_smaller_than_1(self, I):
        for n_iter in range(100, 501, 100):
            r2 = basic_regression_r2(I, n_iter=n_iter)
            msg = "R^2 above 1 for {}".format(I.__name__)
            self.assertGreater(1, r2, msg)

    @parameterized.expand([[I] for I in INSTANTIATIONS])
    def test_r2_smaller_than_1_lasso(self, I):
        for n_iter in range(100, 501, 100):
            r2 = basic_regression_r2(I, n_iter=n_iter, solver='lasso')
            msg = "R^2 above 1 for {}".format(I.__name__)
            self.assertGreater(1, r2, msg)

    @parameterized.expand([[I] for I in INSTANTIATIONS])
    def test_remove_useless_centers(self, I):
        for n_iter in range(10, 51, 10):
            X, y = make_scaled_regression()
            model = I(data_x=X, data_y=y, rng=RNG)
            rks_params = get_rks_params_from_rkhs_weighting(model)
            rks_params['n_neurons'] = n_iter
            rks_params['solver'] = 'default'
            clf = RKSRegressor(**rks_params).fit(X, y)
            clf.output_weights[0] = 0.0  # manually zero out one coefficient
            output_before = clf.raw_output(X)
            clf.remove_useless_features()
            output_after = clf.raw_output(X)
            msg = 'output changed after removing useless features for {}'.format(I.__name__)
            self.assertTrue(np.allclose(output_before, output_after), msg)


if __name__ == '__main__':
    unittest.main()