import numpy as np
from pyscipopt import Model, quicksum

from rabit._utils import Y_TARGET, find_best_actions, compute_action_indicators



class RecourseExplainer():
    
    def __init__(
        self,
        estimator,
        action,
        max_features=-1,
        confidence=0.5, 
        plausibility=1.0,
        max_search=1000,         
    ):
        
        self.estimator = estimator
        self.action = action
        self.max_features = max_features
        self.confidence = confidence
        self.plausibility = plausibility
        self.max_search = max_search
            
    
    def _get_regions(self):
        regions = []
        for estimator in self.estimator.estimators_:
            leaves = np.array([ j for j in range(estimator.tree_.node_count) if estimator.tree_.feature[j] < 0 and estimator.tree_.value[j] > 0])
            if len(leaves) == 0: continue
            regions.append(estimator.regions_[leaves])
        if len(regions) == 0: 
            return []
        regions = np.concatenate(regions, axis=0)
        if regions.shape[0] > self.max_search:
            regions = regions[np.random.choice(regions.shape[0], size=self.max_search, replace=False)]
        return regions


    def _get_validity(self, X, A, F):
        V = F * (self.estimator.predict_proba(np.repeat(X, A.shape[1], axis=0) + np.concatenate(A, axis=0))[:, Y_TARGET] >= self.confidence).reshape(F.shape)
        if self.plausibility < 1.0:
            P = -1 * self.action._plausibility.score_samples(np.repeat(X, A.shape[1], axis=0) + np.concatenate(A, axis=0))
            V = V * (P <= self.plausibility).reshape(F.shape)
        return V

    
    def _get_counterfactuals(self, X, A_opt):
        X_cf = X + A_opt
        return X_cf
    
    
    def explain_recourse(self, X):
       
        is_target = (self.estimator.predict(X) == Y_TARGET)
        regions = self._get_regions()

        if len(regions) == 0:
            results = {
                'budget': self.action.cost_budget,
                'is_target': is_target,
                'X': X, 
                'action': np.zeros_like(X), 
                'X_cf': X, 
                'cost': np.zeros(X.shape[0]), 
                'valid': np.zeros(X.shape[0], dtype=np.bool_), 
            }
            return Recourse(**results)
        
        A, C, F = self.action.enumerate_actions(X, regions, self.max_features)
        V = self._get_validity(X, A, F)
            
        CA = find_best_actions(X, A, V, C)
        A_opt, C_opt = CA[:, 1:], CA[:, 0]

        X_cf = self._get_counterfactuals(X, A_opt)
        V_opt = (self.estimator.predict(X_cf) == Y_TARGET)
        P_opt = -1 * self.action._plausibility.score_samples(X_cf)
        
        results = {
            'budget': self.action.cost_budget,
            'is_target': is_target,
            'X': X, 
            'action': A_opt, 
            'X_cf': X_cf, 
            'cost': C_opt, 
            'valid': V_opt, 
            'plausibility': P_opt,
        }
        return Recourse(**results)    
    

    def generate_recourse_calibration_samples(self, X):
        
        regions = self._get_regions()
        A, C, F = self.action.enumerate_actions(X, regions, self.max_features)
        F = F * (C <= self.action.cost_budget)
        probabilities = (self.estimator.predict_proba(np.repeat(X, A.shape[1], axis=0) + np.concatenate(A, axis=0))[:, Y_TARGET]).reshape(F.shape)
        PA = find_best_actions(X, A, F, 1-probabilities)
        X_cf = X + PA[:, 1:]
        return X_cf

    
    

