import numpy as np
from sklearn.linear_model import LinearRegression
from scipy.linalg import expm as matrix_exponential
from scipy.spatial.distance import pdist, squareform
from sklearn.gaussian_process import GaussianProcessRegressor as GPR
from sklearn.preprocessing import PolynomialFeatures
import logging
from rewards.generalized_score import local_score_CV_general
import matlab.engine
# from rewards.MDS_score import infer_nonsta_dir

class get_Reward_MDS(object):

    _logger = logging.getLogger(__name__)

    def __init__(self, config, batch_num, dim, inputdata, sl, su, lambda1_upper,
                 score_type='BIC', reg_type='LR', l1_graph_reg=0.0, verbose_flag=True):
        self.batch_num = batch_num
        self.maxlen = config.max_length # =d: number of vars
        self.dim = dim
        self.baseint = 2**self.maxlen
        self.d = {} # store results
        self.d_RSS = {} # store RSS for reuse
        self.d_MDS = {}
        self.inputdata = inputdata
        self.n_samples = inputdata.shape[0]
        self.l1_graph_reg = l1_graph_reg 
        self.verbose = verbose_flag
        self.sl = sl
        self.su = su
        self.lambda1_upper = lambda1_upper
        self.bic_penalty = np.log(inputdata.shape[0])/inputdata.shape[0]
        if score_type not in ('BIC', 'BIC_different_var'):
            raise ValueError('Reward type not supported.')
        if reg_type not in ('LR', 'QR', 'GPR'):
            raise ValueError('Reg type not supported')
        self.score_type = score_type
        self.reg_type = reg_type

        self.ones = np.ones((inputdata.shape[0], 1), dtype=np.float32)
        self.poly = PolynomialFeatures()
        self.kernel_coe_data = config.kernel_coe_data
        self.kernel_coe_index = None
        self.n_datasets = config.n_datasets
        self.regression_lambda = config.regression_lambda
        self.eng = matlab.engine.start_matlab()
        self.eng.addpath(
            self.eng.genpath('/home/dingchenwei/MDS/baseline/Causal-Discovery-from-Nonstationary-Heterogeneous-Data-master'))
    def cal_rewards(self, graphs, nons_vars, lambda1, lambda2, lambda3):
        rewards_batches = []

        for graphi in graphs:
            reward_ = self.calculate_reward_single_graph(graphi, nons_vars, lambda1, lambda2, lambda3)
            rewards_batches.append(reward_)

        return np.array(rewards_batches)


    ####### regression 

    def calculate_yerr(self, X_train, y_train):
        if self.reg_type == 'LR':
            return self.calculate_LR(X_train, y_train)
        elif self.reg_type == 'QR':
            return self.calculate_QR(X_train, y_train)
        elif self.reg_type == 'GPR':
            return self.calculate_GPR(X_train, y_train)
        else:
            # raise value error
            assert False, 'Regressor not supported'

    # faster than LinearRegression() from sklearn
    def calculate_LR(self, X_train, y_train):
        X = np.hstack((X_train, self.ones))
        XtX = X.T.dot(X)
        Xty = X.T.dot(y_train)
        theta = np.linalg.solve(XtX, Xty)
        y_err = X.dot(theta) - y_train
        return y_err

    def calculate_QR(self, X_train, y_train):
        X_train = self.poly.fit_transform(X_train)[:,1:]
        return self.calculate_LR(X_train, y_train)
    
    def calculate_GPR(self, X_train, y_train):
        med_w = np.median(pdist(X_train, 'euclidean'))
        gpr = GPR().fit(X_train/med_w, y_train)
        return y_train.reshape(-1,1) - gpr.predict(X_train/med_w).reshape(-1,1)

    ####### score calculations

    def calculate_reward_single_graph(self, graph_batch, nons_vars, lambda1, lambda2, lambda3):
        graph_to_int = []
        graph_to_int2 = []

        ###
 #        graph_batch = np.array([[0., 0., 0., 0., 1.],
 # [1., 0., 0., 0., 1.],
 # [0., 1., 0., 0., 0.],
 # [0., 0., 1., 0., 1.],
 # [0., 0., 0., 0., 0.]]

        # graph_batch = np.array( [[0., 0., 0., 0., 1, 1],
        #  [0., 0., 0., 0, 0., 0],
        # [0., 0., 0., 0, 0., 0],
        # [0., 1., 1., 0., 0., 0],
        # [0., 0., 0., 0, 0., 0],
        # [0., 0., 0., 0, 0., 0]])

        # graph_batch = np.array([[0., 0., 0., 0., 0, 0],
        #                         [0., 0., 0., 0, 0., 0],
        #                         [0., 0., 0., 0, 0., 0],
        #                         [0., 1., 1., 0., 0., 0],
        #                         [1., 0., 0., 0, 0., 0],
        #                         [1., 0., 0., 0, 0., 0]])

        for i in range(self.maxlen):
            graph_batch[i][i] = 0
            tt = np.int32(graph_batch[i])
            graph_to_int.append(self.baseint * i + np.int(''.join([str(ad) for ad in tt]), 2))
            graph_to_int2.append(np.int(''.join([str(ad) for ad in tt]), 2))

        graph_batch_to_tuple = tuple(graph_to_int2)
        if graph_batch_to_tuple in self.d:
            score_cyc = self.d[graph_batch_to_tuple]
            return self.penalized_score(score_cyc, lambda1, lambda2), score_cyc[0], score_cyc[1]

        RSS_ls = []
        MDS_ls = []
        params = (10,self.regression_lambda)
        for i in range(self.maxlen):
            if graph_to_int[i] in self.d_RSS:
                RSS_ls.append(self.d_RSS[graph_to_int[i]])
                MDS_ls.append(self.d_MDS[graph_to_int[i]])
                continue
            parents = np.where(graph_batch[i] > 0.5)[0]
            score_ = local_score_CV_general(self.inputdata, [i], parents, params, self.kernel_coe_data, self.kernel_coe_index)
            RSS_ls.append(score_)
            # if(nons_vars[i] == 1):
            #     if parents.shape[0] == 0:  # no parents
            #         MDS_ = 0
            #     else:
            #         m_inputdata_i = matlab.double(self.inputdata[:, [i]].tolist())
            #         m_inputdata_pari = matlab.double(self.inputdata[:, parents].tolist())
            #         MDS_ = self.eng.infer_nonsta_dir2(m_inputdata_pari, m_inputdata_i, matlab.double([0.01]),
            #                                           matlab.double([1]), matlab.double([self.n_datasets]))
            # else:
            #     MDS_ = 0
            if i == self.maxlen-1:
                assert parents.shape[0] == 0
            if parents.shape[0] == 0:  # no parents
                MDS_ = 0
            else:
                if parents[-1] == self.maxlen-1:  # if parents contains index
                    assert nons_vars[i]==1
                    parents = parents[:-1]
                if parents.shape[0] == 0:
                    MDS_ = 0
                else:
                    m_inputdata_i = matlab.double(self.inputdata[:, [i]].tolist())
                    m_inputdata_pari = matlab.double(self.inputdata[:, parents].tolist())
                    MDS_ = self.eng.infer_nonsta_dir2(m_inputdata_pari, m_inputdata_i, matlab.double([0.1]),
                                                      matlab.double([1]), matlab.double([self.n_datasets]))

            MDS_ls.append(MDS_)
            # print(np.sum(RSS_ls))
            self.d_RSS[graph_to_int[i]] = score_
            self.d_MDS[graph_to_int[i]] = MDS_
        # print(MDS_ls)
        # score = self.score_transform(np.sum(RSS_ls))
        score = np.sum(RSS_ls)/self.inputdata.shape[0]
        MDS_penalty = np.sum(MDS_ls)/np.count_nonzero(MDS_ls)
        cycness = np.trace(matrix_exponential(np.array(graph_batch)))- self.maxlen
        cycness_penalty = (lambda1*np.float(cycness>1e-5) + lambda2*cycness)
        reward = score + cycness_penalty + lambda3 * MDS_penalty

        if self.l1_graph_reg > 0:
            reward = reward + self.l1_graph_reg * np.sum(graph_batch)
            score = score + self.l1_graph_reg * np.sum(graph_batch)

        self.d[graph_batch_to_tuple] = (score+lambda3*MDS_penalty, cycness)

        if self.verbose:
            self._logger.info('BIC: {}, cycness: {}, returned reward: {}'.format(score, cycness, final_score))
        # print('reward: {}, score: {}, cyc: {}, MDS: {}'.format(reward, score, cycness_penalty, lambda3 * MDS_penalty))
        return reward, score+lambda3*MDS_penalty, cycness

    #### helper
    
    def score_transform(self, s):
        print("s: {}, sl: {}, su: {}".format(s, self.sl, self.su))
        return (s-self.sl)/(self.su-self.sl)*self.lambda1_upper
        # return (s-self.sl)/(self.su-self.sl)


    def penalized_score(self, score_cyc, lambda1, lambda2):
        score, cyc = score_cyc
        return score + lambda1*np.float(cyc>1e-5) + lambda2*cyc
    
    def update_scores(self, score_cycs, lambda1, lambda2):
        ls = []
        for score_cyc in score_cycs:
            ls.append(self.penalized_score(score_cyc, lambda1, lambda2))
        return ls
    
    def update_all_scores(self, lambda1, lambda2):
        score_cycs = list(self.d.items())

        ls = []
        for graph_int, score_cyc in score_cycs:
            ls.append((graph_int, (self.penalized_score(score_cyc, lambda1, lambda2), score_cyc[0], score_cyc[1])))
        return sorted(ls, key=lambda x: x[1][0])

