import math
import numpy as np
from typing import Type
import unittest
from parameterized import parameterized
from sklearn.metrics import r2_score

from base_model import BaseRKHSWeighting
from base_predictor import StumpParam
from learners import BaseLearner, LeastSquaresLearner
from base_model import BaseRKHSWeighting
from models.stump import RWStumps
from models.relu import RWRelu, RWExpRelu
from models.sign import RWSign
from rkhs_weightings import RKHSWeightingClassifier
from test_common import SMALL_SAMPLE_SIZE, N_DIM, SMALL_N_ITER, SAMPLE_SIZE, N_ITER, N_DIM, RNG
from test_common import MC_PRECISION, N_MC, LARGE_N_MC
from test_common import LEARNERS, INSTANTIATIONS, ANALYTICAL_INSTANTIATIONS, LOSSES
from test_common import make_scaled_classification

def is_method_overridden(cls, method_name):
    base_method = getattr(BaseRKHSWeighting, method_name, None)
    sub_method = getattr(cls, method_name, None)
    return base_method is not sub_method

def basic_classification_error(learner_class: Type[BaseLearner], model_class: Type[BaseRKHSWeighting]):
    X, y = make_scaled_classification()
    model = model_class(data_x=X, data_y=y, rng=RNG)
    learner = learner_class(n_iter=N_ITER, rng=RNG)
    clf = RKHSWeightingClassifier(learner, model).fit(X, y)
    pred = clf.predict(X)
    error = np.mean(y != pred)
    return error

def basic_classification_error_using_specific_loss(model_class: Type[BaseRKHSWeighting], loss):
    X, y = make_scaled_classification()
    model = model_class(data_x=X, data_y=y, rng=RNG)
    learner = LeastSquaresLearner(loss=loss(), n_iter=N_ITER, rng=RNG)
    clf = RKHSWeightingClassifier(learner, model).fit(X, y)
    pred = clf.predict(X)
    error = sum(y != pred)/SAMPLE_SIZE
    return error

def max_output_after_one_iter(model_class: Type[BaseRKHSWeighting]):
    X, y = make_scaled_classification()
    model = model_class(data_x=X, data_y=y, rng=RNG)
    learner = LeastSquaresLearner(n_iter=1, rng=RNG)
    clf = RKHSWeightingClassifier(learner, model).fit(X, y)
    max_theoretical = clf.model.max_output()
    max_training = np.max(np.abs(clf.model.output(clf.data_)))
    return max_training, max_theoretical

def vector_norm_relative_diff(a, b):
    diff_norm = np.linalg.norm(a - b)
    norm_a = np.linalg.norm(a)
    norm_b = np.linalg.norm(b)
    
    return diff_norm / max(norm_a, norm_b)

class TestInstantiations(unittest.TestCase):
    @parameterized.expand([[L, I] for L in LEARNERS for I in INSTANTIATIONS])
    def test_acceptable_error(self, L, I):
        error = basic_classification_error(L, I)
        msg = "\n{} and {}: Error above 50%".format(L.__name__, I.__name__)
        self.assertGreater(0.5, error, msg)

    @parameterized.expand([[I, loss] for I in INSTANTIATIONS for loss in LOSSES])
    def test_acceptable_error_using_specific_loss(self, I, loss):
        error = basic_classification_error_using_specific_loss(I, loss)
        msg = "\n{} and {}: Error above 50%".format(I.__name__, loss.__name__)
        self.assertGreater(0.5, error, msg)

    @parameterized.expand([[I] for I in INSTANTIATIONS])
    def test_max_output(self, I):
        for _ in range(10):
            max_training, max_theoretical = max_output_after_one_iter(I)
            msg = "\n{}: Training output {} greater than max theoretical value {}".format(I.__name__, max_training, max_theoretical)
            self.assertGreaterEqual(max_theoretical, max_training, msg) 

    @parameterized.expand([[I] for I in INSTANTIATIONS])
    def test_constants(self, I):
        for _ in range(10):
            X, y = make_scaled_classification()
            model = I(data_x=X, data_y=y, rng=RNG)
            theta = model.theta()
            iota = model.iota()
            kappa = model.kappa()
            self.assertGreaterEqual(theta, 0, f'\n{I.__name__}: theta = {theta} < 0')
            self.assertGreaterEqual(1.2*iota, theta, f'\n{I.__name__}: 1.2*iota = {1.2*iota} < theta = {theta}')
            self.assertGreaterEqual(1.2*kappa, iota, f'\n{I.__name__}: 1.2*kappa = {1.2*kappa} < iota = {iota}')

    @parameterized.expand([[I] for I in INSTANTIATIONS])
    def test_rng(self, I):
        X, y = make_scaled_classification()
        model1 = I(data_x=X, data_y=y, rng=1)
        model2 = I(data_x=X, data_y=y, rng=1)
        msg = f'{I.__name__} : Two sampled centers with same random seed are different.'
        value1 = model1.sample_center()
        value2 = model2.sample_center()
        if isinstance(value1, np.ndarray):
            value1 = np.sum(value1)
            value2 = np.sum(value2)
        self.assertEqual(value1, value2, msg)