class ExactRecourseExplainer(RecourseExplainer):
    
    def explain_recourse(self, X, time_limit=60, verbose=False):

        is_target = (self.estimator.predict(X) == Y_TARGET)
        thresholds, feature_pointer = self._get_thresholds()
        regions, weights, tree_pointer = self._get_regions_and_weights()
        A_opt = []
        X_cf = []
        C_opt = []
        V_opt = []
        P_opt = []

        for x in X:
            X_i = x.reshape(1, -1)
            A = self.action._get_action(X_i, thresholds)
            C = self.action._get_cost(X_i, A, feature_pointer)
            A = A[0]; C = C[0];         
            A_ins = np.zeros((A.shape[0] + self.action.n_features, A.shape[1]), dtype=np.float64)
            C_ins = np.zeros(A_ins.shape[0], dtype=np.float64)
            for d in range(self.action.n_features):
                A_ins[(feature_pointer[d]+d), 0] = d
                A_ins[(feature_pointer[d]+d+1):(feature_pointer[d+1]+d+1)] = A[feature_pointer[d]:feature_pointer[d+1]]    
                C_ins[feature_pointer[d]+d+1:feature_pointer[d+1]+d+1] = C[feature_pointer[d]:feature_pointer[d+1]]            
            A = A_ins[C_ins != np.inf]; C = C_ins[C_ins != np.inf];     
        
            I = []
            for t in range(self.estimator.n_estimators):
                regions_t = regions[tree_pointer[t]:tree_pointer[t+1]]
                I_t = compute_action_indicators(x, A, regions_t)
                I.append(I_t)        
        
            model = self._get_milo_model(x, A, C, I, weights, self.max_features)
            model.hideOutput(not verbose)
            model.setParam('limits/time', time_limit)
            model.optimize()
            
            if model.getStatus() == 'infeasible':
                a = np.zeros(self.action.n_features, dtype=np.float64)
                c = 0.0
            else:
                var_dict = model.getVarDict()
                a = np.array([var_dict['action_{:04d}'.format(d)] for d in range(self.action.n_features)])
                c = var_dict['cost']
            A_opt.append(a)
            C_opt.append(c)
            X_cf.append(x + a)

        A_opt = np.array(A_opt)
        C_opt = np.array(C_opt)
        X_cf = np.array(X_cf)
        V_opt = (self.estimator.predict(X_cf) == Y_TARGET)
        P_opt = -1 * self.action._plausibility.score_samples(X_cf)            
        
        results = {
            'budget': self.action.cost_budget,
            'is_target': is_target,
            'X': X, 
            'action': A_opt, 
            'X_cf': X_cf, 
            'cost': C_opt, 
            'valid': V_opt, 
            'plausibility': P_opt,
        }
        return Recourse(**results)            

    def _get_thresholds(self):
        thresholds = []
        feature_pointer = [ 0 ]
        for d in range(self.action.n_features):
            thresholds_d = []
            for n_estimator in range(self.estimator.n_estimators):
                thresholds_d += self.estimator.estimators_[n_estimator].tree_.threshold[self.estimator.estimators_[n_estimator].tree_.feature == d].tolist()
            thresholds_d = sorted(list(set(thresholds_d)))
            if len(thresholds_d) > 0:
                thresholds.append(list(zip([d] * len(thresholds_d), thresholds_d)))
            feature_pointer.append(feature_pointer[-1] + len(thresholds_d))
        thresholds = np.concatenate(thresholds, axis=0)
        return thresholds, feature_pointer
        
    def _get_regions_and_weights(self):
        regions = []
        weights = []
        tree_pointer = [ 0 ]
        for estimator in self.estimator.estimators_:
            leaves = np.array([ j for j in range(estimator.tree_.node_count) if estimator.tree_.feature[j] < 0])
            if len(leaves) == 0: continue
            regions.append(estimator.regions_[leaves])
            weights.append(estimator.tree_.value[leaves])
            tree_pointer.append(tree_pointer[-1] + len(leaves))
        regions = np.concatenate(regions, axis=0)
        weights = np.concatenate(weights, axis=0)
        return regions, weights, tree_pointer

    def _get_milo_model(self, x, A, C, I, weights, max_features):

        n_features = self.action.n_features
        max_features = max_features if max_features > 0 else n_features
        
        As = [A[A[:, 0] == d] for d in range(n_features)]
        Cs = [C[A[:, 0] == d] for d in range(n_features)]
        J = [len(A_d) for A_d in As]
        lb = [min(A_d[:, 1]) for A_d in As]
        ub = [max(A_d[:, 1]) for A_d in As]
        L = [I_t.shape[0] for I_t in I]
        
        model = Model()
        def LinSum(Vars): return quicksum(Vars)
        def LinExpr(Coeffs, Vars): return quicksum(Coeffs[i] * Vars[i] for i in range(len(Coeffs)))
        def flatten(x): return sum(x, [])
        
        action = [
            model.addVar(name='action_{:04d}'.format(d), vtype='C', lb=lb[d], ub=ub[d]) for d in range(n_features)
        ] 
        cost = model.addVar(name='cost', vtype='C', lb=0)
        pi = [
            [model.addVar(name='pi_{:04d}_{:04d}'.format(d, j), vtype='B') for j in range(J[d])] for d in range(n_features)
        ] 
        phi  = [
            [model.addVar(name='phi_{:04d}_{:04d}'.format(t, l), vtype='B') for l in range(L[t])] for t in range(self.estimator.n_estimators) 
        ]         
        
        model.setObjective(cost, sense='minimize')
        if self.action.cost_type == 'MPS':
            for d in range(n_features):
                if (d in flatten(self.action.categories) and np.min(As[d][:, 1]) < 0) or self.action.is_immutable[d]:
                    continue
                model.addCons(cost - LinExpr(Cs[d], pi[d]) >= 0, name='C_cost_{:04d}'.format(d))
        else:
            model.addCons(cost - LinExpr(C, flatten(pi)) == 0, name='C_cost')        
        
        for d in range(n_features): 
            model.addCons(LinSum(pi[d]) == 1, name='C_basic_pi_{:04d}'.format(d))
            model.addCons(action[d] - LinExpr(As[d][:, 1], pi[d]) == 0, name='C_basic_act_{:04d}'.format(d))

        nonzeros = (A[:, 1] != 0)
        model.addCons(LinExpr(nonzeros, flatten(pi)) <= max_features, name='C_basic_sparsity')

        for i, G in enumerate(self.action.categories): 
            model.addCons(LinSum([action[d] for d in G]) == 0, name='C_basic_category_{:04d}'.format(i))

        model.addCons(LinExpr(weights, flatten(phi)) >= 1e-8, name='C_loss')

        for t in range(self.estimator.n_estimators):
            model.addCons(LinSum(phi[t]) == 1, name='C_forest_leaf_{:04d}'.format(t))
            for l in range(L[t]):
                model.addCons(n_features * phi[t][l] - LinExpr(I[t][l], flatten(pi)) <= 0, name='C_forest_decision_{:04d}_{:04d}'.format(t, l))
        
        return model





