import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.preprocessing import PolynomialFeatures
from scipy.spatial.distance import pdist, squareform
from generalized_likelihood import local_score_CV_general

def BIC_input_graph(X, g, reg_type='LR', score_type='BIC'):
    """cal BIC score for given graph"""

    RSS_ls = []

    n, d = X.shape 

    if reg_type in ('LR', 'QR'):
        reg = LinearRegression()
    else:
        reg =GaussianProcessRegressor()

    poly = PolynomialFeatures()

    for i in range(d):
        y_ = X[:, [i]]
        inds_x = list(np.abs(g[i])>0.1)

        if np.sum(inds_x) < 0.1: 
            y_pred = np.mean(y_)
        else:
            X_ = X[:, inds_x]
            if reg_type == 'QR':              
                X_ = poly.fit_transform(X_)[:, 1:] 
            elif reg_type == 'GPR':                
                med_w = np.median(pdist(X_, 'euclidean'))
                X_ = X_ / med_w
            reg.fit(X_, y_)
            y_pred = reg.predict(X_)
        RSSi = np.sum(np.square(y_ - y_pred))

        if reg_type == 'GPR':
            RSS_ls.append(RSSi+1.0)
        else:
            RSS_ls.append(RSSi)

    if score_type == 'BIC':
        return np.log(np.sum(RSS_ls)/n+1e-8) 
    elif score_type == 'BIC_different_var':
        return np.sum(np.log(np.array(RSS_ls)/n)+1e-8) 
    
    
def BIC_lambdas(X, config, gl=None, gu=None, gtrue=None, reg_type='LR', score_type='BIC'):
    """
    :param X: dataset
    :param gl: input graph to get score lower bound
    :param gu: input graph to get score upper bound
    :param gtrue: input true graph
    :param reg_type:
    :param score_type:
    :return: score lower bound, score upper bound, true score (only for monitoring)
    """
        
    n, d = X.shape

    if score_type == 'BIC':
        bic_penalty = np.log(n) / (n*d)
    elif score_type == 'BIC_different_var':
        bic_penalty = np.log(n) / n
    elif score_type == 'generalized_score':
        bic_penalty = 0
    
    # default gl for BIC score: complete graph (except digonals)
    if gl is None:
        g_ones= np.ones((d,d))
        for i in range(d):
            g_ones[i, i] = 0
        gl = g_ones

    # default gu for BIC score: empty graph
    if gu is None:
        gu = np.zeros((d, d))

    sl = generalized_score_input_score(X, gl, config)
    su = generalized_score_input_score(X, gu, config)

    if gtrue is None:
        strue = sl - 10
    else:
        print(BIC_input_graph(X, gtrue, reg_type, score_type))
        print(gtrue)
        print(bic_penalty)
        strue = generalized_score_input_score(X, gtrue, config) + np.sum(gtrue) * bic_penalty
    
    return sl, su, strue

def generalized_score_input_score(data, graph, config):
    score = 0
    params = (10, config.regression_lambda)
    maxlen = graph.shape[0]
    for i in range(maxlen):
        parents = np.where(graph[i] > 0.5)[0]
        score_ = local_score_CV_general(data, [i], parents, params, config.kernel_coe_data, config.kernel_coe_index)
        score += score_
    # score = self.score_transform(score, domain_index)
    return score