import numpy as np
import torch

from scipy.special import legendre

# create numpy random number generator
rng = np.random.default_rng(seed=1)

level_0_time_embedding_functions = {
    "legendre": lambda degree, t: torch.Tensor(legendre(degree)(t.detach().cpu().numpy())),
    "monomial": lambda degree, t: t**degree,
    "sin": lambda degree, t: torch.sin(torch.pi*degree*t/2)/np.sqrt(2), # INPUT EXPECTED TO BE IN [-1, 1],
    "cos": lambda degree, t: torch.cos(torch.pi*degree*t/2)/np.sqrt(2), # INPUT EXPECTED TO BE IN [-1, 1],
    "triangle_1": lambda degree, t: 2*torch.maximum((degree*t +1) % 2-1, 2 - (degree*t +1) % 2-1) - 1, # INPUT EXPECTED TO BE IN [-1, 1],
    "triangle_2": lambda degree, t: 2*torch.maximum((degree*t) % 2-1, 2 - (degree*t) % 2-1) - 1, # INPUT EXPECTED TO BE IN [-1, 1],
    "time_copy": lambda degree, t: t/2,
}

level_1_time_embedding_functions = {
    "fourier" : lambda degree, t: (
        level_0_time_embedding_functions["sin"](degree=degree // 2, t=t) 
        if degree % 2 == 0 
        else level_0_time_embedding_functions["cos"](degree=degree // 2, t=t)
    ),
    "triangle" : lambda degree, t: (
        level_0_time_embedding_functions["triangle_1"](degree=degree // 2, t=t) 
        if degree % 2 == 0 
        else level_0_time_embedding_functions["triangle_2"](degree=degree // 2, t=t)
    ),
}

level_2_time_embedding_functions = {}

time_embedding_functions = {
    **level_0_time_embedding_functions,
    **level_1_time_embedding_functions,
    **level_2_time_embedding_functions,
}

level_0_time_embedding_functions_more_params = {
    
}

level_1_time_embedding_functions_more_params = { 
    
}
time_embedding_functions_more_params = { 
 **level_0_time_embedding_functions_more_params,
 **level_1_time_embedding_functions_more_params,
}