import numpy as np
import GPy
import copy 
from GPy.util.univariate_Gaussian import inv_std_norm_cdf
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeRegressor 
from gurobipy import * 

class SurrogateGaussianProcess:
    
    def __init__(self, column_name): 
        self.column_name = column_name
    
    def learn_GP(self, X, y, kernel, m=100, X_normalize=True, likelihood_function='Gaussian', tau=True, foo=1e-5):
        self.X_normalize = X_normalize 
        self.likelihood_function = likelihood_function 
        self.d = X.shape[1]        
        if self.X_normalize:
            self.SS = StandardScaler()
            X = self.SS.fit_transform(X)
        if self.likelihood_function == 'Gaussian':
            likelihood = GPy.likelihoods.Gaussian() 
            mean_function = GPy.mappings.Constant(self.d, 1, value=np.mean(y)) 
        elif self.likelihood_function == 'Bernoulli':
            likelihood = GPy.likelihoods.Bernoulli()
            mean_function = GPy.mappings.Constant(self.d, 1, value=inv_std_norm_cdf(min(max(np.mean(y), foo), 1 - foo))) 
        elif self.likelihood_function == 'Poisson':
            likelihood = GPy.likelihoods.Poisson()                        
            mean_function = GPy.mappings.Constant(self.d, 1, value=np.log(max(np.mean(y), foo))) 
        Z = KMeans(n_clusters=m).fit(X).cluster_centers_       
        if tau:
            self.GPR = GPy.core.SVGP(X=X, Y=y[:,None], Z=Z, kernel=kernel, likelihood=likelihood, mean_function=mean_function) 
        else:
            self.GPR = GPy.core.SVGP(X=X, Y=y[:,None], Z=Z, kernel=kernel, likelihood=likelihood) 
        self.GPR.optimize(optimizer="LBFGS") 
        
    def make_posterior(self, X_): 
        self.n_ = len(X_) 
        self.X_ = self.SS.transform(X_) if self.X_normalize else copy.copy(X_) 
        mu, Sigma = self.GPR._raw_predict(self.X_, full_cov=True) 
        self.mu = mu[:,0] 
        self.weight = [1 / Sigma[i,i,0] for i in range(self.n_)]
    
    def get_score(self, omega):
        vhat = []
        for j in range(max(omega) + 1):
            vhat_ = [self.mu[i] for i in range(self.n_) if omega[i] == j]
            if len(vhat_) > 0:
                vhat += [np.mean(vhat_)]
            else:
                vhat += [0]
        score = (sum([((self.mu[i] - vhat[omega[i]]) ** 2) * self.weight[i] for i in range(self.n_)]) / sum(self.weight)) ** 0.5
        return score 
    
    def CART(self, depth=2, n_min=1): 
        initial_tree = DecisionTreeRegressor(criterion='squared_error', splitter='best', max_depth=depth, min_samples_leaf=n_min) 
        initial_tree.fit(self.X_, self.mu, sample_weight=self.weight)                        
        omega = initial_tree.apply(self.X_) 
        omega_set = list(set(omega))
        left = list(initial_tree.tree_.children_left)
        right = list(initial_tree.tree_.children_right)         
        leaf = []
        for node in omega_set:
            node_ = node 
            node__ = - 1 
            direction = []
            while node__ != node_:
                node__ = node_            
                if node_ in left:
                    node_ = left.index(node_)
                    direction = [0] + direction
                elif node_ in right: 
                    node_ = right.index(node_)
                    direction = [1] + direction
            leaf += [sum([direction[i] * (2 ** (depth - i - 1)) for i in range(len(direction))])]
        omega = [leaf[omega_set.index(omega[i])] for i in range(self.n_)]
        return omega  
    
    def find_cluster(self, l, depth=2, time_limit=600, n_min=1, surrogate_model="tree", label=None, neighbors_list=None, display=True): 
        if label is None:
            label = [i for i in range(self.n_)] 
        n_granularity = len(set(label))  
        n_list = [sum([1 for i in range(self.n_) if label[i] == j]) for j in range(n_granularity)]
        l_ = 2 ** depth if surrogate_model == "tree" else l 
        M = max(max(self.mu), max(- self.mu))
        model = gurobipy.Model("qp") 
        w = [[model.addVar(vtype=GRB.BINARY) for j in range(l_)] for k in range(n_granularity)] 
        v = [model.addVar(lb=-gurobipy.GRB.INFINITY, ub=gurobipy.GRB.INFINITY, vtype=GRB.CONTINUOUS) for j in range(l_)]
        vbar = [[model.addVar(lb=-gurobipy.GRB.INFINITY, ub=gurobipy.GRB.INFINITY, vtype=GRB.CONTINUOUS) for j in range(l_)] for k in range(n_granularity)] 
        alpha = [model.addVar(vtype=GRB.BINARY) for j in range(l_)] 
        dummy = [quicksum([vbar[k][j] for j in range(l_)]) for k in range(n_granularity)] 
        model.setObjective(0.5 * quicksum(((dummy[label[i]] - self.mu[i]) ** 2) * self.weight[i] for i in range(self.n_)), GRB.MINIMIZE) 
        model.addConstr(quicksum(alpha[j] for j in range(l_)) <= l) 
        for k in range(n_granularity): 
            for j in range(l_):
                model.addConstr(vbar[k][j] <= M * w[k][j]) 
                model.addConstr(vbar[k][j] >= - M * w[k][j]) 
                model.addConstr(vbar[k][j] <= v[j] + M * (1 - w[k][j])) 
                model.addConstr(vbar[k][j] >= v[j] - M * (1 - w[k][j])) 
        for k in range(n_granularity):
            model.addConstr(quicksum(w[k][j] for j in range(l_)) == 1) 
        for j in range(l_):
            model.addConstr(n_list[k] * quicksum(w[k][j] for k in range(n_granularity)) <= self.n_ * alpha[j]) 
            model.addConstr(n_list[k] * quicksum(w[k][j] for k in range(n_granularity)) >= n_min * alpha[j]) 
        if surrogate_model == "tree":     
            interval = [] 
            for j in range(self.d):
                X__ = np.sort(self.X_[:,j])
                interval_ = [X__[i + 1] - X__[i] for i in range(self.n_ - 1) if X__[i + 1] != X__[i]]
                if len(interval_) > 0:
                    interval += [min(interval_)]                
            epsilon = 0.5 * min(interval)            
            delta = max([max(self.X_[:,j]) - min(self.X_[:,j]) for j in range(self.d)]) + epsilon 
            s = [[[model.addVar(lb=-gurobipy.GRB.INFINITY, ub=gurobipy.GRB.INFINITY, vtype=GRB.CONTINUOUS) for i in range(self.d)] for j in range(2 ** k)] for k in range(depth + 1)]
            t = [[[model.addVar(lb=-gurobipy.GRB.INFINITY, ub=gurobipy.GRB.INFINITY, vtype=GRB.CONTINUOUS) for i in range(self.d)] for j in range(2 ** k)] for k in range(depth + 1)]
            gamma = [[[model.addVar(vtype=GRB.BINARY) for i in range(self.d)] for j in range(2 ** k)] for k in range(depth)]                    
            for k in range(self.n_):
                for j in range(l_):
                    for i in range(self.d):
                        model.addConstr(s[depth][j][i] - delta * (1 - w[label[k]][j]) <= self.X_[k,i])                                   
                        model.addConstr(t[depth][j][i] + delta * (1 - w[label[k]][j]) - epsilon >= self.X_[k,i])                             
            for i in range(self.d):
                model.addConstr(s[0][0][i] == min([self.X_[k,i] for k in range(self.n_)]))
                model.addConstr(t[0][0][i] == max([self.X_[k,i] for k in range(self.n_)]) + epsilon)
            for k in range(depth):
                for j in range(2 ** k):
                    model.addConstr(quicksum(gamma[k][j][i] for i in range(self.d)) == 1)                
                    for i in range(self.d):
                        model.addConstr(s[k][j][i] <= t[k + 1][2 * j + 0][i]) 
                        model.addConstr(t[k + 1][2 * j + 0][i] <= t[k][j][i])
                        model.addConstr(s[k][j][i] <= s[k + 1][2 * j + 1][i])
                        model.addConstr(s[k + 1][2 * j + 1][i] <= t[k][j][i]) 
                        model.addConstr(s[k + 1][2 * j + 1][i] <= t[k + 1][2 * j + 0][i])
                        model.addConstr(s[k + 1][2 * j + 0][i] == s[k][j][i])
                        model.addConstr(t[k + 1][2 * j + 1][i] == t[k][j][i])
                        model.addConstr(s[k + 1][2 * j + 1][i] - s[k][j][i] <= delta * gamma[k][j][i])
                        model.addConstr(t[k][j][i] - t[k + 1][2 * j + 0][i] <= delta * gamma[k][j][i])
                        model.addConstr(t[k + 1][2 * j + 0][i] - s[k + 1][2 * j + 1][i] <= delta * (1 - gamma[k][j][i]))  
        elif surrogate_model == "DAG": 
            e = [[model.addVar(vtype=GRB.BINARY) for j in range(n_granularity)] for i in range(n_granularity)] 
            r = [[model.addVar(vtype=GRB.BINARY) for j in range(n_granularity)] for i in range(n_granularity)] 
            beta = [model.addVar(vtype=GRB.BINARY) for i in range(n_granularity)] 
            model.addConstr(quicksum(beta[i] for i in range(n_granularity)) == quicksum(alpha[j] for j in range(l_)))  
            for i in range(n_granularity): 
               for j in range(n_granularity): 
                    if i != j:
                        model.addConstr(- l_ * (1 - e[i][j]) <= quicksum((k + 1) * (w[i][k] - w[j][k]) for k in range(l_))) 
                        model.addConstr(l_ * (1 - e[i][j]) >= quicksum((k + 1) * (w[i][k] - w[j][k]) for k in range(l_))) 
                        model.addConstr(e[j][i] <= 1 - r[i][j]) 
                        model.addConstr(r[j][i] == 1 - r[i][j]) 
            for i in range(n_granularity - 2):
                for j in range(i + 1, n_granularity - 1):
                    for k in range(j + 1, n_granularity):
                        model.addConstr(0 <= r[i][j] + r[j][k] - r[i][k]) 
                        model.addConstr(1 >= r[i][j] + r[j][k] - r[i][k]) 
            if neighbors_list is None:
                for i in range(n_granularity):                 
                    model.addConstr(1 - beta[i] <= quicksum(e[j][i] for j in range(n_granularity) if j != i)) 
                    model.addConstr(e[i][i] == 0)                     
            else:
                for i in range(n_granularity): 
                    model.addConstr(1 - beta[i] <= quicksum(e[j][i] for j in neighbors_list[i])) 
                    for j in range(n_granularity): 
                        if j not in neighbors_list[i]: 
                            model.addConstr(e[j][i] == 0) 
        model.Params.OutputFlag = 0  
        model.Params.TimeLimit = time_limit 
        model.optimize() 
        if display: 
            if surrogate_model == "tree":  
                for k in range(depth): 
                    for j in range(2 ** k): 
                        for i in range(self.d): 
                            if round(gamma[k][j][i].x) == 1: 
                                print("depth:", k + 1, "node:", j + 1, "variable:", self.column_name[i], t[k + 1][2 * j + 0][i].x)                                 
        omega = [j for i in range(self.n_) for j in range(l_) if round(w[label[i]][j].x) == 1] 
        return omega  