from base_predictor import Relu
from distribution import GaussianDistribution
from kernel import GaussianKernel, ExponentialKernel
from base_model import BaseRKHSWeighting
from utils import array_times_vector, get_scalar_over_norm


import math
import numpy as np
from scipy.special import erf
import scipy.stats as stats

def estimate_max_norm(data):
    m, n = data.shape
    norms = np.linalg.norm(data, axis=1)
    std = max(np.std(data, axis=0))
    return max(norms) + std * stats.norm.ppf(m / (m+1))


class RWRelu(BaseRKHSWeighting):
    """
    Instantiation of the model using Gaussian kernel and distribution, 
    and ReLU as the base predictor.

    Parameters
    ----------
    input : np.ndarray,
        Data array of shape (n_examples, n_features).

    sigma : {'auto'} or float, default='auto'
        Parameter of the Gaussian distribution.

    gamma : {'auto'} or float, default='auto'
        Parameter of the Gaussian kernel.

    max_theta : float, default=0.5
        Parameter defining the relationship between sigma and gamma, when
        sigma and/or gamma is 'auto'.

    rng : RandomNumberGenerator or int or None, default=None
        The random number generator or randon seed to use.

    **kwargs
        Additional keyword arguments passed along to the RKHSWeighting __init__.

    Attributes
    ----------
    sigma : float
        The adjusted sigma parameter for the model.

    gamma : float
        The adjusted gamma parameter for the model.

    n_dim : int
        Number of dimensions of the input data.

    dist : Distribution
        Distribution of the parameters. p in the equations.

    kernel : Kernel
        K in the equations

    base : BasePredictor
        phi in the equations
    """
    def __init__(self, data_x: np.ndarray, sigma='auto', gamma='auto',
                 max_theta=0.5, rng=None, **kwargs) -> None:
        self.n_dim = data_x.shape[1]
        self.max_theta = max_theta
        self.max_data_norm_estimator = estimate_max_norm(data_x)
        sigma = self.sigma = self.get_adjusted_sigma(sigma)
        gamma = self.gamma = self.get_adjusted_gamma(gamma)
        dist = GaussianDistribution(n_dim=self.n_dim, sigma=sigma, rng=rng)
        kernel = GaussianKernel(gamma=gamma)
        base = Relu()
        super().__init__(data_x, dist, kernel, base, rng, **kwargs)

    def _exact_expectations(self, X: np.ndarray) -> np.ndarray:
        n = self.n_dim
        s = self.sigma
        g = self.gamma
        s2g2 = s**2 + g**2

        W = self.get_features()
        W_norms = self.get_feature_norms()
        X_norms = np.linalg.norm(X, axis=1)

        sqrtpi = math.sqrt(math.pi)
        zeta = math.sqrt(1 / (1/s**2 + 1/g**2))
        sqrt2zeta = math.sqrt(2) * zeta
        arr = (s**2 / s2g2) * get_scalar_over_norm(X, W)

        global_coef = (g**2 / s2g2)**(n/2) / (2 * sqrtpi)
        exp_norms = np.exp(-W_norms**2 / (2 * s2g2))
        big_parenthesis_term1 = sqrt2zeta * np.exp(-arr**2 / sqrt2zeta**2)
        big_parenthesis_term2 = sqrtpi * arr * (1 + erf(arr / sqrt2zeta))
        big_parenthesis = big_parenthesis_term1 + big_parenthesis_term2

        expectation = global_coef * array_times_vector(array_times_vector(big_parenthesis, exp_norms, axis=1), X_norms, axis=0)

        return expectation

    def true_theta(self):
        s = self.sigma
        g = self.gamma
        n = self.n_dim
        return (1 + 2*s**2 / g**2)**(-(n-1)/4) * s * self.max_data_norm_estimator / math.sqrt(2 * math.pi)

    def true_iota(self):
        return self.sigma * self.max_data_norm_estimator / math.sqrt(2 * math.pi)

    def true_kappa(self):
        return self.sigma * self.max_data_norm_estimator / math.sqrt(2)

    def get_adjusted_sigma(self, sigma):
        if sigma == 'auto':
            return 2 * math.sqrt(self.max_theta)
        else:
            return float(sigma)

    def get_adjusted_gamma(self, gamma):
        if gamma == 'auto':
            return self.get_gamma_from_theta(self.max_theta)
        else:
            return float(gamma)

    def get_gamma_from_theta(self, theta):
        n = self.n_dim
        return math.sqrt(2) * self.sigma / math.sqrt(theta**(-4/n) - 1)

    def get_gamma_from_expect(self, var):
        n = self.n_dim
        return self.sigma / math.sqrt(var**(-2/n) - 1)
    


