import numpy as np
from sklearn.linear_model import LinearRegression
from scipy.linalg import expm as matrix_exponential
from scipy.spatial.distance import pdist, squareform
from sklearn.gaussian_process import GaussianProcessRegressor as GPR
from sklearn.preprocessing import PolynomialFeatures
import logging


class get_Reward_MDS(object):

    _logger = logging.getLogger(__name__)

    def __init__(self, batch_num, maxlen, dim, inputdata, sl, su, lambda1_upper, 
                 score_type='BIC', reg_type='LR', l1_graph_reg=0.0, verbose_flag=True, n_domains=1):
        self.batch_num = batch_num
        self.maxlen = maxlen # =d: number of vars
        self.dim = dim
        self.baseint = 2**maxlen
        self.d = {} # store results
        self.d_RSS = {} # store RSS for reuse
        self.inputdata = inputdata[:, 0:-1]
        self.domaindata = inputdata[:, -1]
        self.n_domains = n_domains
        self.n_samples = inputdata.shape[0]
        self.n_samples_each_domain = np.int(self.inputdata.shape[0]/self.n_domains)
        self.l1_graph_reg = l1_graph_reg 
        self.verbose = verbose_flag
        self.sl = sl
        self.su = su
        self.lambda1_upper = lambda1_upper
        self.bic_penalty = np.log(self.n_samples_each_domain)/(self.n_samples_each_domain)

        if score_type not in ('BIC', 'BIC_different_var'):
            raise ValueError('Reward type not supported.')
        if reg_type not in ('LR', 'QR', 'GPR'):
            raise ValueError('Reg type not supported')
        self.score_type = score_type
        self.reg_type = reg_type

        self.ones = np.ones((self.n_samples_each_domain, 1), dtype=np.float32)
        self.poly = PolynomialFeatures()

    def cal_rewards(self, graphs, nons_vars, lambda1, lambda2, lambda3):
        rewards_batches = []

        for graphi in graphs:
            reward_ = self.calculate_reward_MDS_single_graph(graphi, nons_vars, lambda1, lambda2, lambda3)
            rewards_batches.append(reward_)

        return np.array(rewards_batches)


    ####### regression 

    def calculate_yerr(self, X_train, y_train):
        if self.reg_type == 'LR':
            return self.calculate_LR(X_train, y_train)
        elif self.reg_type == 'QR':
            return self.calculate_QR(X_train, y_train)
        elif self.reg_type == 'GPR':
            return self.calculate_GPR(X_train, y_train)
        else:
            # raise value error
            assert False, 'Regressor not supported'

    # faster than LinearRegression() from sklearn
    def calculate_LR(self, X_train, y_train):
        X = np.hstack((X_train, self.ones))
        XtX = X.T.dot(X)
        Xty = X.T.dot(y_train)
        theta = np.linalg.solve(XtX, Xty)
        y_err = X.dot(theta) - y_train
        return y_err, theta[-1]  # assume "b" varies

    def calculate_QR(self, X_train, y_train):
        X_train = self.poly.fit_transform(X_train)[:,1:]
        return self.calculate_LR(X_train, y_train)
    
    def calculate_GPR(self, X_train, y_train):
        med_w = np.median(pdist(X_train, 'euclidean'))
        gpr = GPR().fit(X_train/med_w, y_train)
        return y_train.reshape(-1,1) - gpr.predict(X_train/med_w).reshape(-1,1)

    def calculate_MDS_penalty(self, nons_theta, nons_vars, graph_batch):
        MDS_penalty = []
        for i in range(self.maxlen):
            if(nons_vars[i]==1):  # if i is nonstationary
                for j in range(self.maxlen):
                    if(graph_batch[i, j]==1 and nons_vars[j]==1):  # if j is i's parent and j is nons
                        cor = np.abs(np.corrcoef(nons_theta[:, i], nons_theta[:, j])[0, 1])
                        if(np.isnan(cor)):  # in some extreme cases when stage 1 wrongly detect a stationary var as nons
                            cor=0
                        MDS_penalty.append(cor)
        if len(MDS_penalty)==0:
            return 0
        else:
            return np.mean(MDS_penalty)



    ####### score calculations

    def calculate_reward_MDS_single_graph(self, graph_batch, nons_vars, lambda1, lambda2, lambda3):
        graph_to_int = []
        graph_to_int2 = []

        for i in range(self.maxlen):
            graph_batch[i][i] = 0
            tt = np.int32(graph_batch[i])
            graph_to_int.append(self.baseint * i + np.int(''.join([str(ad) for ad in tt]), 2))
            graph_to_int2.append(np.int(''.join([str(ad) for ad in tt]), 2))

        graph_batch_to_tuple = tuple(graph_to_int2)

        if graph_batch_to_tuple in self.d:
            score_cyc = self.d[graph_batch_to_tuple]
            return self.penalized_score(score_cyc, lambda1, lambda2), score_cyc[0], score_cyc[1]


        # cal BIC for each dataset and return corresponding pars
        nons_theta = np.zeros([self.n_domains, self.maxlen])
        BIC_list = []
        for domain_index in range(self.n_domains):
            domain_data = self.inputdata[domain_index*self.n_samples_each_domain:(domain_index+1)*self.n_samples_each_domain, :]
            RSS_ls = []
            for i in range(self.maxlen):
                col = graph_batch[i]
                if graph_to_int[i] in self.d_RSS:
                    RSS_ls.append(self.d_RSS[graph_to_int[i]])
                    continue

                # no parents, then simply use mean
                if np.sum(col) < 0.1:
                    y_err = domain_data[:, i]
                    theta = np.mean(y_err)
                    y_err = y_err - np.mean(y_err)

                else:
                    cols_TrueFalse = col > 0.5
                    X_train = domain_data[:, cols_TrueFalse]
                    y_train = domain_data[:, i]
                    y_err, theta= self.calculate_yerr(X_train, y_train)
                nons_theta[domain_index, i] = theta
                RSSi = np.sum(np.square(y_err))

                # if the regresors include the true parents, GPR would result in very samll values, e.g., 10^-13
                # so we add 1.0, which does not affect the monotoniticy of the score
                if self.reg_type == 'GPR':
                    RSSi += 1.0

                RSS_ls.append(RSSi)
                self.d_RSS[graph_to_int[i]] = RSSi

            if self.score_type == 'BIC':
                BIC = np.log(np.sum(RSS_ls)/self.n_samples_each_domain+1e-8) \
                      + np.sum(graph_batch)*self.bic_penalty/self.maxlen
            elif self.score_type == 'BIC_different_var':
                BIC = np.sum(np.log(np.array(RSS_ls)/self.n_samples_each_domain+1e-8)) \
                     + np.sum(graph_batch)*self.bic_penalty
            BIC_list.append(BIC)

        score = self.score_transform(np.mean(BIC_list))
        cycness = np.trace(matrix_exponential(np.array(graph_batch))) - self.maxlen
        MDS_penalty = self.calculate_MDS_penalty(nons_theta, nons_vars, graph_batch)
        reward = score + lambda1*np.float(cycness>1e-5) + lambda2*cycness + lambda3*MDS_penalty
            
        if self.l1_graph_reg > 0:
            reward = reward + self.l1_grapha_reg * np.sum(graph_batch)
            score = score + self.l1_grapha_reg * np.sum(graph_batch)

        self.d[graph_batch_to_tuple] = (score+lambda3*MDS_penalty, cycness)

        if self.verbose:
            self._logger.info('BIC: {}, cycness: {}, returned reward: {}'.format(BIC, cycness, final_score))
        # print('score: {}, cyc: {}, MDS_p: {}'.format(score, lambda1*np.float(cycness>1e-5) + lambda2*cycness, lambda3*MDS_penalty))
        return reward, score+lambda3*MDS_penalty, cycness

    #### helper
    
    def score_transform(self, s):
        return (s-self.sl)/(self.su-self.sl)*self.lambda1_upper

    def penalized_score(self, score_cyc, lambda1, lambda2):
        score, cyc = score_cyc
        return score + lambda1*np.float(cyc>1e-5) + lambda2*cyc
    
    def update_scores(self, score_cycs, lambda1, lambda2):
        ls = []
        for score_cyc in score_cycs:
            ls.append(self.penalized_score(score_cyc, lambda1, lambda2))
        return ls
    
    def update_all_scores(self, lambda1, lambda2):
        score_cycs = list(self.d.items())
        ls = []
        for graph_int, score_cyc in score_cycs:
            ls.append((graph_int, (self.penalized_score(score_cyc, lambda1, lambda2), score_cyc[0], score_cyc[1])))
        return sorted(ls, key=lambda x: x[1][0])
