import numpy as np



################################
class polynomial:
    def __init__(self,n, alpha) -> None:
        from numpy.polynomial.legendre import leg2poly
        self.n=n
        coef_mat = np.zeros((n, n), dtype=float)
        for j in range(1,n+1):
            c=[0 for __ in range(j)]
            c[-1]=1
            coef_mat[j-1, :len(c)] =leg2poly(c)
        #print(coef_mat)
        normalized_factor= np.sqrt(2) *np.sqrt((2*np.arange(n)+1)/2)
        
        alpha_vec=alpha*np.ones(n)
        alpha_vec[0]=1
        inverse_alpha_vec=(1/alpha)*np.ones(n)
        inverse_alpha_vec[0]=1
        
        
        factors = alpha_vec* normalized_factor
        self.coef_mat_alpha = coef_mat*factors[:, None]
        
        
        inverse_factors = inverse_alpha_vec* normalized_factor
        self.coef_mat_inverse_alpha= coef_mat*inverse_factors[:, None]
        
    
    def multivariate_all_basis_alpha(self, x):
        new_x=2*x-1
        #powers_x is np array of size d by n
        powers_x = np.vander(new_x,self.n, increasing=True)
        return powers_x @ self.coef_mat_alpha.T
    
    
    def multivariate_all_basis_inverse_alpha(self, x):
        new_x=2*x-1
        #powers_x is np array of size d by n
        powers_x = np.vander(new_x,self.n, increasing=True)
        return powers_x @ self.coef_mat_inverse_alpha.T
    

#n=5
#polynomial(n,1).multivariate_all_basis_alpha( np.array([0.5,0.3]))
########## 


class generate_basis_mat:
    def __init__(self, n, dim,alpha):
        self.n=n
        self.polynomial=polynomial(n,alpha)
        self.dim=dim
        
        
    #single input of x of size dim
    def all_x_multivariate(self, data):
        new_data= data.reshape(-1)
        new_data=new_data*2-1
        power_data= np.vander(new_data,self.n, increasing=True)
        basis_mat_flat= power_data @ self.polynomial.coef_mat_alpha.T
        
        return basis_mat_flat.reshape(len(data), self.dim, self.n)

""" 
    
    
    
    
    #def all_x_multivariate(self, data):
        
    #    basis_mat=[]
    #    for x in data:
    #        basis_mat.append(self.polynomial.multivariate_all_basis_alpha(x))
    #    return basis_mat
       
    def all_x_multivariate(self, data):
    # data : array of shape (N, dim)
        data = np.asarray(data, float)
        N, dim = data.shape

        # 1) map to [–1,1]
        new_x = 2*data - 1            # shape (N, dim)
    
        # 2) flatten to (N*dim,)
        flat = new_x.ravel()
    
        # 3) build the Vandermonde for all N*dim points in one go
        #    ⇒ shape (N*dim, n)
        V = np.vander(flat, self.n, increasing=True)
    
        # 4) one big BLAS call: (N*dim, n) @ (n, n) → (N*dim, n)
        C_flat = V.dot(self.polynomial.coef_mat_alpha.T)
    
        # 5) reshape back to (N, dim, n)
        return C_flat.reshape(N, dim, self.n)







        






"""






