import matplotlib.pyplot as plot
import numpy as np
from scipy.special import erf

from base_predictor import Stump
from distribution import GaussianThresholdDistribution
from kernel import IndicatorGaussianKernel
from base_model import BaseRKHSWeighting
from utils import array_times_vector




import math


class RWStumps(BaseRKHSWeighting):
    """
    Instantiation of the model using stumps as the base predictor.

    Distribution is Uniform on the chosen variable, and Gaussian on the value of the threshold.

    Kernel is Gaussian on the threshold, 0 when the variables are different.

    Parameters
    ----------
    input : int or np.ndarray,
        Number of input features, or data array of shape (n_examples, n_features).

    sigma : {'auto'} or float, default='auto'
        Parameter of the Gaussian distribution.

    gamma : {'auto'} or float, default='auto'
        Parameter of the Gaussian kernel.

    rng : RandomNumberGenerator or int or None, default=None
        The random number generator or randon seed to use.

    **kwargs
        Additional keyword arguments passed along to the Model __init__.

    Attributes
    ----------
    sigma : float
        The adjusted sigma parameter for the model.

    gamma : float
        The adjusted gamma parameter for the model.

    n_dim : int
        Number of dimensions of the input data.

    dist : Distribution
        Distribution of the parameters. p in the equations.

    kernel : Kernel
        K in the equations

    base : BasePredictor
        phi in the equations
    """
    def __init__(self, 
                 data_x: np.ndarray, 
                 sigma='auto', 
                 gamma='auto', 
                 rng=None, 
                 **kwargs) -> None:
        self.n_dim = data_x.shape[1]
        sigma = self.sigma = self.get_adjusted_sigma(sigma)
        gamma = self.gamma = self.get_adjusted_gamma(gamma)
        dist = GaussianThresholdDistribution(n_dim=self.n_dim, sigma=sigma, rng=rng)
        kernel = IndicatorGaussianKernel(gamma=gamma)
        base = Stump()
        super().__init__(data_x, dist, kernel, base, rng, **kwargs)

    def _exact_expectations(self, X: np.ndarray) -> np.ndarray:
        n = self.n_dim
        s = self.sigma
        g = self.gamma
        s2g2 = s**2 + g**2
        W = self.get_features()
        W_idxs = np.array([stump_param.idx for stump_param in W])
        W_values = np.array([stump_param.value for stump_param in W])

        zeta = math.sqrt(1 / (1/s**2 + 1/g**2))
        sqrt2zeta = math.sqrt(2) * zeta
        W2prime = s**2 / s2g2 * W_values

        coef = zeta / (s * n)
        exp_norms = np.exp(-W_values**2 / (2 * s2g2))
        erf_stump = erf((X[:,W_idxs]-W2prime)/sqrt2zeta)

        return coef * array_times_vector(erf_stump, exp_norms, axis=1)

    def true_theta(self):
        s = self.sigma
        g = self.gamma
        n = self.n_dim
        return (1 + 2*s**2 / g**2)**(-1/4) / math.sqrt(n)

    def true_iota(self):
        return 1.0

    def true_kappa(self):
        return 1.0

    def get_adjusted_sigma(self, sigma):
        return 1.0 if sigma == 'auto' else float(sigma)

    def get_adjusted_gamma(self, gamma):
        return self.sigma if gamma == 'auto' else float(gamma)

    def plot_variable(self, idx=0, min=-3, max=3, n_points=1000, show=True):
        """
        Plots the contribution of variable 'idx' to the output of the model
        """
        valid_centers = []
        valid_coefs = []
        for i in range(self.get_n_centers()):
             if int(self.list_of_params[i][0]) == idx:
                 valid_centers.append(self.list_of_params[i])
                 valid_coefs.append(self.coefs[i])
        values = np.linspace(start=min, stop=max, num=n_points)
        points = np.zeros(shape=(n_points, self.n_dim))
        points[:, idx] = values

        new_model = self.copy()
        new_model.set_features(valid_centers, valid_coefs)

        output = new_model.output(points)
        plot.plot(values, output)
        if show:
            plot.show()