import numpy as np

def oakoh04(xx):
    a1 = np.asarray([[0.0118, 0.0456, 0.2297, 0.0393, 0.1177, 0.3865, 0.3897, 0.6061, 0.6159, 0.4005, 1.0741, 1.1474, 0.7880, 1.1242, 1.1982]])
    a2 = np.asarray([[0.4341, 0.0887, 0.0512, 0.3233, 0.1489, 1.0360, 0.9892, 0.9672, 0.8977, 0.8083, 1.8426, 2.4712, 2.3946, 2.0045, 2.2621]])
    a3 = np.asarray([[0.1044, 0.2057, 0.0774, 0.2730, 0.1253, 0.7526, 0.8570, 1.0331, 0.8388, 0.7970, 2.2145, 2.0382, 2.4004, 2.0541, 1.9845]])
    M = np.asarray([[-0.022482886,  -0.18501666,  0.13418263,   0.36867264,   0.17172785,   0.13651143,  -0.44034404, -0.081422854,   0.71321025,  -0.44361072,   0.50383394, -0.024101458, -0.045939684,   0.21666181,  0.055887417],[0.25659630,  0.053792287,  0.25800381,   0.23795905,  -0.59125756, -0.081627077,  -0.28749073,   0.41581639,   0.49752241,  0.083893165,  -0.11056683,  0.033222351,  -0.13979497, -0.031020556,  -0.22318721],[-0.055999811,   0.19542252, 0.095529005,  -0.28626530,  -0.14441303,   0.22369356,   0.14527412,   0.28998481,   0.23105010,  -0.31929879,  -0.29039128,  -0.20956898,   0.43139047,  0.024429152,  0.044904409],[0.66448103,   0.43069872,  0.29924645,  -0.16202441,  -0.31479544,  -0.39026802,   0.17679822,  0.057952663,   0.17230342,   0.13466011,  -0.35275240,   0.25146896, -0.018810529,   0.36482392,  -0.32504618],[ -0.12127800,   0.12463327,  0.10656519,  0.046562296,  -0.21678617,   0.19492172, -0.065521126,  0.024404669, -0.096828860,   0.19366196,   0.33354757,   0.31295994, -0.083615456,  -0.25342082,   0.37325717],[-0.28376230,  -0.32820154, -0.10496068,  -0.22073452,  -0.13708154,  -0.14426375,  -0.11503319,   0.22424151, -0.030395022,  -0.51505615,  0.017254978,  0.038957118,   0.36069184,   0.30902452,  0.050030193],[-0.077875893, 0.0037456560,  0.88685604,  -0.26590028, -0.079325357, -0.042734919,  -0.18653782,  -0.35604718,  -0.17497421,  0.088699956,   0.40025886, -0.055979693,   0.13724479,   0.21485613, -0.011265799],[-0.092294730,   0.59209563, 0.031338285, -0.033080861,  -0.24308858, -0.099798547,  0.034460195,  0.095119813,  -0.33801620, 0.0063860024,  -0.61207299,  0.081325416,   0.88683114,   0.14254905,   0.14776204],[-0.13189434,   0.52878496,  0.12652391,  0.045113625,   0.58373514,   0.37291503,   0.11395325,  -0.29479222,  -0.57014085,   0.46291592, -0.094050179,   0.13959097,  -0.38607402,  -0.44897060,  -0.14602419],[0.058107658,  -0.32289338, 0.093139162,  0.072427234,  -0.56919401,   0.52554237,   0.23656926, -0.011782016,  0.071820601,  0.078277291,  -0.13355752,   0.22722721,   0.14369455,  -0.45198935,  -0.55574794],[0.66145875,   0.34633299,  0.14098019,   0.51882591,  -0.28019898,  -0.16032260, -0.068413337,  -0.20428242,  0.069672173,   0.23112577, -0.044368579,  -0.16455425,   0.21620977, 0.0042702105, -0.087399014],[0.31599556, -0.027551859,  0.13434254,   0.13497371,  0.054005680,  -0.17374789,   0.17525393,  0.060258929,  -0.17914162,  -0.31056619,  -0.25358691,  0.025847535,  -0.43006001,  -0.62266361, -0.033996882],[-0.29038151,  0.034101270, 0.034903413,  -0.12121764,  0.026030714,  -0.33546274,  -0.41424111,  0.053248380,  -0.27099455, -0.026251302,   0.41024137,   0.26636349,   0.15582891,  -0.18666254,  0.019895831],[-0.24388652,  -0.44098852, 0.012618825,   0.24945112,  0.071101888,   0.24623792,   0.17484502, 0.0085286769,   0.25147070,  -0.14659862, -0.084625150,   0.36931333,  -0.29955293,   0.11044360,  -0.75690139],[0.041494323,  -0.25980564,  0.46402128,  -0.36112127,  -0.94980789,  -0.16504063, 0.0030943325,  0.052792942,   0.22523648,   0.38390366,   0.45562427,  -0.18631744, 0.0082333995,   0.16670803,   0.16045688]])
  
    term1 = a1 @ xx.T
    term2 = a2 @ np.sin(xx.T)
    term3 = a3 @ np.cos(xx.T)
    term4 = np.sum(xx.T * (M @ xx.T), axis=0)
    y = term1 + term2 + term3 + term4
    return y.T
 
 
def get_xy_preprocess_tranforms(x_train,y_train,xpreprocess,train_data_size,input_dim):
    # derive quantities that can be used to preprocess (x,y) training values (in the case of x this might optionally either be none, prewhitening or axis-rescaling)
    # derive mean and sd of y train values, mean of x train components, if 'prewhitenting then inverse covar of x train cpts, and if axis-rescaling then
    #diag matrix with 1/(sd x-cpt) along the diag
    m_x  = np.average(x_train,axis = 0)
    m_y = np.average(y_train)
    sd_y =  np.std(y_train, dtype=np.float64)
    if (xpreprocess == 'axis_rescale'):
        sd_x =  np.std(x_train, axis =0, dtype=np.float64)
        prep_mat_x = np.diag(1.0/sd_x)
    if (xpreprocess == 'whiten'):
        cov_x = np.dot(np.transpose(x_train - m_x), (x_train - m_x))/float(train_data_size)
        U,S,V = np.linalg.svd(cov_x) #Singular Value Decomposition
        epsilon = 1e-5
        prep_mat_x = np.diag(1.0 / np.sqrt(S + epsilon)).dot(U.T)
    if (xpreprocess == 'none'):
        prep_mat_x = np.identity(input_dim,dtype=np.float64)
    ans = m_y, sd_y, m_x, prep_mat_x
    return(ans)
    
def preprocess_y_vals(y_vals,m_y,sd_y):
    #apply preprocessing transforms to y vals
    y_vals = y_vals - m_y
    y_vals = y_vals/sd_y
    return(y_vals)
 
 
data_size = 5 * (10 ** 7)
 
#generate raw function values
x = np.random.randn(data_size,15)
y = oakoh04(x)
#apply afine transform to function vals only to make ms value = 1 and mean vals = 0 over the x distro
m_y, sd_y, m_x, prep_mat_x =  get_xy_preprocess_tranforms(x,y,'whiten',data_size,15)
y= preprocess_y_vals(y,m_y,sd_y)
 
 
#generate npy file of xy values
xy_data =  np.zeros([data_size,16],dtype=np.float64)
for i in range(data_size):
    xy_data[i][15] = y[i]
    for j in range(15):
        xy_data[i][j] = x[i][j]
 
       
np.save('base_oak_xy_data.npy', xy_data)