import numpy as np


def local_score_CV_general(Data, Xi, PAi, parameters, kernel_coe_data, kernel_coe_index=None):
    Data_ = np.copy(Data)
    np.random.shuffle(Data_)
    T = Data_.shape[0]
    d = Data_.shape[1]
    # X = np.expand_dims(Data[:, Xi], axis=1)
    X = Data_[:, Xi]
    regression_lambda = parameters[1]
    k = parameters[0]
    n0 = np.int(np.floor(T / k))
    gamma = 0.01
    Thresh = 1e-5
    CV = 0
    index_include = False
    if PAi.shape[0] != 0:
        if (PAi[-1] + 1 == d) and (kernel_coe_index != None):  # index included
            index_include = True
        PA = Data_[:, PAi]
        # set the kernel for X
        GX = np.multiply(X, X)
        Q = np.tile(GX, (1, T))
        R = np.tile(GX.T, (T, 1))
        dists = Q + R - 2 * X.dot(X.T)
        dists = dists - np.tril(dists)
        dists = np.reshape(dists, (T ** 2, 1), order='F')
        width = kernel_coe_data * np.sqrt(0.5 * np.median(dists[dists > 0]))
        width = width * 2
        theta = 1 / (width ** 2)

        Kx = kernel(X, X, (theta, 1))

        H0 = np.eye(T) - np.ones([T, T]) / (T)
        Kx = H0.dot(Kx).dot(H0)

        eig_Kx = np.sort(np.linalg.eigvals((Kx + Kx.T) / 2))[::-1][0:min(400, np.int(np.floor(T / 2)))]
        IIx = (eig_Kx > np.max(eig_Kx) * Thresh).nonzero()[0]
        eig_Kx = eig_Kx[IIx]
        mx = IIx.shape[0]

        # set the kernel for PA
        Kpa = np.ones([T, T])
        for m in range(PA.shape[1]):
            G = PA[:, [m]] ** 2
            Q = np.tile(G, (1, T))
            R = np.tile(G.T, (T, 1))
            dists = Q + R - 2 * PA[:, [m]].dot(PA[:, [m]].T)
            dists = dists - np.tril(dists)
            dists = np.reshape(dists, (T ** 2, 1), order='F')
            if index_include and m + 1 == PA.shape[1]:
                width = kernel_coe_index * np.sqrt(0.5 * np.median(dists[dists > 0]))
            else:
                width = kernel_coe_data * np.sqrt(0.5 * np.median(dists[dists > 0]))

            width = width * 2
            theta = 1 / (width ** 2)
            Kpa = Kpa * kernel(PA[:, [m]], PA[:, [m]], (theta, 1))
        H0 = np.eye(T) - np.ones([T, T]) / T  # for centering of the data in feature space
        Kpa = H0.dot(Kpa).dot(H0)  # kernel matrix for PA

        for kk in range(k):
            if (kk == 0):
                Kx_te = Kx[kk * n0:(kk + 1) * n0, kk * n0:(kk + 1) * n0]
                Kx_tr = Kx[(kk + 1) * n0:T, (kk + 1) * n0:T]
                Kx_tr_te = Kx[(kk + 1) * n0:T, kk * n0:(kk + 1) * n0]
                Kpa_te = Kpa[kk * n0:(kk + 1) * n0, kk * n0:(kk + 1) * n0]
                Kpa_tr = Kpa[(kk + 1) * n0:T, (kk + 1) * n0:T]
                Kpa_tr_te = Kpa[(kk + 1) * n0:T, kk * n0:(kk + 1) * n0]
                nv = n0  # sample size of validated data
            if (kk == k - 1):
                Kx_te = Kx[kk * n0:T, kk * n0:T]
                Kx_tr = Kx[0:kk * n0, 0:kk * n0]
                Kx_tr_te = Kx[0:kk * n0, kk * n0:T]
                Kpa_te = Kpa[kk * n0:T, kk * n0:T]
                Kpa_tr = Kpa[0:kk * n0, 0:kk * n0]
                Kpa_tr_te = Kpa[0:kk * n0, kk * n0:T]
                nv = T - kk * n0
            if (kk < k - 1 and kk > 0):
                Kx_te = Kx[kk * n0:(kk + 1) * n0, kk * n0:(kk + 1) * n0]
                Kx_tr = Kx[np.concatenate((np.arange(kk * n0), np.arange((kk + 1) * n0, T))), :][:,
                        np.concatenate((np.arange(kk * n0), np.arange((kk + 1) * n0, T)))]
                Kx_tr_te = Kx[np.concatenate((np.arange(kk * n0), np.arange((kk + 1) * n0, T))), kk * n0:(kk + 1) * n0]
                Kpa_te = Kpa[kk * n0:(kk + 1) * n0, kk * n0:(kk + 1) * n0]
                Kpa_tr = Kpa[np.concatenate((np.arange(kk * n0), np.arange((kk + 1) * n0, T))), :][:,
                         np.concatenate((np.arange(kk * n0), np.arange((kk + 1) * n0, T)))]
                Kpa_tr_te = Kpa[np.concatenate((np.arange(kk * n0), np.arange((kk + 1) * n0, T))),
                            kk * n0:(kk + 1) * n0]
                nv = n0
            n1 = T - nv
            tmp1 = pdinv(Kpa_tr + n1 * regression_lambda * np.eye(n1))
            tmp2 = tmp1.dot(Kx_tr).dot(tmp1)

            tmp3 = tmp1.dot(pdinv(np.eye(n1) + n1 * regression_lambda ** 2 / gamma * tmp2)).dot(tmp1)
            # print('##test##: ', Kpa_tr_te.T, '\n',tmp2,'\n', tmp1)
            A = (Kx_te + Kpa_tr_te.T.dot(tmp2).dot(Kpa_tr_te) - 2 * Kx_tr_te.T.dot(tmp1).dot(Kpa_tr_te) \
                 - n1 * regression_lambda ** 2 / gamma * Kx_tr_te.T.dot(tmp3).dot(Kx_tr_te) \
                 - n1 * regression_lambda ** 2 / gamma * Kpa_tr_te.T.dot(tmp1).dot(Kx_tr).dot(tmp3).dot(Kx_tr).dot(
                        tmp1).dot(Kpa_tr_te) \
                 + 2 * n1 * regression_lambda ** 2 / gamma * Kx_tr_te.T.dot(tmp3).dot(Kx_tr).dot(tmp1).dot(
                        Kpa_tr_te)) / gamma

            B = n1 * regression_lambda ** 2 / gamma * tmp2 + np.eye(n1)
            L = np.linalg.cholesky(B)
            C = np.sum(np.log(np.diag(L)))
            CV = CV + (nv * nv * np.log(2 * np.pi) + nv * C + np.trace(A)) / 2
        CV = CV / k

    else:
        if (Xi[-1] + 1 == d) and (kernel_coe_index != None):  # index included
            index_include = True
        GX = np.multiply(X, X)
        Q = np.tile(GX, (1, T))
        R = np.tile(GX.T, (T, 1))
        dists = Q + R - 2 * X.dot(X.T)
        dists = dists - np.tril(dists)
        dists = np.reshape(dists, (T ** 2, 1), order='F')
        if index_include:
            width = kernel_coe_index * np.sqrt(0.5 * np.median(dists[dists > 0]))
        else:
            width = kernel_coe_data * np.sqrt(0.5 * np.median(dists[dists > 0]))
        width = width * 2
        theta = 1 / (width ** 2)

        Kx = kernel(X, X, (theta, 1))
        H0 = np.eye(T) - np.ones([T, T]) / (T)
        Kx = H0.dot(Kx).dot(H0)
        eig_Kx = np.sort(np.linalg.eigvals((Kx + Kx.T) / 2))[::-1][0:min(400, np.int(np.floor(T / 2)))]
        IIx = (eig_Kx > np.max(eig_Kx) * Thresh).nonzero()[0]
        eig_Kx = eig_Kx[IIx]
        mx = IIx.shape[0]

        for kk in range(k):
            if (kk == 0):
                Kx_te = Kx[kk * n0:(kk + 1) * n0, kk * n0:(kk + 1) * n0]
                Kx_tr = Kx[(kk + 1) * n0:T, (kk + 1) * n0:T]
                Kx_tr_te = Kx[(kk + 1) * n0:T, kk * n0:(kk + 1) * n0]
                nv = n0
            if (kk == k - 1):
                Kx_te = Kx[kk * n0:T, kk * n0:T]
                Kx_tr = Kx[0:kk * n0, 0:kk * n0]
                # print('##test###: ', Kx_tr)
                Kx_tr_te = Kx[0:kk * n0, kk * n0:T]
                nv = T - kk * n0
            if (kk < k - 1 and kk > 0):
                Kx_te = Kx[kk * n0:(kk + 1) * n0, kk * n0:(kk + 1) * n0]
                Kx_tr = Kx[np.concatenate((np.arange(kk * n0), np.arange((kk + 1) * n0, T))), :][:,
                        np.concatenate((np.arange(kk * n0), np.arange((kk + 1) * n0, T)))]
                Kx_tr_te = Kx[np.concatenate((np.arange(kk * n0), np.arange((kk + 1) * n0, T))), kk * n0:(kk + 1) * n0]
                nv = n0
            n1 = T - nv
            A = (Kx_te - 1 / (gamma * n1) * Kx_tr_te.T.dot(pdinv(np.eye(n1) + 1 / (gamma * n1) * Kx_tr)).dot(
                Kx_tr_te)) / gamma
            B = 1 / (gamma * n1) * Kx_tr + np.eye(n1)
            L = np.linalg.cholesky(B)
            # print('###L###: ', gamma, n1, Kx_tr)
            C = np.sum(np.log(np.diag(L)))

            CV = CV + (nv * nv * np.log(2 * np.pi) + nv * C + np.trace(A)) / 2
        CV = CV / k
    return CV


def kernel(x, xKern, theta):
    n2 = dist2(x, xKern)
    if theta[0] == 0:
        theta[0] = 2 / np.median(n2[np.tril(n2) > 0])
        theta_new = theta[0]
    wi2 = theta[0] / 2
    kx = theta[1] * np.exp(-n2 * wi2)
    bw_new = 1 / theta[0]
    return kx


def dist2(x, c):
    ndata = x.shape[0]
    ncentres = c.shape[0]
    # assert dimx == dimc

    n2 = (np.ones([ncentres, 1]) * np.sum((x ** 2).T, 0)).T + \
         np.ones([ndata, 1]) * np.sum((c ** 2).T, 0) - \
         2. * (x.dot(c.T))

    if np.any(n2 < 0):
        n2[n2 < 0] = 0

    return n2


def pdinv(mat):
    d = mat.shape[0]
    U = np.linalg.cholesky(mat)
    invU = np.linalg.solve(U, np.eye(d)).T
    return invU.dot(invU.T)
    # except np.linalg.LinAlgError:
    #     print("matrix is not positive definite")
    #     return np.linalg.inv(mat)
    # else:
    #     raise np.linalg.LinAlgError
