import numpy as np
from scipy.special import binom, factorial2
from jax import numpy as jnp, vmap
from jax.numpy import linalg as jla


def hermite_coefficient(k, j):
    if (k - j) % 2:
        return 0
    if j > k:
        return 0
    i = (k - j) // 2
    return (-1) ** i * binom(k, 2 * i) * factorial2(2 * i - 1)


def gegenbauer_coefficient(k, j, d):
    if j > k:
        return 0
    i = (k - j) // 2
    numerator = d + 2 * ((k - 1) // 2) + np.arange(0, k - 2 * i - 1, 2)
    numerator = np.pad(numerator, (0, i), constant_values=1)
    denominator = d - 1 + np.arange(0, k - 1, 2)
    return hermite_coefficient(k, j) * np.prod(numerator / denominator)


def HeMat(p):
    return np.array(
        [[hermite_coefficient(k, j) for k in range(p + 1)] for j in range(p + 1)]
    )


def GegMat(p, d):
    return np.array(
        [[gegenbauer_coefficient(k, j, d) for k in range(p + 1)] for j in range(p + 1)]
    )

def get_sigma(he_coef):
    p = len(he_coef) - 1
    he_coef = jnp.array(he_coef)
    H = jnp.array(HeMat(p))
    poly_coef = H@he_coef
    _sigma = lambda x: jnp.polyval(poly_coef[::-1],x)
    return _sigma

def get_smoothed_sigma(he_coef, d):
    p = len(he_coef) - 1
    he_coef = jnp.array(he_coef)
    H = jnp.array(HeMat(p))
    G = jnp.array(GegMat(p, d))
    Ginv = jla.inv(G)

    def _sigma(w, x, lam):
        alpha, x_norm = w@x, jla.norm(x)
        poly_coef = H @ he_coef
        poly_coef *= x_norm ** jnp.arange(p + 1)
        geg_coef = Ginv @ poly_coef
        cos = 1 / jnp.sqrt(1 + lam**2)
        geg_coef *= vmap(jnp.polyval, (0, None))(G.T[:, ::-1], cos)
        poly_coef = G @ geg_coef
        return jnp.polyval(poly_coef[::-1], alpha / x_norm)

    return _sigma
