import time
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from scipy.sparse import csr_matrix
from auxClasses import Clause, Rule

class RUXClassifier:
    
        def __init__(self, rf=None, 
                     ada=None, 
                     eps=0.01,
                     threshold=0.0,
                     use_ada_weights=False,
                     rule_length_cost=False, 
                     false_negative_cost=False,
                     negative_label=1.0, # For identifying false negatives
                     random_state=2516):
            
            self.eps = eps
            self.threshold = threshold
            self.wscale = 1.0
            self.vscale = 1.0
            self.fittedDTs = {}            
            self.randomState = random_state
            self.initNumOfRules = 0
            self.rules = {}      
            self.ruleInfo = {}
            self.K = None # number of classes
            self.labelToInteger = {} # mapping classes to integers
            self.integerToLabel= {} # mapping integers to classes
            self.vecY = None        
            self.majorityClass = None # which class is the majority
            self.missedXvals = None        
            self.numOfMissed = None
            # The following three vectors keep the Abar and A matrices
            # Used for CSR sparse matrices
            self.yvals = np.empty(shape=(0), dtype=np.float)
            self.rows = np.empty(shape=(0), dtype=np.int32)
            self.cols = np.empty(shape=(0), dtype=np.int32) 
            # The cost of each rule is stored
            self.costs = np.empty(shape=(0), dtype=np.float)
            self.ruleLengthCost = rule_length_cost
            self.falseNegativeCost = false_negative_cost
            self.negativeLabel = negative_label
            self.useADAWeights = use_ada_weights
            self.estimatorWeights = [] # Used with AdaBoost
            # For time keeping        
            self.fitTime = 0
            self.predictTime = 0
            # Classifier type
            self.classifier = ''
            
            self._checkOptions(rf, ada)

        def _checkOptions(self, rf, ada):
            
            if (rf == None and ada == None):
                print('RF or ADA should be provided')
                print('Exiting...')
                return None
            
            if (rf != None and ada != None):
                print('Both RF and ADA are provided')
                print('Proceeding with RF')
                ada = None
                
            if (rf != None):
                self.classifier = 'RF'
                
                if (self.useADAWeights):
                    print('Estimator weights work only with ADA')
                    self.useADAWeights = False
                    
                if (np.sum([self.ruleLengthCost, self.falseNegativeCost]) > 1):
                    print('Works with only one type of cost')
                    print('Proceeding with rule length')
                    self.falseNegativeCost = False
                    self.ruleLengthCost = True

                for treeno, fitTree in enumerate(rf.estimators_):
                    self.initNumOfRules += fitTree.get_n_leaves()
                    self.fittedDTs[treeno] = fitTree
                
            if (ada != None):
                self.classifier = 'ADA'
                
                if (ada.get_params()['algorithm'] != 'SAMME' and self.useADAWeights):
                    print('Estimator weights only work with SAMME algorithm of ADA')
                    print('Proceeding without estimator weights')
                    self.useADAWeights = False
                
                if (np.sum([self.ruleLengthCost, 
                            self.falseNegativeCost, 
                            self.useADAWeights]) > 1):
                    print('Works with only one type of cost')
                    print('Proceeding with estimator weights')
                    self.falseNegativeCost = False
                    self.ruleLengthCost = False
                    self.useADAWeights = True

                if (self.useADAWeights):
                    self.estimatorWeights = (1.0/(ada.estimator_weights_+1.0e-4))
                    
                for treeno, fitTree in enumerate(ada.estimators_):
                    self.initNumOfRules += fitTree.get_n_leaves()
                    self.fittedDTs[treeno] = fitTree

        def _cleanup(self):
            
            self.fittedDTs = {}   
            self.rules = {}      
            self.ruleInfo = {}
            self.labelToInteger = {} 
            self.integerToLabel= {}
            self.missedXvals = None        
            self.numOfMissed = None
            self.yvals = np.empty(shape=(0), dtype=np.float)
            self.rows = np.empty(shape=(0), dtype=np.int32)
            self.cols = np.empty(shape=(0), dtype=np.int32) 
            self.costs = np.empty(shape=(0), dtype=np.float)
            self.estimatorWeights = []            
                    
        def _getRule(self, fitTree, nodeid):
            
            if (fitTree.tree_.feature[0] == -2): # No rule case
                return Rule()
            left = fitTree.tree_.children_left
            right = fitTree.tree_.children_right
            threshold = fitTree.tree_.threshold
        
            def recurse(left, right, child, returnRule=None):
                if returnRule is None:
                    returnRule = Rule()                
                if child in left: # 'l'
                    parent = np.where(left == child)[0].item()
                    clause = Clause(feature=fitTree.tree_.feature[parent], 
                                    ub=threshold[parent])
                else: # 'r'               
                    parent = np.where(right == child)[0].item()
                    clause = Clause(feature=fitTree.tree_.feature[parent], 
                                    lb=threshold[parent])                
                returnRule.addClause(clause)
                if parent == 0:
                    return returnRule
                else:
                    return recurse(left, right, parent, returnRule)
        
            retRule = recurse(left, right, nodeid)
        
            return retRule
        
        
        def _getMatrix(self, X, y, fitTree, treeno):

            if (len(self.cols) == 0):
                col = 0
            else:
                col = max(self.cols) + 1 # Next column        
            y_rules = fitTree.apply(X) # Tells us which sample is in which leaf            
            for leafno in np.unique(y_rules):
                covers = np.where(y_rules == leafno)[0]
                leafYvals = y[covers] # y values of the samples in the leaf
                uniqueLabels, counts = np.unique(leafYvals, return_counts=True)
                label = uniqueLabels[np.argmax(counts)] # majority class in the leaf
                labelVector = np.ones(self.K)*(-1/(self.K-1))
                labelVector[self.labelToInteger[label]] = 1
                fillAhat = np.dot(self.vecY[:, covers].T, labelVector)
                self.rows = np.hstack((self.rows, covers))
                self.cols = np.hstack((self.cols, np.ones(len(covers), dtype=np.int32)*col))
                self.yvals = np.hstack((self.yvals, np.ones(len(covers), dtype=np.float)*fillAhat))
                if (self.falseNegativeCost):
                    cost = 1.0
                    if (label != self.negativeLabel and self.negativeLabel in leafYvals):
                        cost += np.exp(counts[int(self.negativeLabel)]/np.sum(counts))
                elif (self.ruleLengthCost):
                    tempRule = self._getRule(fitTree, leafno)
                    cost = tempRule.length()
                elif (self.useADAWeights):
                    cost = self.estimatorWeights[treeno]
                else:
                    cost = 1.0
                self.costs = np.append(self.costs, cost)
                self.ruleInfo[col] = (treeno, leafno, label)
                col += 1            

        def _getMatrices(self, X, y):
            
            for treeno, fitTree in enumerate(self.fittedDTs.values()):                    
                self._getMatrix(X, y, fitTree, treeno)

        def _preprocess(self, X, y):
            
            classes, classCounts = np.unique(y, return_counts=True)
            self.majorityClass = classes[np.argmax(classCounts)]
            for i, c in enumerate(classes):
                self.labelToInteger[c] = i
                self.integerToLabel[i] = c
            self.K = len(classes)
            n = len(y)
            self.vscale = 1.0
            self.vecY = np.ones((self.K, n))*(-1/(self.K-1))
            for i, c in enumerate(y):
                self.vecY[self.labelToInteger[c], i] = 1        
            
        def _fillRules(self, weights):
            
            weights = weights/np.max(weights) # Scaled weights
            selectedColumns = np.where(weights > self.threshold)[0] # Selected columns
            weightOrder = np.argsort(-weights[selectedColumns]) # Ordered weights
            orderedColumns = selectedColumns[weightOrder] # Ordered indices
            
            for i, col in enumerate(orderedColumns):
                treeno, leafno, label = self.ruleInfo[col]
                fitTree = self.fittedDTs[treeno]
                if (fitTree.get_n_leaves()==1):
                    self.rules[i] = Rule(label=self.majorityClass,
                                         clauses=[],
                                         weight=weights[col]) # No rule
                else:
                    self.rules[i] = self._getRule(fitTree, leafno)
                    self.rules[i].label = label
                    self.rules[i].weight = weights[col]                
                    self.rules[i]._cleanRule()                
    
        def _solvePrimal(self):
    
            Ahat = csr_matrix((self.yvals, (self.rows, self.cols)), dtype=np.float)
            data = np.ones(len(self.rows), dtype=np.int32)        
            A = csr_matrix((data, (self.rows, self.cols)), dtype=np.int32)        
            
            n, m = max(self.rows)+1, max(self.cols)+1
            self.wscale = 1.0/np.max(self.costs)
            self.costs *= self.wscale
            # Primal Model
            modprimal = gp.Model('RUG Primal')
            modprimal.setParam('OutputFlag', False)
            # variables
            vs = modprimal.addMVar(shape=int(n), name='vs')
            ws = modprimal.addMVar(shape=int(m), name='ws') 
            # objective
            modprimal.setObjective((np.ones(n)*self.vscale) @ vs + 
                                   self.costs @ ws, GRB.MINIMIZE)
            #constraints
            modprimal.addConstr((((self.K - 1.0)/self.K)*Ahat) @ ws + vs >= 1.0, 
                                name='Ahat Constraints')
            modprimal.addConstr(A @ ws >= self.eps, name='A Constraints')
            modprimal.optimize()
            betas = np.array(modprimal.getAttr(GRB.Attr.Pi)[:n])
            gammas = np.array(modprimal.getAttr(GRB.Attr.Pi)[n:n+n])
            
            return ws.X, betas, gammas
        
        def printRules(self, indices=[]):
            
            if (len(indices) == 0):
                indices = self.rules.keys()
            elif (np.max(indices) > len(self.rules)):
                print('\n### WARNING! printRules() ###\n')
                print('Do not have that many rules')                
                return
            
            for indx in indices:
                rule = self.rules[indx]
                print('RULE %d:' % (indx))
                if (rule == 'NR'):
                    print('==> No Rule: Set Majority Class')
                else:
                    rule.printRule()
                print('Class: %.0f' % rule.label)
                print('Scaled rule weight: %.4f\n' % rule.weight)
    
        def printWeights(self, indices=[]):
    
            if (len(indices) == 0):
                indices = self.rules.keys()
            elif (np.max(indices) > len(self.rules)):
                print('\n### WARNING!: printWeights() ###\n')
                print('Do not have that many rules')                
                return
            
            for indx in indices:
                rule = self.rules[indx]
                print('RULE %d:' % (indx))
                print('Class: %.0f' % rule.label)
                print('Scaled rule weight: %.4f\n' % rule.weight)
                
        def getWeights(self, indices=[]):
    
            if (len(indices) == 0):
                indices = self.rules.keys()
            elif (np.max(indices) > len(self.rules)):
                print('\n### WARNING!: getWeights() ###\n')
                print('Do not have that many rules')                
                return None 
            
            return [self.rules[indx].weight for indx in indices]    
                
        def predict(self, X, indices=[]):       
            
            if (len(indices) == 0):
                indices = self.rules.keys()
            elif (np.max(indices) > len(self.rules)):
                print('\n### WARNING!: predict() ###\n')
                print('Do not have that many rules')
                return None  
    
            self.missedXvals = []        
            self.numOfMissed = 0
            
            startTime = time.time()
            # TODO: Can be done in parallel
            returnPrediction = []
            for x0 in X:
                sumClassWeights = np.zeros(self.K)
                for indx in indices:
                    rule = self.rules[indx]
                    if (rule != 'NR'):
                        if(rule.checkRule(x0)):
                            lab2int = self.labelToInteger[rule.label]
                            sumClassWeights[lab2int] += rule.weight
                
                if (np.sum(sumClassWeights) == 0):
                    # Unclassified test sample
                    self.numOfMissed += 1
                    self.missedXvals.append(x0)
                    getClass = self.fittedDTs[0].predict(x0.reshape(1, -1))[0]
                    returnPrediction.append(getClass)
                else:
                    sel_label_indx = np.argmax(sumClassWeights)
                    int2lab = self.integerToLabel[sel_label_indx]
                    returnPrediction.append(int2lab)
    
            endTime = time.time()
            self.predictTime = endTime - startTime
            
            return returnPrediction
    
        def getAvgRuleLength(self):
            
            return np.mean([rule.length() for rule in self.rules.values()])
            
        def getNumOfRules(self):
            
            return len(self.rules)
    
        def getInitNumOfRules(self):
            
            return self.initNumOfRules
    
        def getNumOfMissed(self):
            
            return self.numOfMissed
    
        def getFitTime(self):
            
            return self.fitTime
    
        def getPredictTime(self):
            
            return self.predictTime
        
        def fit(self, X, y):
            
            if (len(self.cols) != 0):
                self._cleanup()
       
            startTime = time.time()
            
            self._preprocess(X, y)
            self._getMatrices(X, y)
            ws, betas, gammas = self._solvePrimal()
            self._fillRules(ws)
            
            endTime = time.time()
            
            self.fitTime = endTime - startTime