from __future__ import print_function, division, absolute_import
from copy import deepcopy
import numpy as np
import scipy as sp
from time import time

from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.linear_model import Lasso
from sklearn.exceptions import ConvergenceWarning

from base_predictor import Stump
from distribution import GaussianThresholdDistribution as StumpDistribution
from base_model import BaseRKHSWeighting

try:
    from sklearn.utils._testing import ignore_warnings
except ImportError:
    from sklearn.utils.testing import ignore_warnings


def get_rks_params_from_rkhs_weighting(rkhs_weighting: BaseRKHSWeighting, keep_centers=False):
    params = {
        'base_predictor' : deepcopy(rkhs_weighting.base_pred),
        'distribution_class' : rkhs_weighting.dist.__class__,
        'distribution' : deepcopy(rkhs_weighting.dist)
    }
    if keep_centers:
        params['custom_centers'] = deepcopy(rkhs_weighting.get_features())
        params['n_neurons'] = rkhs_weighting.get_n_centers()
    return params

class _RKSEstimator(BaseEstimator):
    def __init__(self,
                 n_neurons=1000,
                 regularization=1e-5,
                 base_predictor=Stump(),
                 distribution_class=StumpDistribution,
                 distribution=None,
                 custom_centers=None,
                 sigma=1.0,
                 divide_reg_by_T=True,
                 solver = 'default',
                 rng=None,
                 **distribution_params):
        self.n_neurons = n_neurons
        self.regularization = regularization
        self.base_predictor = base_predictor
        self.distribution_class = distribution_class
        self.distribution = distribution
        self.custom_centers = custom_centers
        self.sigma = sigma
        self.divide_reg_by_T = divide_reg_by_T
        self.solver = solver
        self.rng = np.random.default_rng(rng)
        self.distribution_params = distribution_params

    def preprocess_targets(self, y):
        return y
    
    def initial_sampling(self, X: np.ndarray):
        # define the sampling distribution
        if self.distribution is None:
            n = X.shape[1]
            self.dist = self.distribution_class(n_dim=n, sigma=self.sigma, rng=self.rng, **self.distribution_params)
        else:
            self.dist = deepcopy(self.distribution)

        # generate random weights
        if self.custom_centers is None:
            self.features = [self.dist.sample() for _ in range(self.n_neurons)]
        else:
            self.features = self.custom_centers
    
    def get_regularisation(self):
        """
            Dividing by T is justified by Equation 39 of the RW paper.

            This is caused by the model being divided by T.
        """
        reg = self.regularization
        if self.divide_reg_by_T:
            reg /= self.n_neurons
        return reg

    def fit(self, X, y):
        if self.solver == 'lasso':
            return self._lasso_fit(X, y)
        elif self.solver == 'default':
            return self._default_fit(X, y)
    
    def _default_fit(self, X, y):
        Y = self.preprocess_targets(y)
        start = time()
        self.initial_sampling(X)
        self.sample_time_ = time() - start

        start = time()
        Psi = self.get_hidden(X)  # m x n_neurons
        self.propagation_time_ = time() - start

        # define the linear system to solve
        start = time()
        reg = self.get_regularisation()
        A = Psi.T @ Psi + reg * self.get_regularization_gram(X=X)
        B = Psi.T @ Y

        # compute the output weights
        try:
            self.output_weights = sp.linalg.cho_solve(sp.linalg.cho_factor(A), B)
        except:
            self.output_weights = np.linalg.solve(A, B)
        self.solve_time_ = time() - start
        return self
    
    @ignore_warnings(category=ConvergenceWarning)
    def _lasso_fit(self, X, y):
        Y = self.preprocess_targets(y)
        self.initial_sampling(X)
        reg = self.get_regularisation()
        Psi = self.get_hidden(X)  # m x n_neurons

        # compute the output weights
        lasso = Lasso(alpha=reg, fit_intercept=False, max_iter=5*self.n_neurons)
        lasso.fit(Psi, Y)
        self.output_weights = lasso.coef_.copy().T
        self.remove_useless_features()

        return self
    
    def kernel(self, X: np.ndarray, Y: np.ndarray) -> np.ndarray:
        base_pred_X = self.base_predictor.eval(self.features, X) # m1 x T
        base_pred_Y = self.base_predictor.eval(self.features, Y) # m2 x T
        output = (1 / self.n_neurons) * base_pred_X @ base_pred_Y.T
        return output
    
    def get_hidden(self, X):
        return self.base_predictor.eval(self.features, X) / self.n_neurons
    
    def get_regularization_gram(self, **kwargs):
        return np.eye(self.n_neurons)
    
    def l2p_norm_approx(self) -> float:
        return np.max(self.output_weights)
    
    def remove_useless_features(self, tol=1e-10):
        raise NotImplementedError
    
    def get_n_centers(self):
        return len(self.features)