class RWExpRelu(BaseRKHSWeighting):
    """
    Instantiation of the model using Gaussian distribution, exponential kernel
    and ReLU as the base predictor.

    Parameters
    ----------
    input : np.ndarray,
        Data array of shape (n_examples, n_features).

    sigma : {'auto'} or float, default='auto'
        Parameter of the Gaussian distribution.

    gamma : {'auto'} or float, default='auto'
        Parameter of the Gaussian kernel.

    max_theta : float, default=5
        Parameter defining the relationship between sigma and gamma, when
        sigma and/or gamma is 'auto'.

    rng : RandomNumberGenerator or int or None, default=None
        The random number generator or randon seed to use.

    **kwargs
        Additional keyword arguments passed along to the RKHSWeighting __init__.

    Attributes
    ----------
    sigma : float
        The adjusted sigma parameter for the model.

    gamma : float
        The adjusted gamma parameter for the model.

    n_dim : int
        Number of dimensions of the input data.

    dist : Distribution
        Distribution of the parameters. p in the equations.

    kernel : Kernel
        K in the equations

    base : BasePredictor
        phi in the equations
    """
    def __init__(self, 
                 data_x: np.ndarray, 
                 sigma='auto', 
                 gamma='auto',
                 max_kappa=50, 
                 rng=None, 
                 **kwargs) -> None:
        self.n_dim = data_x.shape[1]
        self.max_kappa = max_kappa
        self.max_data_norm_estimator = estimate_max_norm(data_x)
        sigma = self.sigma = self.get_adjusted_sigma(sigma)
        gamma = self.gamma = self.get_adjusted_gamma(gamma)
        dist = GaussianDistribution(n_dim=self.n_dim, sigma=sigma, rng=rng)
        kernel = ExponentialKernel(gamma=gamma)
        base = Relu()
        super().__init__(data_x, dist, kernel, base, rng, **kwargs)

    def _exact_expectations(self, X: np.ndarray) -> np.ndarray:
        n = self.n_dim
        s = self.sigma
        g = self.gamma

        W = self.get_features()
        W_norms = self.get_feature_norms()
        X_norms = np.linalg.norm(X, axis=1)

        sqrtpi = math.sqrt(math.pi)
        arr = get_scalar_over_norm(X, W)

        global_coef = (s / math.sqrt(2)) / (sqrtpi * math.sqrt(2) * s)
        exp_norms = np.exp(W_norms**2 * (s**2) / (8*g**4))
        big_parenthesis_term1 = (math.sqrt(2) * s) * np.exp(-arr**2 / (2*s**2))
        big_parenthesis_term2 = sqrtpi * arr * (1 + erf(arr / (math.sqrt(2) * s)))
        big_parenthesis = big_parenthesis_term1 + big_parenthesis_term2

        expectation = global_coef * array_times_vector(array_times_vector(big_parenthesis, exp_norms, axis=1), X_norms, axis=0)

        return expectation
    
    def true_theta(self):
        return self.true_iota()

    def true_iota(self):
        n = self.n_dim
        s = self.sigma
        g = self.gamma
        factors = [
            s * self.max_data_norm_estimator / (math.sqrt(2 * math.pi)),
            (1 - s**2 / (2 * g**2))**(-(n-1)/2)
        ]
        return math.prod(factors)

    def true_kappa(self):
        n = self.n_dim
        s = self.sigma
        g = self.gamma
        factors = [
            s**2 * self.max_data_norm_estimator**2 / 2,
            (1 - s**2 / g**2)**(-n/2 - 1)
        ]
        return math.sqrt(math.prod(factors))
    
    def get_adjusted_sigma(self, sigma):
        if sigma == 'auto':
            return 1
        else:
            return float(sigma)

    def get_adjusted_gamma(self, gamma):
        if gamma == 'auto':
            return self.get_gamma_from_kappa(self.max_kappa)
        else:
            return float(gamma)
    
    def get_gamma_from_kappa(self, theta):
        n = self.n_dim
        return np.sqrt(self.sigma**2 / (1 - theta**(-2/n)))