from __future__ import annotations

from itertools import product
import math
import numpy as np
import scipy as sp
import seaborn ; seaborn.set()
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold, ParameterGrid, ParameterSampler
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
try:
    from sklearn.utils._testing import ignore_warnings
except ImportError:
    from sklearn.utils.testing import ignore_warnings
from time import time

from base_model import BaseRKHSWeighting
from models import RWSign
from base_predictor import sign
from learners import BaseLearner, LeastSquaresLearner
from rks import _RKSEstimator, get_rks_params_from_rkhs_weighting


class _RKHSWeightingEstimator(BaseEstimator):
    def __init__(self):
        super().__init__()
    """
    Utility class containing functions shared by RKHSWeightingRegressor and RKHSWeightingClassifier
    """   
    def _set_model(self, new_model: BaseRKHSWeighting):
        """
        Utility function that directly sets self.model to new_model.
        """
        self.model = new_model
    
    def _model_training_loss(self, model: BaseRKHSWeighting) -> np.ndarray:
        check_is_fitted(self)
        output = model.output(self.data_)
        targets = self.targets_
        return self.learner.loss.calculate(output, targets)
    
    def raw_output(self, X) -> np.ndarray:
        """
        Returns the size m array containing
        the output of the prediction model for all examples in X.

        Equations-wise, this is Lambda alpha(X).
        """
        return self.model.output(X)  


class RKHSWeightingRegressor(_RKHSWeightingEstimator, RegressorMixin):
    def __init__(self, learner: BaseLearner, model: BaseRKHSWeighting) -> None:
        self.learner = learner
        self.model = model
        super().__init__()

    def fit(self, X: np.ndarray, y: np.ndarray, **kwargs) -> RKHSWeightingRegressor: 
        X, y = check_X_y(X, y)
        self.data_ = X
        self.targets_ = y
        self.model = self.learner.fit_model(X, y, self.model, **kwargs)
        self._is_fitted = True
        return self

    def predict(self, X):
        check_is_fitted(self)
        X = check_array(X)
        return self.model.output(X)
    
    def calculate_loss(self, X: np.ndarray, y: np.ndarray) -> np.ndarray:
        output = self.raw_output(X)
        return self.learner.loss.calculate(output, y)
    
    def __sklearn_is_fitted__(self):
        """
        Check fitted status and return a Boolean value.
        """
        return hasattr(self, "_is_fitted") and self._is_fitted


class RKHSWeightingClassifier(_RKHSWeightingEstimator, ClassifierMixin):
    def __init__(self, learner: BaseLearner, model: BaseRKHSWeighting, is_fitted=False) -> None:
        self.learner = learner
        self.model = model
        super().__init__()
        # self.__sklearn_is_fitted__ = is_fitted

    def fit(self, X: np.ndarray, y: np.ndarray, **kwargs) -> RKHSWeightingClassifier: 
        self._preprocessing(X, y)
        data = self.data_
        targets = self.targets_
        self.model = self.learner.fit_model(data, targets, self.model, **kwargs)
        return self
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        check_is_fitted(self)
        X = check_array(X)
        targets = sign(self.raw_output(X))
        targets[targets == 0] = 1
        return self._targets_to_classes(targets)

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        return self._proba_from_output(self.raw_output(X))  
    
    # def _more_tags(self):
    #     return {'binary_only': True}
    
    def _preprocessing(self, X: np.ndarray, y: np.ndarray):
        X, y = check_X_y(X, y)
        self.data_ = X
        self.classes_, self.targets_ = self._classes_preprocessing(y)

    def _classes_preprocessing(self, y: np.ndarray):
        self.y_ = y
        self.classes_ = np.unique(y)
        if set(self.classes_) == {-1, 1} and len(self.classes_) == 2:
            self.classes_ = np.array([1, -1])
        self.targets_ = self._classes_to_targets(y)
        return self.classes_, self.targets_

    def _classes_to_targets(self, classes: np.ndarray) -> np.ndarray:
        return np.where(classes == self.classes_[0], 1, -1)

    def _targets_to_classes(self, targets: np.ndarray) -> np.ndarray:
        return np.where(targets == 1, self.classes_[0], self.classes_[1])

    def _proba_from_output(self, output: np.ndarray) -> np.ndarray:
        return self._logistic_function(output)

    def _logistic_function(self, array: np.ndarray) -> np.ndarray:
        return sp.special.expit(array) 

    def _proba_to_classes(self, proba: np.ndarray) -> np.ndarray:
        classes = np.zeros(proba.shape, self.classes_.dtype)
        classes[proba > 0.5] = self.classes_[0]
        classes[proba <= 0.5] = self.classes_[1]
        return classes
    
    def _model_training_01_loss(self, model: BaseRKHSWeighting) -> np.ndarray:
        proba = self._proba_from_output(model.output(self.data_))
        pred = self._proba_to_classes(proba)
        return 1 - accuracy_score(self.targets_, pred)
    
    def training_01_loss(self) -> np.ndarray:
        return self._model_training_01_loss(self.model_)
    
    def calculate_loss(self, X: np.ndarray, y: np.ndarray) -> np.ndarray:
        output = self.raw_output(X)
        targets = self._classes_to_targets(y)
        return self.learner.loss.calculate(output, targets)
    
