import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class Wavy(nn.Module):    
    def __init__(self, idim=16, scale=512, eps=0):
        super(Wavy, self).__init__()
        self.eps= eps
        self.scale = scale
        self.idim = idim
        self.pi = 3.1415926535
        # self.lin=nn.Linear(self.idim, 1, bias=False)

    def transform(self, x):
        return self.scale/((torch.sum(x,dim=1,keepdim=True))**2+self.eps)

    def sob(self, x):
        return 2 * self.scale * math.sqrt(self.idim) /(torch.abs(torch.sum(x,dim=1,keepdim=True))**3)

    def forward(self, x):
        return torch.cat([torch.sin(self.transform(x)),torch.cos(self.transform(x))],dim=1)

if __name__ == "__main__":
    with torch.no_grad():
        wavy = Wavy()
        train_size, valid_size, test_size = int(1e8),int(1e6),int(1e6)
    
        X_train = torch.rand(train_size,wavy.idim)
        dataset_train =  X_train, wavy(X_train), wavy.sob(X_train)
    
        X_valid = torch.rand(valid_size,wavy.idim)
        dataset_valid =  X_valid, wavy(X_valid), wavy.sob(X_valid)
    
        X_test = torch.rand(test_size,wavy.idim)
        dataset_test =   X_test, wavy(X_test), wavy.sob(X_test)
        
        torch.save({"train":dataset_train, "valid":dataset_valid, "test":dataset_test}, './data/wavy_data.pt')

    