class TestI1(unittest.TestCase):

    def test_expectation_center_is_0(self):
        X = np.array([1]).reshape(1,1)
        model = RWSign(data_x=X, sigma=1, gamma=1, max_theta=0.5)
        model.add_center([0], 1)
        correct_expectation = 0
        returned_expectation = model.expectations(X)[0, 0]
        self.assertAlmostEqual(correct_expectation, returned_expectation)

    def test_expectation_center_is_1(self):
        X = np.array([1]).reshape(1,1)
        model = RWSign(data_x=X, sigma=1, gamma=1, max_theta=0.5)
        model.add_center([1], 1)
        correct_expectation = math.exp(-0.25) / math.sqrt(2) * math.erf(0.5)
        returned_expectation = model.expectations(X)[0, 0]
        self.assertAlmostEqual(correct_expectation, returned_expectation)


class TestI2(unittest.TestCase):

    def test_expectation_center_is_0(self):
        X = np.array([1]).reshape(1,1)
        model = RWRelu(sigma=1, gamma=1, max_theta=0.5, data_x=X)
        model.add_center([0], 1)
        correct_expectation = 1 / math.sqrt(8 * math.pi)
        returned_expectation = model.expectations(X)[0, 0]
        self.assertAlmostEqual(correct_expectation, returned_expectation)

    def test_expectation_center_is_1(self):
        X = np.array([1]).reshape(1,1)
        model = RWRelu(sigma=1, gamma=1, max_theta=0.5, data_x=X)
        model.add_center([1], 1)
        correct_expectation = math.exp(-0.25) \
                              * (math.exp(-0.25) + math.sqrt(math.pi / 4) * (1 + math.erf(0.5))) \
                              / math.sqrt(8 * math.pi)
        returned_expectation = model.expectations(X)[0, 0]
        self.assertAlmostEqual(correct_expectation, returned_expectation)


class TestI3(unittest.TestCase):
        
    def test_expectation(self):
        X = np.array([1]).reshape(1,1)
        model = RWStumps(data_x=X, sigma=1, gamma=1)
        model.add_center(StumpParam(0, 0), 1)
        correct_expectation = math.erf(1) / math.sqrt(2)
        returned_expectation = model.expectations(X)[0, 0]
        self.assertAlmostEqual(correct_expectation, returned_expectation)


class TestI4(unittest.TestCase):
    pass
        

class TestI5(unittest.TestCase):
    pass

class TestI6(unittest.TestCase):
    def test_expectation(self):
        X = np.array([[-1], [1]]).reshape(2,1)
        model = RWExpRelu(sigma=1, gamma=1, data_x=X)
        model.add_center([0], 1)
        x = np.array([1]).reshape(1,1)
        pi = math.pi
        terms = [1 / math.sqrt(2 * pi),
                 math.exp(0),
                 1 / math.sqrt(2) * math.sqrt(2)]
        correct_expectation = math.prod(terms)                           
        returned_expectation = model.expectations(X)[0, 0]
        self.assertAlmostEqual(correct_expectation, returned_expectation)
        
if __name__ == '__main__':
    unittest.main()