import sys
import re
import os
import subprocess
from sklearn import tree
from sklearn.ensemble import RandomForestClassifier
import graphviz 
from random import *

# need a newer version of discrete one

def render_tree(clf,id=0):
    dot_data = tree.export_graphviz(clf, out_file=None)
    graph = graphviz.Source(dot_data)
    graph.render("iris"+str(id))

def render_each_clf(list_clf):
    i = 0
    while i < len(list_clf):
        render_tree(list_clf[i],i)
        i += 1

def decision_tree(sample_,label_,max_depth=-1,cascade=False):
    # use random_state to get deterministic behaviour
    if cascade:
        clf = tree.DecisionTreeClassifier(criterion='entropy',random_state=1,max_depth=2)
    elif max_depth == -1:
        clf = tree.DecisionTreeClassifier(criterion='entropy',random_state=1,max_depth=None)
    else:
        clf = tree.DecisionTreeClassifier(criterion='entropy',random_state=1,max_depth=3)
    clf = clf.fit(sample_, label_)
    render_tree(clf)
    return clf

def prune_trainingset(sample,label,list_of_TP):

    temp_sample = []
    temp_label = []
    i = 0
    while i < len(sample):
        if i not in list_of_TP:
            temp_sample.append(sample[i])
            temp_label.append(label[i])
        else:
            pass
        i += 1
    return temp_sample,temp_label

def binary_evaluate_TP_TN_result(clf,orisample,orilabel,threshold):
    TP_commit = []
    TN_commit = []
    FP_commit = []
    FN_commit = []
    i = 0
    while i < len(orisample):
        print "proba", clf.predict_proba([orisample[i]])[0][0], 1-threshold, orilabel[i]
        print "decision_path", (clf.decision_path([orisample[i]])).getnnz()
        print clf.decision_path([orisample[i]])
        if clf.predict_proba([orisample[i]])[0][0] > 1-threshold and orilabel[i] == "0": 
            #print "correct"
            TN_commit.append(i)
        elif clf.predict_proba([orisample[i]])[0][0] <= 1-threshold and orilabel[i] == "1": 
            #print "correct"
            TP_commit.append(i)
        elif clf.predict_proba([orisample[i]])[0][0] > 1-threshold and orilabel[i] == "1": 
            #print "correct"
            FN_commit.append(i)
        elif clf.predict_proba([orisample[i]])[0][0] <= 1-threshold and orilabel[i] == "0": 
            #print "correct"
            FP_commit.append(i)
        else:
            pass
        i = i + 1
    print "TP_commit",TP_commit
    print "FP_commit",FP_commit
    print "TN_commit",TN_commit
    print "FN_commit",FN_commit
    return TP_commit,FP_commit,TN_commit,FN_commit

def discrete_evaluate_TP_TN_result(clf,orisample,orilabel,threshold):
    TP_commit = []
    FN_commit = []
    i = 0
    while i < len(orisample):
        j = 0
        while j < len(clf.predict_proba([orisample[i]])[0]):
            if clf.predict_proba([orisample[i]])[0][j] > threshold:
                break
            j += 1
        if float(j+1) == orilabel[i]:
            TP_commit.append(i)
        elif float(j+1) != orilabel[i]:
            FN_commit.append(i)
        else:
            pass
        i += 1
    
    return TP_commit,FN_commit


################################################################################
#
# Main Cascading Decision Tree Algorithm
#
################################################################################

#Todo support both less sample and less label

                          
# This one has no maximum depth.
def cascading_dt_training(sample,label,max_depth=2,threshold=0.8,loop=5,earlybreak=1,discrete=True):
    number_loop = 0
    history_of_TP = []
    classifier_buffer = []
    TP_buffer = []
    TP_buffer.append(([],len(sample[0])))

    while (number_loop < loop):
        new_of_TP = []
        number_loop += 1
        pruned_sample, pruned_label = prune_trainingset(sample, label, history_of_TP)
        clf = decision_tree(pruned_sample, pruned_label,max_depth,cascade=True)
        classifier_buffer.append(clf)
        if not discrete:
            pass
            #list_of_TP, list_of_FP, list_of_TN, list_of_FN = evaluate_TP_TN_result(clf, sample, label, threshold)
        else:
            list_of_TP, list_of_FP, list_of_TN, list_of_FN = binary_evaluate_TP_TN_result(clf, sample, label, threshold)
        print "list_of_TP", list_of_TP
        print "list_of_FP", list_of_FP
        for true_positive in list_of_TP:
            if true_positive not in history_of_TP:
                new_of_TP.append(true_positive)
        
        if earlybreak == 1:
            if (new_of_TP == []) and (number_loop > 1): #early termination
                break
        else:
            pass
        
        for true_positive in new_of_TP:
            history_of_TP.append(true_positive)
        print "history_of_TP", history_of_TP    
        TP_buffer.append((history_of_TP, len(sample[0])))

    if number_loop == loop: #if reach the maximum depth of loop, record the last dt 
        pruned_sample, pruned_label = prune_trainingset(sample, label, history_of_TP)
        clf = decision_tree(pruned_sample, pruned_label,max_depth,cascade=True)
        classifier_buffer.append(clf)
    else:
        pass

    render_each_clf(classifier_buffer)
    return classifier_buffer
    
################################################################################
#
# End of Cascading Decision Tree Algorithm
#
################################################################################


################################################################################
#
# How to Use Cascading Decision Tree Algorithm
#
################################################################################

"""
cas_dt will return to you a list of classifiers.
If the testing sample's probability is greater than the threshold, 
It is a positive.
Otherwise, go for the next classifier.
If the last classifier is used, then the threshold is set manually = 0.5
"""

def original_dt_testing(single_sample,clf,threshold=0.8):
    if clf.predict_proba([single_sample])[0][0] <= 1 - threshold:
        return 1, (clf.decision_path([single_sample])).getnnz() - 1
    else:
        return 0, (clf.decision_path([single_sample])).getnnz() - 1

def cascading_dt_testing(single_sample,classifier_buffer_,threshold=0.8):
    number_tree = 0
    while number_tree < len(classifier_buffer_):
        clf = classifier_buffer_[number_tree]
        #print "nnn",clf.predict_proba([single_sample])[0]
        if clf.predict_proba([single_sample])[0][0] <= 1 - threshold:
            print "level1",clf.decision_path([single_sample])
            #It is positive
            #Stop immediately and output the result
            return 1, (clf.decision_path([single_sample])).getnnz() - 1, number_tree + 1
        elif number_tree == len(classifier_buffer_) - 1:
            #It is negative and we do not have more tree
            #reach the last tree
            #output the result
            #use normal threshold = 0.5
            if clf.predict_proba([single_sample])[0][0] <= 1 - threshold:
                print "level2",clf.decision_path([single_sample])
                return 1, (clf.decision_path([single_sample])).getnnz() - 1, number_tree + 1
            else:
                print "level3",clf.decision_path([single_sample])
                return 0, (clf.decision_path([single_sample])).getnnz() - 1, number_tree + 1
        else:
            #It is negative
            #go to the next tree
            pass
        number_tree += 1

   
