import numpy as np
import random
import math
from scipy import linalg
from copy import copy

def adaptive_exploration(X, Y, WP, EN):
    '''
    Explore sparse center points and calculate the best-fit alpha for RBF

    Paras:
    X: input feathers
    Y: input targets
    WP: some working parameters,
        WP[0] is the upper bound of condition of R
        WP[1] is the termination error of explorations
        WP[2] is the factor of shape parameters
    EN: input data limit
    '''
    X, Y = np.array(X), np.array(Y)
    N0 = len(Y)
    dim = np.shape(X)[1]
    X0, Y0 = copy(X), copy(Y)
    N = min(EN, N0)

    # Data order rearrangement
    Index = [i for i in range(N0)]
    Re_index = np.array(random.sample(Index, N))
    X, Y = X[Re_index], Y[Re_index]

    # PX: denote which area input datas belong to
    # P_centers: index of center points
    # distances: distance between inputs and nearest centers
    PX = np.zeros((N, 1))
    P_centers = np.argmin(np.sum((X - np.mean(X, axis=0))**2, axis=1))
    P_centers = [P_centers]
    distances = np.sum((X - X[P_centers])**2, axis=1).reshape(-1, 1)

    # SP2: shape parameter
    I = [0]
    XR2 = max(np.sum((X - np.mean(X, axis=0))**2, axis=1))
    nY, uY, beta, U = math.sqrt(np.mean(Y**2)), copy(Y), [], None
    SP2, RE, GK = -math.log(WP[2])/XR2, 2, None
    n = 0

    t_P_centers, t_distances, t_PX = copy(P_centers), copy(distances), copy(PX)
    for i in range(dim + 1):
        t_P_centers, t_distances, t_PX = findN1(X, t_P_centers, t_distances, t_PX, i)
        I.append(1)

    while True:
        u = np.exp(-SP2*np.sum((X-X[P_centers[n]])**2, axis=1))
        U, uY, beta = house1(n, N, u, U, uY, beta)
        #R：after QR decomposition, generate R and Q
        #uY: Y = Q*uY
        R = np.triu(U[0:n+1][0:n+1])
        if GK is None:
            GK = u.reshape(-1, 1)
        else:
            GK = np.concatenate((GK, u.reshape(-1, 1)), axis=1)
        #alpha: alpha in RBF
        alpha = linalg.solve(R, uY[0:n+1])
        
        #CR: compare with WP[0]
        dR = np.diagonal(R)
        CR = abs(max(dR)/min(dR))

        res = Y - np.sum(np.dot(GK, alpha.reshape(-1, 1)), axis=1)
        #RE - RE2: compare with WP[1], this one and CR both decide the termination of  the exploration
        RE2 = math.sqrt(np.mean(res**2))/nY
        J = CR < WP[0] and RE - RE2 > WP[1] and n < N

        if n < dim+2 or J:
            t_P_centers, t_distances, t_PX = findN1(X, t_P_centers, t_distances, t_PX, dim+n+1)
            I.append(1)
            dn = sum(I)
            SE = np.zeros((dn, 1))
            ID = np.where(np.array(I))[0]
            for j in range(dn):
                SE[j] = np.mean((res[np.where(t_PX == ID[j])[0]])**2)
            id = np.argmax(SE)
            I[ID[id]] = 0
            RE = RE2
            P_centers, distances, PX = addN1(X, P_centers, distances, PX, n, t_P_centers[ID[id]])
            n += 1
        else:
            P_centers = Re_index[P_centers]
            n += 1
            break
    RF = np.zeros(N0)
    for i in range(n):
        RF += alpha[i]*np.exp(-SP2*np.sum((X0-X0[P_centers[i]])**2, 1))
    res = Y0 - RF
    return n, P_centers, alpha, SP2, Re_index, res

# QR decomposition with householder
def house1(i, N, u, U, uY, beta):
    u_ = copy(u)
    uY_ = copy(uY)
    for j in range(i):
        temp = np.dot(np.concatenate(
            ([1], U[j+1:N, j]), axis=0), u_[j:N])*np.concatenate(([1], U[j+1:N, j]), axis=0)
        u_[j:N] = u_[j:N] - beta[j]*temp
    v, beta_temp = house(u_[i:N])
    beta.append(beta_temp)
    u_[i:N] = u_[i:N] - beta[i]*(np.dot(v, u_[i:N]))*v
    uY_[i:N] = uY_[i:N] - beta[i]*(np.dot(v[0:N-i], uY_[i:N]))*v[0:N-i]
    if U is None:
        U = np.concatenate(
            (u_[0:i+1].reshape(-1, 1), v[1:].reshape(-1, 1)), axis=0)
    else:
        temp = np.concatenate(
            (u_[0:i+1].reshape(-1, 1), v[1:].reshape(-1, 1)), axis=0)
        U = np.concatenate((U, temp), axis=1)
    return U, uY_, beta


def house(x):
    n = len(x)
    sig = np.dot(x[1:n].T, x[1:n])
    v = np.concatenate(([1], x[1:n]), axis=0)
    if sig == None:
        beta = 0
    elif sig == 0:
        beta = 0
    else:
        mu = math.sqrt(x[1]**2+sig)
        if x[0] <= 0:
            v[0] = x[0] - mu
        else:
            v[0] = -sig/(x[0] + mu)
        beta = 2*v[0]**2/(v[0]**2+sig)
        v = v/v[0]
    return v, beta

#find one more center point candidate
def findN1(X, P_centers, distances, PX, K):
    I1 = np.argmax(distances)
    D1 = np.concatenate(
        (distances, np.sum((X-X[I1, :])**2, axis=1).reshape(-1, 1)), axis=1)
    distances = np.min(D1, axis=1).reshape(-1, 1)
    I2 = np.argmin(D1, axis=1)
    for i in range(len(I2)):
        if I2[i] == 1:
            PX[i] = K+1
    P_centers.append(I1)
    return P_centers, distances, PX

#add one more center point
def addN1(X, P_centers, distances, PX, K, I):
    D1 = np.concatenate(
        (distances, np.sum((X-X[I, :])**2, axis=1).reshape(-1, 1)), axis=1)
    distances = np.min(D1, axis=1).reshape(-1, 1)
    I1 = np.argmin(D1, axis=1)
    for i in range(len(I1)):
        if I1[i] == 1:
            PX[i] = K+1
    P_centers.append(I)
    return P_centers, distances, PX

#split input space
def Split_X(X, Re_index, fs, res):
    X, res = np.array(X), np.array(res)
    N = len(res)
    dim = np.shape(X)[1]
    #X0, res0 = np.array(X),np.array(res)
    X, res = X[Re_index], res[Re_index]
    if dim == 1:
        a = np.argmax(np.sum((X-np.mean(X, axis=0))**2, axis=1))
    else:
        P0 = np.argmax(np.sum((X-np.mean(X, axis=0))**2, axis=1))
        P0 = [P0]
        D = np.sum((X-X[P0])**2, axis=1).reshape(-1, 1)
        PX = np.zeros((N, 1))
        SE = np.zeros((N+1, 1))
        for i in range(dim):
            P0, D, PX = findN1(X, P0, D, PX, i)
        for i in range(dim+1):
            SE[i] = np.mean((res[np.where(PX == i)[0]])**2)
        I = np.argmax(SE)
        a = P0[I]
    b = np.argmax(np.sum((X-X[a])**2, axis=1))
    nv = (X[b] - X[a])
    XV = np.dot(X, nv)
    if fs == 1:
        center = np.percentile(XV, 37+np.random.randint(1, 25))
    else:
        center = np.median(XV)
    return nv, center
