# Aug.31
import os
import shutil
import re
import sys
import time
import random

import cas_dt
import display
import crossvalid

from os.path import exists

def clean_enter(str_):
    if len(str_) > 0:
        if str_[-1] == '\n':
            if str_ == "?":
                return -1
            return str_[:-1]
    if str_ == "?":
        return -1
    return str_
        
def clean_questionmark(list_):
    return_list = []
    for item in list_:
        if item == "?":
            item = "-1"
        return_list.append(item)
    return return_list

def main():
    if os.path.isfile("result.csv"):
        os.remove("result.csv")
    if os.path.isfile("current.data"):
        os.remove("current.data")

    with open("result.csv", 'ab') as file__:
        file__.write("Explanation Depth,Accuracy,Runtime,TP,TN,FP,FN,Precision,Recall,F-1 Score\n")
    count = 0
    dir_path = os.path.dirname(os.path.realpath(__file__))
    respath = dir_path + "/res"
    
    if os.path.exists(respath):
        shutil.rmtree(respath)
    if not os.path.exists(respath):
        os.makedirs(respath)
    list_dirs = os.walk(dir_path + "/data")
    for root, dirs, files in list_dirs:
        for file_ in files:
            if file_ == ".DS_Store":
                pass
            else:
                if ("breast_cancer_cleaned.data" in os.path.join(root,file_)):
                    file_name = file_
                    list_raw_data = open(os.path.join(root,file_)).readlines()
                    #random.shuffle(list_raw_data)
                    with open("current.data", 'ab') as file__:
                        for item in list_raw_data:
                            file__.write(item)
                    sample = []
                    label = []
                    for line_sample in list_raw_data:
                        if len(line_sample) > 0:
                            sample_ = clean_questionmark(re.split(',', line_sample)[:-1])
                            label_ = clean_enter(re.split(',', line_sample)[-1])
                            sample.append(sample_)
                            label.append(label_)

                    crossvalid.cross_validation(sample,label,file_name,0.8)
                
    display.pretty_print()


def test(clf,testing_set,testing_label,start):
    res = []
    i = 0
    accuracy = 0
    correct = 0.0
    wrong = 0.0
    TP = 0.0
    TN = 0.0
    FP = 0.0
    FN = 0.0
    while i < len(testing_set):
        #print "look",[testing_set[i]]
        #print "predict",clf.predict([testing_set[i]])
        #print "conti",clf.predict_proba([testing_set[i]])[0]
        #print "real",testing_label[i]
        predict_res, decision_depth = cas_dt.original_dt_testing(testing_set[i],clf,threshold=0.8)
        #res.append((predict_res, decision_depth))
        print "predict_res",predict_res,"decision_depth",decision_depth
        if int(predict_res) == int(testing_label[i]) and int(testing_label[i]) == 1:
            correct += 1
            TP += 1
            res.append((predict_res, decision_depth))
        elif int(predict_res) == int(testing_label[i]) and int(testing_label[i]) == 0:
            correct += 1
            TN += 1
        elif int(predict_res) != int(testing_label[i]) and int(testing_label[i]) == 1:
            wrong += 1
            FN += 1
        elif int(predict_res) != int(testing_label[i]) and int(testing_label[i]) == 0:
            wrong += 1
            FP += 1
            res.append((predict_res, decision_depth))
        else:
            pass

        i = i + 1
    with open("result.csv", 'ab') as file__:
        #file__.write("Explanation Depth,Accuracy,TP,TN,FP,FN,Precision,Recall,F-1 Score\n")
        file__.write(str(compute_ave(res)) + "," + str((correct)/(wrong+correct)) + "," + str(time.time() - start) + "," + str(TP) + "," + str(TN) + "," + str(FP) + "," + str(FN) + "," + str(TP/(TP+FP)) + "," + str(TP/(TP+FN)) + "," + str(2 * ( ((TP/(TP+FP)) * (TP/(TP+FN))) / ((TP/(TP+FP)) + (TP/(TP+FN))))) + " \n")
    print "correct",correct
    print "wrong",wrong
    print "accuracy",(correct)/(wrong+correct)
    print "TP",TP
    print "TN",TN
    print "FP",FP
    print "FN",FN
    print "precision",TP/(TP+FP)
    print "recall",TP/(TP+FN)
    print "f1 score",2 * ( ((TP/(TP+FP)) * (TP/(TP+FN))) / ((TP/(TP+FP)) + (TP/(TP+FN))))
    print "average_depth",compute_ave(res)

def compute_ave(list_):
    temp = 0.0
    for item in list_:
        temp += item[1]
    return temp/len(list_)


def cas_test(clf_buffer,testing_set,testing_label,start):
    res = []
    i = 0
    accuracy = 0
    correct = 0.0
    wrong = 0.0
    TP = 0.0
    TN = 0.0
    FP = 0.0
    FN = 0.0
    while i < len(testing_set):
        #print "look",[testing_set[i]]
        #print "real",testing_label[i]
        predict_res, decision_depth, id_clf = cas_dt.cascading_dt_testing(testing_set[i],clf_buffer,threshold=0.8)
        #res.append((predict_res, decision_depth))
        print "predict_res",predict_res,"decision_depth",decision_depth,"id_clf",id_clf
        if int(predict_res) == int(testing_label[i]) and int(testing_label[i]) == 1:
            correct += 1
            TP += 1
            res.append((predict_res, decision_depth))
        elif int(predict_res) == int(testing_label[i]) and int(testing_label[i]) == 0:
            correct += 1
            TN += 1
        elif int(predict_res) != int(testing_label[i]) and int(testing_label[i]) == 1:
            wrong += 1
            FN += 1
        elif int(predict_res) != int(testing_label[i]) and int(testing_label[i]) == 0:
            wrong += 1
            FP += 1
            res.append((predict_res, decision_depth))
        else:
            pass
        i = i + 1
    with open("result.csv", 'ab') as file__:
        #file__.write("Explanation Depth,Accuracy,TP,TN,FP,FN,Precision,Recall,F-1 Score\n")
        file__.write(str(compute_ave(res)) + "," + str((correct)/(wrong+correct)) + "," + str(time.time() - start) + "," + str(TP) + "," + str(TN) + "," + str(FP) + "," + str(FN) + "," + str(TP/(TP+FP)) + "," + str(TP/(TP+FN)) + "," + str(2 * ( ((TP/(TP+FP)) * (TP/(TP+FN))) / ((TP/(TP+FP)) + (TP/(TP+FN))))) + " \n")
    
    print "correct",correct
    print "wrong",wrong
    print "accuracy",correct/(wrong+correct)
    print "TP",TP
    print "TN",TN
    print "FP",FP
    print "FN",FN
    print "precision",TP/(TP+FP)
    print "recall",TP/(TP+FN)
    print "f1 score",2 * ( ((TP/(TP+FP)) * (TP/(TP+FN))) / ((TP/(TP+FP)) + (TP/(TP+FN))))
    print "average_depth",compute_ave(res)



if __name__ == '__main__':
    main()