class RKSClassifier(_RKSEstimator, ClassifierMixin):
    def __init__(self,
                 n_neurons=1000,
                 regularization=1e-5,
                 base_predictor=Stump(),
                 distribution_class=StumpDistribution,
                 distribution=None,
                 custom_centers=None,
                 sigma=1.0,
                 divide_reg_by_T=True,
                 solver='default',
                 rng=None,
                 **distribution_params):
        self.n_neurons = n_neurons
        self.regularization = regularization
        self.base_predictor = base_predictor
        self.distribution_class = distribution_class
        self.distribution = distribution
        self.custom_centers = custom_centers
        self.sigma = sigma
        self.divide_reg_by_T = divide_reg_by_T
        self.solver = solver
        self.rng = np.random.default_rng(rng)
        self.distribution_params = distribution_params

    def preprocess_targets(self, y):
         # generate one-hot vector matrix
        self.classes = np.unique(y)
        if set(self.classes) == {-1, 1} and len(self.classes) == 2:
            self.classes = np.array([1, -1])
        self.n_classes = len(self.classes)
        Y = np.zeros((len(y), self.n_classes))
        for r in range(self.n_classes):
            Y[:, r][y == self.classes[r]] = 1
            Y[:, r][y != self.classes[r]] = -1
        return Y
    
    def raw_output(self, X):
        if len(self.features) == 0:
            return np.array([0] * X.shape[0])
        Psi = self.get_hidden(X)  # m x n_neurons
        output = Psi @ self.output_weights
        return output[:, 0]

    def predict(self, X):
        if len(self.features) == 0:
            return [self.classes[0]] * X.shape[0]
        Psi = self.get_hidden(X)  # m x n_neurons
        return self.classes[np.argmax(Psi @ self.output_weights, axis=1)]
    
    def remove_useless_features(self, tol=1e-10):
        nonzero_indices = np.where(np.abs(self.output_weights[:, 0]) > tol)[0]
        self.features = [self.features[i] for i in nonzero_indices]
        self.output_weights = self.output_weights[nonzero_indices, :]
        self.n_non_zero_coefs_ = len(nonzero_indices)
    
class RKSRegressor(_RKSEstimator, RegressorMixin):
    def __init__(self,
                 n_neurons=1000,
                 regularization=1e-5,
                 base_predictor=Stump(),
                 distribution_class=StumpDistribution,
                 distribution=None,
                 custom_centers=None,
                 sigma=1.0,
                 divide_reg_by_T=True,
                 solver='default',
                 rng=None,
                 **distribution_params):
        self.n_neurons = n_neurons
        self.regularization = regularization
        self.base_predictor = base_predictor
        self.distribution_class = distribution_class
        self.distribution = distribution
        self.custom_centers = custom_centers
        self.sigma = sigma
        self.divide_reg_by_T = divide_reg_by_T
        self.solver = solver
        self.rng = np.random.default_rng(rng)
        self.distribution_params = distribution_params
    
    def predict(self, X):
        return self.raw_output(X)
    
    def raw_output(self, X):
        if len(self.features) == 0:
            return np.array([0] * X.shape[0])
        Psi = self.get_hidden(X)  # m x n_neurons
        return Psi @ self.output_weights

    def remove_useless_features(self, tol=1e-10):
        # Ensure output_weights is 1D for regression
        output_weights_flat = np.ravel(self.output_weights)
        nonzero_indices = np.where(np.abs(output_weights_flat) > tol)[0]
        self.features = [self.features[i] for i in nonzero_indices]
        self.output_weights = output_weights_flat[nonzero_indices]