
import numpy as np



class abstract_mean_element():
    def __init__(self,name,rkhs,mu_def,norm):
        self.name = name
        self.rkhs = rkhs
        self.mu_def = mu_def
        self.norm = norm
    def evaluate(self,nodes):
        N = len(nodes)
        evals_list = []
        for i in list(range(N)):
            evals_list.append(self.mu_def(nodes[i]))
        return evals_list



def mean_element(d,g_list,sigma_list,onb_list):
    onb_list_trunc = onb_list[0:len(g_list)]
    sigma_list_trunc = sigma_list[0:len(g_list)]
    def mean_element_aux(x):
        phi_x = np.asarray([phi.evaluate_at_X(x) for phi in onb_list_trunc])
        mu_coeffs = np.asarray([g*s for (g,s) in zip(g_list,sigma_list_trunc)])
        return np.dot(phi_x,mu_coeffs)
    if d ==1:
        return mean_element_aux
    else:
        def tensor_mean_element_aux(x):
            output_var = 1
            for i in list(range(d)):
                output_var *= mean_element_aux(x[i])
            return output_var
        return tensor_mean_element_aux
    
    
def pseudo_mean_element(d,g_list,sigma_list,onb_list):
    onb_list_trunc = onb_list[0:len(g_list)]
    sigma_list_trunc = sigma_list[0:len(g_list)]
    def mean_element_aux(x):
        phi_x = np.asarray([phi.evaluate_at_X(x) for phi in onb_list_trunc])
        mu_coeffs = np.asarray([g*s for (g,s) in zip(g_list,sigma_list_trunc)])
        return np.dot(phi_x,mu_coeffs)
    if d ==1:
        return mean_element_aux
    else:
        def tensor_mean_element_aux(x):
            output_var = 1
            for i in list(range(d)):
                output_var *= mean_element_aux(x[i])
            return output_var
        return tensor_mean_element_aux




def sigma_seq(n_list,s):
    output_list = []
    for n in n_list: 
        if n == 1 or n==2:
            output_list.append(1)
        if n>2:
            if n%2==0:
                output_list.append(1/np.power(n/2,2*s))
            if n%2==1:
                output_list.append(1/np.power((n+1)/2,2*s))
    return output_list



def ezq_rate(n_list,sigma_list):
    
    output_list = []
    for n in n_list: 
        output_list.append(np.sum(sigma_list[n:]))
    return output_list




def get_Sobolev_spectrum(s,M):
    pSobolev_trunc_order = M
    pSobolev_sigma_list = [1]
    for i in list(range(1,pSobolev_trunc_order)):
        if i%2 ==0:
            pSobolev_sigma_list.append(1/np.power(int(i/2),2*s))
        else:
            pSobolev_sigma_list.append(1/np.power(int((i+1)/2),2*s))   
    return pSobolev_sigma_list