import numpy as np


def get_w (f_t_x):
    return 1.0 / f_t_x

# nn 无法解决balancing 调这里可以使其变差
def true_t2obs_t(true_t, x):

    obs_t = np.zeros_like(true_t)

    obs_t[:, 0] = true_t[:, 0] + 1.5*(x[:, :3]).sum(axis=1) + 0.5 * (x[:, 150:]).sum(axis=1)
    bg = 4
    stride = 1
    if true_t.shape[1] > 1:
        obs_t[:, 0] -= 0.1*(x[:, 150:]).sum(axis=1)

        # obs_t[:, 0] += 0.01*(x[:, 80:90]).sum(axis=1)
        obs_t[:, 1] = true_t[:, 1] + 1.5 * (x[:, bg:bg+stride]).sum(axis=1)
    if true_t.shape[1] == 5:
        # obs_t[:, 0] += 0.4*(x[:, 100:]).sum(axis=1)

        for i in range(2, true_t.shape[1]):
            # obs_t[:, i] = true_t[:, i] + i* (x[:, bg+i:bg+i+stride]).sum(axis=1)
            obs_t[:, i] = true_t[:, i] + x[:, 100:100+i].sum(axis=1)



    return obs_t

# cevae根本不收敛， 调高ADRF 就可以搞差CEVAE
def true_t2y(true_t, x):

    y = 1.0/(1.0+np.exp(-true_t[:,:2].sum(axis=1, keepdims=True))) 
    if x.shape[1] > 5:
        y += x.shape[1]/50.0
    if x.shape[1] == 200:
        y -= 2.0
    # y += true_t.shape[1] - 1.0
    # if true_t.shape[1] > 1:
    #     y += 1.8*np.cos(true_t[:,2:].sum(axis=1, keepdims=True))
    if true_t.shape[1] == 5:
        y += 0.1*true_t[:,2:].sum(axis=1, keepdims=True)
        y += 2.0
    return y
def simulate(args):
    # setup_seed(0)
    # 确认x的分布
    x = np.zeros((args.n_samples, args.x_dim))
    observed_x = np.zeros((args.n_samples, args.x_dim))
    for i in range(args.x_dim):
        x[:,i] = np.random.normal(0.0, 1.0, (args.n_samples))
        if i < 5:
            observed_x[:,i] = x[:, i] + np.linspace(0, i/(10.0), x.shape[0])
        else:
            observed_x[:,i] = x[:, i]

        # else:
        #     observed_x[:,i] = x[:, i] + np.linspace(0, i/(100.0), x.shape[0])


    # x = np.concatenate((x1,x2,x3,x4,x5), axis=1)
    true_t = np.zeros((args.n_samples, args.t_dim))

    true_t[:,0] = np.concatenate((np.random.normal(3,1.0, args.n_samples//2),np.random.normal(6.0,0.5, args.n_samples//2)),axis=0)
    bg = 4.0
    if args.t_dim > 1:
        true_t[:, 1] = np.random.normal(bg, 1.0, args.n_samples)
    if args.t_dim == 5:
        for i in range(2, args.t_dim):
            # true_t[:, i] = np.random.normal(bg+i, 1.0, args.n_samples)
            true_t[:, i] = np.random.normal(i, 0.5, args.n_samples)

            # true_t[:, i] = np.random.uniform(bg, bg+i, args.n_samples)



    # observed_t = true_t + 0.5*(x1 + x2 +  x3 +  x4 + x5) 
    observed_t = true_t2obs_t(true_t, x)

    y = (np.exp(x[...,0][...,np.newaxis]) - 1.0 +2.1*(x[...,1][...,np.newaxis])+\
        2.2*x[...,2][...,np.newaxis]+2.3*x[...,3][...,np.newaxis]+x[...,4][...,np.newaxis])\
        + (100.0/args.x_dim+5.0) * np.sum(x[:,5:150], axis=1, keepdims=True)

    if args.x_dim == 200:
        y += 4.0 * np.sum(x[:,150:], axis=1, keepdims=True)
    y += true_t2y(true_t, x)

    data = np.concatenate((observed_t, observed_x, y), axis=1)
    adrf_true_t = np.zeros((args.n_adrf, args.t_dim))
    adrf_true_t[:, 0] = np.linspace(args.t_left, args.t_right, args.n_adrf)
    if args.t_dim > 1:
        adrf_true_t[:, 1] = np.linspace(observed_t[:,1].mean()-1.0, observed_t[:,1].mean()+1.0, args.n_adrf)
    for i in range(2, adrf_true_t.shape[1]):
        adrf_true_t[:, i] = np.linspace(-0.5, 0.5, args.n_adrf)
    adrf_x = np.zeros((args.n_adrf,args.x_dim))*1.0

    adrf_obs_t = true_t2obs_t(adrf_true_t,adrf_x)
    
    adrf_y = true_t2y(adrf_true_t, adrf_x)
    adrf = np.concatenate((adrf_obs_t, adrf_x, adrf_y), axis=1)

    return data, adrf,x
