import numpy as np
from utils import data_normalize
# func_dic = {1: 'np.sin', 2: 'np.cos', 3: 'np.tanh'}


def SyntheticData(dim_list, nums, seed, CI=True, noise_scale=0.5, T_scale=0.2):
    np.random.seed(seed)
    [dim_x, dim_y, dim_z] = dim_list
    [seed1, seed2] = np.random.randint(0, 100000, 2)
    Z = np.random.randn(nums, dim_z)
    if CI:
        X = PNL_process222(Z, dim_x, nums, seed1, noise_scale)
        Y = PNL_process222(Z, dim_y, nums, seed2, noise_scale)
    else:
        T = np.random.randn(nums, 1)
        X = PNL_process222(Z, dim_x, nums, seed1, noise_scale)
        Y = PNL_process222(Z, dim_y, nums, seed2, noise_scale) # , T, T_scale=T_scale
        # T = np.random.randn(nums, 1)
        # TX = T_scale*T_func(T, seed1)
        # TY = T_scale*T_func(T, seed2)
        TX = T_scale*T
        TY = T_scale*T
        X += TX
        Y += TY
        # dim_toX = np.random.randint(0, dim_x)
        # dim_toY = np.random.randint(0, dim_y)
        # print(X[:, dim_toX].shape, TX.shape)
        # X[:, dim_toX] = X[:, dim_toX] + TX
        # Y[:, dim_toY] = Y[:, dim_toY] + TY

    X = data_normalize(X)
    Y = data_normalize(Y)
    Z = data_normalize(Z)
    return X, Y, Z


def PNL_process222(Z, dim_X, nums, seed, noise_scale, T=None, T_scale=None):
    dim_Z = Z.shape[-1]
    np.random.seed(seed)

    # func_id_f_X = np.random.randint(1, 5, dim_Z)
    # for i in range(dim_Z):
    #     if func_id_f_X[i] == 1:
    #         Z[:, i] = 2*np.sin(Z[:, i])
    #     elif func_id_f_X[i] == 2:
    #         Z[:, i] = 2*np.tanh(Z[:, i])
    #     elif func_id_f_X[i] == 3:
    #         Z[:, i] = 2*np.cos(Z[:, i]) #np.pi
    #     elif func_id_f_X[i] == 4:
    #         Z[:, i] = Z[:, i]

    # W = np.random.uniform(0.5, 2, (dim_Z, dim_X))
    # signal = 2*np.random.randint(0, 2, (dim_Z, dim_X)) - 1
    # W = signal * W
    # W = np.ones((dim_Z, dim_X))
    # X = Z @ W
    # W = np.ones((dim_Z, dim_X)) / np.sqrt(dim_Z)
    # X = Z @ W



    # if T is not None:
    #     X += T_scale*T
    W = np.random.randn(dim_Z, dim_X)# / np.sqrt(dim_Z)
    X =  Z @ W

    # if seed % 2 == 0:
    #     noise_f_X = noise_scale*np.random.randn(nums, dim_X)
    # else:
    #     noise_f_X = noise_scale*np.random.uniform(-0.5,0.5, (nums, dim_X))
    # X = X + noise_f_X

    func_id_f_X = np.random.randint(1, 8, 2)
    if func_id_f_X[0] == 1:
        X = np.sin(X*np.pi)
    elif func_id_f_X[0] == 2:
        X = np.cos(X*np.pi)
    elif func_id_f_X[0] == 3:
        X = X**2/ np.sqrt(dim_Z)
    elif func_id_f_X[0] == 4:
        X = X / np.sqrt(dim_Z)
    elif func_id_f_X[0] == 5:
        X = np.exp(X) / np.sqrt(dim_Z)
    elif func_id_f_X[0] == 6:
        X = 2**X / np.sqrt(dim_Z)

    if seed % 2 == 0:
        noise_f_X = noise_scale*np.random.randn(nums, dim_X)
    else:
        noise_f_X = noise_scale*np.random.uniform(-1,1, (nums, dim_X))
    X = X + noise_f_X

    return X


def T_func(T, seed):
    np.random.seed(seed)
    func_id_f_X = np.random.randint(1, 5, 2)
    if func_id_f_X[0] == 1:
        T = np.sin(T)
    elif func_id_f_X[0] == 2:
        T = np.cos(T)
    # elif func_id_f_X[0] == 3:
    #     T = T**2
    # elif func_id_f_X[0] == 4:
    #     T = T

    # elif func_id_f_X[0] == 5:
    # elif func_id_f_X[0] == 2:
    #     weight = np.random.randint(1, 4)
    #     T = 0.5*weight*T
    # elif func_id_f_X[0] == 4:
    #     T = T**3

    # T = T**2

    # if func_id_f_X[1] == 1:
    #     T = 2*np.sin(T)
    #     # T = np.sin(T)
    # elif func_id_f_X[1] == 2:
    #     T = 2*np.tanh(T)
    #     # T = np.tanh(T)
    # elif func_id_f_X[1] == 3:
    #     weight = np.random.randint(1, 4)
    #     T = 0.5*weight*T
    return T