class RKHSWeightingGridSearchCV(BaseEstimator):
    def __init__(self, 
                 estimator_class, 
                 learner_class=LeastSquaresLearner, 
                 model_class=RWSign, 
                 learner_param_grid={}, 
                 model_param_grid={}, 
                 folds=5, 
                 rng=None, 
                 verbose=True) -> None:
        self.estimator_class = estimator_class
        self.learner_class = learner_class
        self.model_class = model_class
        self.learner_param_grid = learner_param_grid
        self.model_param_grid = model_param_grid
        self.folds = folds
        self.rng = rng
        self.verbose = verbose
    
    def fit(self, X: np.ndarray, y: np.ndarray, **kwargs) -> RKHSWeightingGridSearchCV:
        self._initialize(X, y)
        n_fits = 0
        param_list = self.get_learner_param_list(self.learner_param_grid, self.model_param_grid)
        n_total_fits = len(param_list)
        for learner_params, model_params in param_list:
            start = time()
            self._fit_and_score_one_combination(learner_params, model_params)
            fit_time = time() - start
            n_fits += 1
            if self.verbose:
                print(f'Fit {n_fits} of {n_total_fits} done in {fit_time} seconds.')
        self._refit()
        return self
    
    def get_learner_param_list(self, learner_param_grid, model_param_grid):
        return list(product(list(ParameterGrid(learner_param_grid)), list(ParameterGrid(model_param_grid))))
    
    def _initialize(self, X, y):
        self.data_ = X
        self.targets_ = y
        self.best_score_ = -math.inf
        self.best_estimator_ = None
        self.best_learner_params_ = {}
        self.best_model_params_ = {}

    def _get_clf(self, learner_params: dict, model_params: dict):
        model_keys = list(model_params.keys()) if model_params is not None else []
        if 'data_x' not in model_keys and 'data_y' not in model_keys:
            model = self.model_class(data_x=self.data_, data_y=self.targets_, **model_params, rng=self.rng)
        else:
            model = self.model_class(**model_params, rng=self.rng)
        if issubclass(self.estimator_class, _RKSEstimator):
            rks_params = get_rks_params_from_rkhs_weighting(model)
            return self.estimator_class(**learner_params, **rks_params)
        elif issubclass(self.estimator_class, _RKHSWeightingEstimator):
            learner = self.learner_class(**learner_params, rng=self.rng)
            return self.estimator_class(learner, model)
    
    def _fit_and_score_one_combination(self, learner_params: dict, model_params: dict):
        score = self.avg_cv_score(learner_params, model_params, self.data_, self.targets_)
        if score > self.best_score_:
            self.best_score_ = score
            self.best_learner_params_ = learner_params
            self.best_model_params_ = model_params

    def _refit(self):
        start_refit = time()
        clf = self._get_clf(self.best_learner_params_, self.best_model_params_)
        clf.fit(self.data_, self.targets_)
        self.refit_time_ = time() - start_refit
        self.best_estimator_ = clf
        
    def predict(self, X):
        return self.best_estimator_.predict(X)
    
    def score(self, X, y):
        return self.best_estimator_.score(X, y)
    
    def avg_cv_score(self, learner_params: dict, model_params: dict, X: np.ndarray, y: np.ndarray) -> float:
        kf = KFold(n_splits=self.folds, shuffle=True, random_state=0) 
        total_score = 0
        for train_index, test_index in kf.split(X):
            X_train, X_test = X[train_index], X[test_index] 
            y_train, y_test = y[train_index], y[test_index]
            estimator = self._get_clf(learner_params, model_params)
            estimator.fit(X_train, y_train)
            total_score += estimator.score(X_test, y_test)
        return total_score / self.folds
    
class RKHSWeightingRandomSearchCV(RKHSWeightingGridSearchCV):
    def __init__(self, 
                 estimator_class, 
                 learner_class=LeastSquaresLearner, 
                 model_class=RWSign, 
                 learner_param_grid={}, 
                 model_param_grid={}, 
                 folds=5, 
                 n_iter=10, 
                 rng=None, 
                 verbose=True) -> None:
        self.estimator_class = estimator_class
        self.learner_class = learner_class
        self.model_class = model_class
        self.learner_param_grid = learner_param_grid
        self.model_param_grid = model_param_grid
        self.folds = folds
        self.n_iter = n_iter
        self.rng = rng
        self.verbose = verbose
    @ignore_warnings(category=UserWarning)
    def get_learner_param_list(self, learner_param_grid, model_param_grid):
        return list(zip((ParameterSampler(learner_param_grid, n_iter=self.n_iter, random_state=0)),
                     ParameterSampler(model_param_grid, n_iter=self.n_iter, random_state=0)))