import scipy as sp
from base_predictor import Sign, np
from distribution import GaussianDistribution, np
from kernel import ExponentialKernel, GaussianKernel, PolynomialKernel, np
from loss import np
from base_model import BaseRKHSWeighting
from utils import array_plus_vector, array_times_vector, get_scalar_over_norm, np


import numpy as np
from scipy.special import binom, erf
from sklearn.svm import LinearSVC


import math


class RWSign(BaseRKHSWeighting):
    """
    Instantiation of the model using Gaussian kernel and distribution, 
    and sign as the base predictor.

    Parameters
    ----------
    input : int or np.ndarray,
        Number of input features, or data array of shape (n_examples, n_features).

    sigma : {'auto'} or float, default='auto'
        Parameter of the Gaussian distribution.

    gamma : {'auto'} or float, default='auto'
        Parameter of the Gaussian kernel.

    max_theta : float, default=0.5
        Parameter defining the relationship between sigma and gamma, when
        sigma and/or gamma is 'auto'.

    dist_mean : None or np.array, default=None
        Mean vector of the distribution p.

    rng : RandomNumberGenerator or int or None, default=None
        The random number generator or randon seed to use.

    **kwargs
        Additional keyword arguments passed along to the Model __init__.

    Attributes
    ----------
    sigma : float
        The adjusted sigma parameter for the model.

    gamma : float
        The adjusted gamma parameter for the model.

    n_dim : int
        Number of dimensions of the input data.

    dist : Distribution
        Distribution of the parameters. p in the equations.

    kernel : Kernel
        K in the equations

    base : BasePredictor
        phi in the equations
    """
    def __init__(self, 
                 data_x, 
                 sigma=1.0, 
                 gamma='auto', 
                 max_theta=0.5,
                 dist_mean=None, 
                 rng=None, 
                 **kwargs) -> None:
        self.n_dim = data_x.shape[1]
        self.max_theta = max_theta
        self.sigma = sigma
        gamma = self.gamma = self.get_adjusted_gamma(gamma)
        dist = GaussianDistribution(n_dim=self.n_dim, mean=dist_mean, sigma=sigma, rng=rng)
        kernel = GaussianKernel(gamma=gamma)
        base = Sign()
        super().__init__(data_x, dist, kernel, base, rng, **kwargs)

    def get_and_set_mean_parameter(self, X, y):
        mean = self._get_good_mean_parameter(X, y)
        self._set_mean_parameter(mean)

    def _get_good_mean_parameter(self, X, y):
        clf = self._fit_svm(X, y)
        return clf.coef_.flatten()

    def _set_mean_parameter(self, param):
        self.dist.set_mean(param)

    def _fit_svm(self, X, y):
        self.svm = LinearSVC(random_state=0).fit(X, y)
        return self.svm

    def _get_fitted_svm(self):
        return self.svm
    
    def _exact_expectations(self, X: np.ndarray) -> np.ndarray:
        n = self.n_dim
        w_0 = self.dist.mean
        s = self.sigma
        g = self.gamma


        s2g2 = s**2 + g**2
        sqrt2zeta = 1 / math.sqrt(0.5 * (1/s**2 + 1/g**2))
        global_coef = (1 + s**2 / g**2)**(-n/2)

        W = np.array(self.get_features())
        W_norms = self.get_feature_norms()

        W_prime = (s**2 * W + g**2 * w_0) / s2g2

        w_0_norm = np.sqrt(w_0 @ w_0)
        if w_0_norm != 0:
            mixed_norms = np.linalg.norm(W / g**2 + w_0 / s**2, axis=1)
        else:
            mixed_norms = W_norms / g**2

        exp_norms = np.exp(-W_norms**2 / (2*g**2) \
                           -w_0_norm**2 / (2*s**2) \
                           +(s**2 * g**2 / (2*s2g2)) * mixed_norms**2)
        arr = get_scalar_over_norm(X, W_prime)
        expects = global_coef * array_times_vector(erf(arr / sqrt2zeta), exp_norms, axis=1)

        return expects

    def _exact_centered_expectations(self, X: np.ndarray) -> np.ndarray:
        n = self.n_dim
        s = self.sigma
        g = self.gamma
        s2g2 = s**2 + g**2

        W = self.get_features()
        W_norms = self.get_feature_norms()
        exp_norms = np.exp(-W_norms**2 / (2*s2g2))

        arr = get_scalar_over_norm(X, W)
        erf_coef = 1 / (g * math.sqrt(2 * (1 + g**2 / s**2)))
        global_coef = (1 + s**2 / g**2)**(-n/2)

        return global_coef * array_times_vector(erf(erf_coef * arr), exp_norms, axis=1)

    def true_theta(self):
        s = self.sigma
        g = self.gamma
        n = self.n_dim
        return (1 + 2*s**2 / g**2)**(-n/4)

    def true_iota(self):
        return 1.0

    def true_kappa(self):
        return 1.0

    def get_adjusted_gamma(self, gamma):
        if gamma == 'auto':
            return self.get_gamma_from_theta(self.max_theta)
        else:
            return float(gamma)

    def get_gamma_from_theta(self, theta):
        n = self.n_dim
        return math.sqrt(2) * self.sigma / math.sqrt(theta**(-4/n) - 1)



