import numpy as np
from sklearn.linear_model import LinearRegression
from scipy.linalg import expm as matrix_exponential
from scipy.spatial.distance import pdist, squareform
from sklearn.gaussian_process import GaussianProcessRegressor as GPR
from sklearn.preprocessing import PolynomialFeatures
import logging
from rewards.generalized_score import local_score_CV_general

class get_Reward_GS(object):

    _logger = logging.getLogger(__name__)

    def __init__(self, config, batch_num, dim, inputdata, sl, su, lambda1_upper,
                 score_type='BIC', reg_type='LR', l1_graph_reg=0.0, verbose_flag=True):
        self.batch_num = batch_num
        self.maxlen = config.max_length # =d: number of vars
        self.dim = dim
        self.baseint = 2**config.max_length
        self.d = {} # store results
        self.d_RSS = {} # store RSS for reuse
        self.inputdata = inputdata
        self.n_samples = inputdata.shape[0]
        self.l1_graph_reg = l1_graph_reg 
        self.verbose = verbose_flag
        self.sl = sl
        self.su = su
        self.lambda1_upper = lambda1_upper
        self.bic_penalty = np.log(inputdata.shape[0])/inputdata.shape[0]

        if score_type not in ('BIC', 'BIC_different_var'):
            raise ValueError('Reward type not supported.')
        if reg_type not in ('LR', 'QR', 'GPR'):
            raise ValueError('Reg type not supported')
        self.score_type = score_type
        self.reg_type = reg_type

        self.ones = np.ones((inputdata.shape[0], 1), dtype=np.float32)
        self.poly = PolynomialFeatures()
        self.kernel_coe_data = config.kernel_coe_data
        self.kernel_coe_index = config.kernel_coe_index
        self.regression_lambda = config.regression_lambda
    def cal_rewards(self, graphs, lambda1, lambda2):
        rewards_batches = []

        for graphi in graphs:
            reward_ = self.calculate_reward_single_graph(graphi, lambda1, lambda2)
            rewards_batches.append(reward_)

        return np.array(rewards_batches)


    ####### regression 

    def calculate_yerr(self, X_train, y_train):
        if self.reg_type == 'LR':
            return self.calculate_LR(X_train, y_train)
        elif self.reg_type == 'QR':
            return self.calculate_QR(X_train, y_train)
        elif self.reg_type == 'GPR':
            return self.calculate_GPR(X_train, y_train)
        else:
            # raise value error
            assert False, 'Regressor not supported'

    # faster than LinearRegression() from sklearn
    def calculate_LR(self, X_train, y_train):
        X = np.hstack((X_train, self.ones))
        XtX = X.T.dot(X)
        Xty = X.T.dot(y_train)
        theta = np.linalg.solve(XtX, Xty)
        y_err = X.dot(theta) - y_train
        return y_err

    def calculate_QR(self, X_train, y_train):
        X_train = self.poly.fit_transform(X_train)[:,1:]
        return self.calculate_LR(X_train, y_train)
    
    def calculate_GPR(self, X_train, y_train):
        med_w = np.median(pdist(X_train, 'euclidean'))
        gpr = GPR().fit(X_train/med_w, y_train)
        return y_train.reshape(-1,1) - gpr.predict(X_train/med_w).reshape(-1,1)

    ####### score calculations

    def calculate_reward_single_graph(self, graph_batch, lambda1, lambda2):
        graph_to_int = []
        graph_to_int2 = []

        ###
 #        graph_batch = np.array([[0., 0., 0., 0., 1.],
 # [1., 0., 0., 0., 1.],
 # [0., 1., 0., 0., 0.],
 # [0., 0., 1., 0., 1.],
 # [0., 0., 0., 0., 0.]])

        # graph_batch = np.array( [[0., 0., 0., 0., 0.],
        #  [1., 0., 1., 1., 1.],
        # [1., 0., 0., 0., 1.],
        # [1., 0., 1., 0., 1.],
        # [0., 0., 0., 0., 0.]])

        # graph_batch = np.array([[0., 0., 0., 0.],
        # [1., 0., 0., 1.],
        # [0., 1., 0., 1.],
        # [0., 0., 0., 0.]])
        # graph_batch = np.array([[0., 0., 0., 0.],
        #                         [1., 0., 0., 0.],
        #                         [0., 1., 0., 0.],
        #                         [0., 0., 1., 0.]])
        ###

        for i in range(self.maxlen):
            graph_batch[i][i] = 0
            tt = np.int32(graph_batch[i])
            graph_to_int.append(self.baseint * i + np.int(''.join([str(ad) for ad in tt]), 2))
            graph_to_int2.append(np.int(''.join([str(ad) for ad in tt]), 2))

        graph_batch_to_tuple = tuple(graph_to_int2)
        if graph_batch_to_tuple in self.d:
            score_cyc = self.d[graph_batch_to_tuple]
            return self.penalized_score(score_cyc, lambda1, lambda2), score_cyc[0], score_cyc[1]

        RSS_ls = []
        params = (10,self.regression_lambda)
        for i in range(self.maxlen):
            if graph_to_int[i] in self.d_RSS:
                RSS_ls.append(self.d_RSS[graph_to_int[i]])
                continue
            parents = np.where(graph_batch[i] > 0.5)[0]
            score_ = local_score_CV_general(self.inputdata, [i], parents, params, self.kernel_coe_data, self.kernel_coe_index)
            RSS_ls.append(score_)
            # print(np.sum(RSS_ls))
            self.d_RSS[graph_to_int[i]] = score_

        # score = self.score_transform(np.sum(RSS_ls))
        score = np.sum(RSS_ls)/self.inputdata.shape[0]
        cycness = np.trace(matrix_exponential(np.array(graph_batch)))- self.maxlen
        cycness_penalty = (lambda1*np.float(cycness>1e-5) + lambda2*cycness)
        reward = score + cycness_penalty

        if self.l1_graph_reg > 0:
            reward = reward + self.l1_graph_reg * np.sum(graph_batch)
            score = score + self.l1_graph_reg * np.sum(graph_batch)

        self.d[graph_batch_to_tuple] = (score, cycness)

        if self.verbose:
            self._logger.info('BIC: {}, cycness: {}, returned reward: {}'.format(score, cycness, final_score))
        # print('reward: {}, score: {}, cyc: {}'.format(reward, score, cycness_penalty))
        return reward, score, cycness

    #### helper
    
    def score_transform(self, s):
        print("s: {}, sl: {}, su: {}".format(s, self.sl, self.su))
        return (s-self.sl)/(self.su-self.sl)*self.lambda1_upper
        # return (s-self.sl)/(self.su-self.sl)


    def penalized_score(self, score_cyc, lambda1, lambda2):
        score, cyc = score_cyc
        return score + lambda1*np.float(cyc>1e-5) + lambda2*cyc
    
    def update_scores(self, score_cycs, lambda1, lambda2):
        ls = []
        for score_cyc in score_cycs:
            ls.append(self.penalized_score(score_cyc, lambda1, lambda2))
        return ls
    
    def update_all_scores(self, lambda1, lambda2):
        score_cycs = list(self.d.items())

        ls = []
        for graph_int, score_cyc in score_cycs:
            ls.append((graph_int, (self.penalized_score(score_cyc, lambda1, lambda2), score_cyc[0], score_cyc[1])))
        return sorted(ls, key=lambda x: x[1][0])