class Recourse():
    
    def __init__(
        self, 
        budget,
        is_target,
        X,
        action,
        X_cf,
        cost,
        valid,
        plausibility,
    ):
        
        self.budget = budget
        self.is_target = is_target
        self.X = X
        self.action = action
        self.X_cf = X_cf
        self.cost = cost
        self.valid = valid
        self.plausibility = plausibility
              
        
    def get_recourse(self, budget=True):
        if budget:
            return (self.valid & (self.cost <= self.budget)).mean()
        else:
            return self.valid.mean()

    def get_validity(self, budget=True):
        if budget:
            return (self.valid[~self.is_target] & (self.cost[~self.is_target] <= self.budget)).mean()
        else:
            return self.valid[~self.is_target].mean()

    def get_cost(self, valid=True):
        if valid:
            return self.cost[(~self.is_target) & self.valid].mean()
        else:
            return self.cost[~self.is_target].mean()
    
    def get_plausibility(self, valid=True):
        if valid:
            return self.plausibility[(~self.is_target) & self.valid].mean()
        else:
            return self.plausibility[~self.is_target].mean()

    def get_sparsity(self, valid=True):
        if valid:
            non_zeros = np.count_nonzero(self.action[(~self.is_target) & self.valid], axis=1)
        else:
            non_zeros = np.count_nonzero(self.action[~self.is_target], axis=1)
        return non_zeros.mean() 
    
    def get_recourse_for_each_group(self, sensitive_indices):
        Z = self.X[:, sensitive_indices]
        recourse = (self.valid & (self.cost <= self.budget))
        recourse_group = np.array([recourse[Z[:, i] == 1].mean() if (Z[:, i] == 1).sum() > 0 else 0.0 for i in range(Z.shape[1])])
        return recourse_group
    
    def get_unfairness(self, sensitive_indices):
        recourse_group = self.get_recourse_for_each_group(sensitive_indices)
        return recourse_group.max() - recourse_group.min()