import torch

def standardize(y):
    return (y-y.mean())

def f1(x):
    y = -2*torch.sin(2*x)
    return standardize(y)

def f2(x):
    y = x**2/2 + 1
    return standardize(y)
    
def f3(x):
    y = x - (1/2)
    return standardize(y)

def f4(x1):
    y = torch.exp(-x1) + torch.exp(torch.Tensor([-1])) - 1 
    return standardize(y)

def f5(x1, x2):
    y = torch.exp(torch.sin(x1)+torch.cos(x2)-1)
    return standardize(y)

def generate_dataset(nsample, nfeature, UB, LB, type):
    
    # n_sample: [training size, validation size, testing size]
    dataset = {}
    name = ['Train', 'Valid', 'Test']
    for i, n in enumerate(nsample):
        data = {}
        X = torch.FloatTensor(n, nfeature).uniform_(LB, UB)

        func1 = f1(X[:, 0])
        func2 = f2(X[:, 1])
        func3 = f3(X[:, 2])

        if type == 'only_main':
            func4 = f4(X[:, 3])
        elif type == 'weak_main':
            func4 = 0.01 * f4(X[:, 3])
        elif type == 'inter_no_overlap':
            func4 = f5(X[:, 3], X[:, 4])
        elif type == 'inter_mild_overlap':
            func4 = f5(X[:, 2], X[:, 3])
        elif type == 'inter_strong_overlap':
            func4 = f5(X[:, 1], X[:, 2])
        elif type == 'only_inter':
            func1 = f5(X[:, 0], X[:, 1])
            func2 = f5(X[:, 2], X[:, 3])
            func3 = torch.zeros_like(func1)
            func4 = torch.zeros_like(func1)
        else:
            pass
            
        y = func1 + func2 + func3 + func4 + torch.rand(n)
        data['data'] = X
        data['target'] = y
        data['true_func'] = [func1, func2, func3, func4]
        dataset[name[i]] =  data
    
    return dataset


