import numpy as np

def infer_nonsta_dir(X, Y, width, Wt):
    GP_hyp = 0

    (T, d) = X.shape
    X = np.dot(X, np.diag(1.0/np.std(X, axis=0)))
    Y = np.dot(Y, np.diag(1.0/np.std(Y, axis=0)))

    theta = 1 / width**2
    lambda_ = 2
    Ml =[]
    Kyy = kernel(Y, Y, (theta, 1))

    # P(Y|X)
    if GP_hyp:
        pass
    else:
        Kxx = kernel(X, X, (theta, 1))
        Kyy = kernel(Y, Y, (theta, 1))
        Ktt = kernel(np.expand_dims(np.arange(T), axis=1), np.expand_dims(np.arange(T), axis=1), (1 / Wt**2, 1))
        invK = pdinv(Kxx * Ktt + lambda_ * np.eye(T))
    Kxx3 = np.dot(Kxx, Kxx)
    prod_invK = np.dot(np.dot(invK, Kyy), invK)
    Ml = 1 / T**2 * np.dot(np.dot(Ktt, Kxx3 * prod_invK), Ktt)
    D = np.dot(np.diag(np.diag(Ml)), np.ones(Ml.shape)) + np.dot(np.ones(Ml.shape), np.diag(np.diag(Ml))) - 2 * Ml

    sigma2_square = np.median(D[np.tril(np.ones(D.shape), -1)!=0])
    if(sigma2_square==0):
        Mg = np.zeros(D.shape)
    else:
        Mg = np.exp(-D / sigma2_square / 2)

    # P(X)
    invK2 = pdinv(Ktt + lambda_ * np.eye(T))
    Ml2 = Ktt.dot(invK2).dot(Kxx).dot(invK2).dot(Ktt)
    D2 = np.diag(np.diag(Ml2)).dot(np.ones(Ml2.shape)) + np.ones(Ml2.shape).dot(np.diag(np.diag(Ml2))) - 2 * Ml2
    sigma2_square2 = np.median(D2[np.tril(np.ones(D2.shape), -1)!=0])
    # print("sigma2_square2: {}, \nD2: {}".format(sigma2_square2, D2))
    if(sigma2_square2==0):
        Mg2 = np.zeros(D2.shape)
    else:
        Mg2 = np.exp(-D2 / sigma2_square2 / 2)

    ###
    H = np.eye(T) - 1 / T * np.ones([T, T])
    Mg = H.dot(Mg).dot(H)
    Mg2 = H.dot(Mg2).dot(H)
    testStat = 1 / T**2 * np.sum(Mg.T * Mg2)
    return testStat


def kernel(x, xKern, theta):
    n2 = dist2(x, xKern)
    if theta[0]==0:
        theta[0]=2/np.median(n2[np.tril(n2)>0])
        theta_new=theta[0]
    wi2 = theta[0]/2
    kx = theta[1]*np.exp(-n2*wi2)
    bw_new=1/theta[0]
    return kx   
def dist2(x, c):
    ndata = x.shape[0]
    ncentres = c.shape[0]
    # assert dimx == dimc

    n2 = (np.ones([ncentres, 1]) * np.sum((x**2).T, 0)).T +\
    np.ones([ndata, 1]) * np.sum((c**2).T,0) -\
    2.*(x.dot(c.T))

    if np.any(n2<0):
        n2[n2<0] = 0

    return n2

def pdinv(mat):
    d = mat.shape[0]
    U = np.linalg.cholesky(mat)
    invU = np.linalg.solve(U, np.eye(d)).T
    return invU.dot(invU.T)
    # except np.linalg.LinAlgError:
    #     print("matrix is not positive definite")
    #     return np.linalg.inv(mat)
    # else:
    #     raise np.linalg.LinAlgError