# def local_score_CV_general(Data, Xi, PAi, parameters, kernel_coe_data, kernel_coe_index=None):
#     Data_ = np.copy(Data)
#     np.random.shuffle(Data_)
#     T = Data_.shape[0]
#     d = Data_.shape[1]
#     # X = np.expand_dims(Data[:, Xi], axis=1)
#     X = Data_[:, Xi]
#     regression_lambda = parameters[1]
#     k = parameters[0]
#     n0 = np.int(np.floor(T/k))
#     gamma = 0.01
#     Thresh = 1e-5
#     CV = 0
#     index_include = False
#     if PAi.shape[0] != 0:
#         if (PAi[-1]+1 == d) and (kernel_coe_index != None):  # index included
#             index_include = True
#         PA = Data_[:, PAi]
#         # set the kernel for X
#         GX = np.multiply(X, X)
#         Q = np.tile(GX, (1, T))
#         R = np.tile(GX.T, (T, 1))
#         dists = Q + R - 2*X.dot(X.T)
#         dists = dists-np.tril(dists)
#         dists = np.reshape(dists, (T**2, 1), order='F')
#         width = kernel_coe_data*np.sqrt(0.5*np.median(dists[dists>0]))
#         width = width*2
#         theta = 1/(width**2)
#
#         Kx = kernel(X, X, (theta,1))
#
#         H0 =  np.eye(T) - np.ones([T,T])/(T)
#         Kx = H0.dot(Kx).dot(H0)
#
#         eig_Kx = np.sort(np.linalg.eigvals((Kx+Kx.T)/2))[::-1][0:min(400, np.int(np.floor(T/2)))]
#         IIx = (eig_Kx > np.max(eig_Kx) * Thresh).nonzero()[0]
#         eig_Kx = eig_Kx[IIx]
#         mx = IIx.shape[0]
#
#         # set the kernel for PA
#         Kpa = np.ones([T,T])
#         for m in range(PA.shape[1]):
#             G = PA[:,[m]]**2
#             Q = np.tile(G,(1,T))
#             R = np.tile(G.T,(T,1))
#             dists = Q + R - 2*PA[:,[m]].dot(PA[:,[m]].T)
#             dists = dists-np.tril(dists)
#             dists = np.reshape(dists,(T**2,1), order='F')
#             if index_include and m+1 == PA.shape[1]:
#                 width = kernel_coe_index*np.sqrt(0.5*np.median(dists[dists>0]))
#             else:
#                 width = kernel_coe_data*np.sqrt(0.5*np.median(dists[dists>0]))
#
#             width = width*2
#             theta = 1/(width**2)
#             Kpa = Kpa*kernel(PA[:,[m]], PA[:,[m]], (theta,1))
#         H0 =  np.eye(T) - np.ones([T,T])/T  # for centering of the data in feature space
#         Kpa = H0.dot(Kpa).dot(H0)  # kernel matrix for PA
#
#         for kk in range(k):
#             if(kk==0):
#                 Kx_te = Kx[kk*n0:(kk+1)*n0,kk*n0:(kk+1)*n0]
#                 Kx_tr = Kx[(kk+1)*n0:T,(kk+1)*n0:T]
#                 Kx_tr_te = Kx[(kk+1)*n0:T,kk*n0:(kk+1)*n0]
#                 Kpa_te = Kpa[kk*n0:(kk+1)*n0,kk*n0:(kk+1)*n0]
#                 Kpa_tr = Kpa[(kk+1)*n0:T,(kk+1)*n0:T]
#                 Kpa_tr_te = Kpa[(kk+1)*n0:T,kk*n0:(kk+1)*n0]
#                 nv = n0  # sample size of validated data
#             if(kk==k-1):
#                 Kx_te = Kx[kk*n0:T,kk*n0:T]
#                 Kx_tr = Kx[0:kk*n0,0:kk*n0]
#                 Kx_tr_te = Kx[0:kk*n0,kk*n0:T]
#                 Kpa_te = Kpa[kk*n0:T,kk*n0:T]
#                 Kpa_tr = Kpa[0:kk*n0,0:kk*n0]
#                 Kpa_tr_te = Kpa[0:kk*n0,kk*n0:T]
#                 nv = T-kk*n0
#             if(kk<k-1 and kk>0):
#                 Kx_te = Kx[kk*n0:(kk+1)*n0,kk*n0:(kk+1)*n0]
#                 Kx_tr = Kx[np.concatenate((np.arange(kk*n0),np.arange((kk+1)*n0, T))), :][:, np.concatenate((np.arange(kk*n0),np.arange((kk+1)*n0, T)))]
#                 Kx_tr_te = Kx[np.concatenate((np.arange(kk*n0),np.arange((kk+1)*n0, T))),kk*n0:(kk+1)*n0]
#                 Kpa_te = Kpa[kk*n0:(kk+1)*n0,kk*n0:(kk+1)*n0]
#                 Kpa_tr = Kpa[np.concatenate((np.arange(kk*n0),np.arange((kk+1)*n0, T))), :][:, np.concatenate((np.arange(kk*n0),np.arange((kk+1)*n0, T)))]
#                 Kpa_tr_te = Kpa[np.concatenate((np.arange(kk*n0),np.arange((kk+1)*n0, T))),kk*n0:(kk+1)*n0]
#                 nv = n0
#             n1 = T-nv
#             tmp1 = pdinv(Kpa_tr + n1*regression_lambda*np.eye(n1))
#             tmp2 = tmp1.dot(Kx_tr).dot(tmp1)
#
#             tmp3 = tmp1.dot(pdinv(np.eye(n1) + n1*regression_lambda**2/gamma*tmp2)).dot(tmp1)
#             # print('##test##: ', Kpa_tr_te.T, '\n',tmp2,'\n', tmp1)
#             A = (Kx_te + Kpa_tr_te.T.dot(tmp2).dot(Kpa_tr_te) - 2*Kx_tr_te.T.dot(tmp1).dot(Kpa_tr_te)\
#                 - n1*regression_lambda**2/gamma*Kx_tr_te.T.dot(tmp3).dot(Kx_tr_te)\
#                 - n1*regression_lambda**2/gamma*Kpa_tr_te.T.dot(tmp1).dot(Kx_tr).dot(tmp3).dot(Kx_tr).dot(tmp1).dot(Kpa_tr_te)\
#                 + 2*n1*regression_lambda**2/gamma*Kx_tr_te.T.dot(tmp3).dot(Kx_tr).dot(tmp1).dot(Kpa_tr_te))/gamma
#
#             B = n1*regression_lambda**2/gamma * tmp2 + np.eye(n1)
#             L = np.linalg.cholesky(B)
#             C = np.sum(np.log(np.diag(L)))
#             CV = CV + (nv*nv*np.log(2*np.pi) + nv*C + np.trace(A))/2
#         CV =CV/k
#
#     else:
#         GX = np.multiply(X, X)
#         Q = np.tile(GX, (1, T))
#         R = np.tile(GX.T, (T, 1))
#         dists = Q + R - 2*X.dot(X.T)
#         dists = dists-np.tril(dists)
#         dists = np.reshape(dists, (T**2, 1), order='F')
#         width = kernel_coe_data*np.sqrt(0.5*np.median(dists[dists>0]))
#         width = width*2
#         theta = 1/(width**2)
#
#         Kx = kernel(X, X, (theta,1))
#         H0 =  np.eye(T) - np.ones([T,T])/(T)
#         Kx = H0.dot(Kx).dot(H0)
#         eig_Kx = np.sort(np.linalg.eigvals((Kx+Kx.T)/2))[::-1][0:min(400, np.int(np.floor(T/2)))]
#         IIx = (eig_Kx > np.max(eig_Kx) * Thresh).nonzero()[0]
#         eig_Kx = eig_Kx[IIx]
#         mx = IIx.shape[0]
#
#         for kk in range(k):
#             if(kk==0):
#                 Kx_te = Kx[kk*n0:(kk+1)*n0,kk*n0:(kk+1)*n0]
#                 Kx_tr = Kx[(kk+1)*n0:T,(kk+1)*n0:T]
#                 Kx_tr_te = Kx[(kk+1)*n0:T,kk*n0:(kk+1)*n0]
#                 nv = n0
#             if(kk==k-1):
#                 Kx_te = Kx[kk*n0:T,kk*n0:T]
#                 Kx_tr = Kx[0:kk*n0,0:kk*n0]
#                 # print('##test###: ', Kx_tr)
#                 Kx_tr_te = Kx[0:kk*n0,kk*n0:T]
#                 nv = T-kk*n0
#             if(kk<k-1 and kk>0):
#                 Kx_te = Kx[kk*n0:(kk+1)*n0,kk*n0:(kk+1)*n0]
#                 Kx_tr = Kx[np.concatenate((np.arange(kk*n0),np.arange((kk+1)*n0, T))), :][:, np.concatenate((np.arange(kk*n0),np.arange((kk+1)*n0, T)))]
#                 Kx_tr_te = Kx[np.concatenate((np.arange(kk*n0),np.arange((kk+1)*n0, T))),kk*n0:(kk+1)*n0]
#                 nv = n0
#             n1 = T-nv
#             A = (Kx_te - 1/(gamma*n1)*Kx_tr_te.T.dot(pdinv(np.eye(n1)+1/(gamma*n1)*Kx_tr)).dot(Kx_tr_te))/gamma
#             B = 1/(gamma*n1)*Kx_tr + np.eye(n1)
#             L = np.linalg.cholesky(B)
#             # print('###L###: ', gamma, n1, Kx_tr)
#             C = np.sum(np.log(np.diag(L)))
#
#             CV = CV + (nv*nv*np.log(2*np.pi) + nv*C + np.trace(A))/2
#         CV =CV/k
#     return CV
#
#
# def kernel(x, xKern, theta):
#     n2 = dist2(x, xKern)
#     if theta[0]==0:
#         theta[0]=2/np.median(n2[np.tril(n2)>0])
#         theta_new=theta[0]
#     wi2 = theta[0]/2
#     kx = theta[1]*np.exp(-n2*wi2)
#     bw_new=1/theta[0]
#     return kx
# def dist2(x, c):
#     ndata = x.shape[0]
#     ncentres = c.shape[0]
#     # assert dimx == dimc
#
#     n2 = (np.ones([ncentres, 1]) * np.sum((x**2).T, 0)).T +\
#     np.ones([ndata, 1]) * np.sum((c**2).T,0) -\
#     2.*(x.dot(c.T))
#
#     if np.any(n2<0):
#         n2[n2<0] = 0
#
#     return n2
#
# def pdinv(mat):
#     d = mat.shape[0]
#     U = np.linalg.cholesky(mat)
#     invU = np.linalg.solve(U, np.eye(d)).T
#     return invU.dot(invU.T)
#     # except np.linalg.LinAlgError:
#     #     print("matrix is not positive definite")
#     #     return np.linalg.inv(mat)
#     # else:
#     #     raise np.linalg.LinAlgError