class RWExpSign(BaseRKHSWeighting):
    """
    Instantiation of the model using Exponential kernel and Gaussian distribution, 
    and sign as the base predictor.

    Parameters
    ----------
    input : int or np.ndarray,
        Number of input features, or data array of shape (n_examples, n_features).

    sigma : {'auto'} or float, default='auto'
        Parameter of the Gaussian distribution.

    gamma : {'auto'} or float, default='auto'
        Parameter of the Gaussian kernel.

    max_theta : float, default=1.5
        Parameter defining the relationship between sigma and gamma, when
        sigma and/or gamma is 'auto'.

    rng : RandomNumberGenerator or int or None, default=None
        The random number generator or randon seed to use.

    **kwargs
        Additional keyword arguments passed along to the Model __init__.

    Attributes
    ----------
    sigma : float
        The adjusted sigma parameter for the model.

    gamma : float
        The adjusted gamma parameter for the model.

    n_dim : int
        Number of dimensions of the input data.

    dist : Distribution
        Distribution of the parameters. p in the equations.

    kernel : Kernel
        K in the equations

    base : BasePredictor
        phi in the equations
    """
    def __init__(self, 
                 data_x: np.ndarray, 
                 sigma=1.0, 
                 gamma='auto', 
                 max_theta=1.5,
                 rng=None, 
                 **kwargs) -> None:
        self.n_dim = data_x.shape[1]
        self.max_theta = max_theta
        self.sigma = sigma
        gamma = self.gamma = self.get_adjusted_gamma(gamma)
        dist = GaussianDistribution(n_dim=self.n_dim, sigma=sigma, rng=rng)
        kernel = ExponentialKernel(gamma=gamma)
        base = Sign()
        super().__init__(data_x, dist, kernel, base, rng, **kwargs)

    def _exact_expectations(self, X: np.ndarray) -> np.ndarray:
        s = self.sigma
        g = self.gamma

        W = self.get_features()
        W_norms = self.get_feature_norms()

        cst = s / (g**2 * math.sqrt(8 * math.pi))
        exp_norms = np.exp((cst * W_norms) ** 2)
        arr = get_scalar_over_norm(X, W)

        return array_times_vector(erf(cst * arr), exp_norms, axis=1)

    def true_theta(self):
        s = self.sigma
        g = self.gamma
        n = self.n_dim
        return (1 - s**2 / (2*g**2))**(-n/2)

    def true_iota(self):
        s = self.sigma
        g = self.gamma
        n = self.n_dim
        return (1 - s**2 / (2*g**2))**(-n/2)

    def true_kappa(self):
        s = self.sigma
        g = self.gamma
        n = self.n_dim
        return (1 - s**2 / g**2)**(-n/4)

    def get_adjusted_gamma(self, gamma):
        if gamma == 'auto':
            return self.get_gamma_from_theta(self.max_theta)
        else:
            return float(gamma)

    def get_gamma_from_theta(self, theta):
        n = self.n_dim
        return self.sigma / math.sqrt(1 - theta**(-4/n))