import datetime

from sklearn.model_selection import cross_validate
from sklearn.tree import DecisionTreeClassifier

import config
from classifiers.common import print_report, print_cv_report, custom_scorer
from classifiers.common import write_model_to_disk, read_model_from_disk, test_classifier, train_classifier
from learning_curve import plot_learning_curve


def run_decision_tree_classifier(vectorizer, x_train, yy_train, xx_test, yy_test, **kwargs):
    run_type = kwargs.get("run_type")
    ngram_range = kwargs.get("ngram_range")
    analyzer = kwargs.get("analyzer")
    load_clf = kwargs.get("load_clf")
    conf_matrix = kwargs.get("conf_matrix")
    t_prefix = datetime.datetime.now().strftime('%s')

    _duration_train = 0
    if run_type == "default":
        if not config.is_silent:
            print("Cross Validation!: Using {:,} train documents".format(len(yy_train)))
        clf = DecisionTreeClassifier()
        cv_results = cross_validate(clf, x_train, yy_train, cv=3, scoring=custom_scorer)
        print_cv_report(cv_results, **kwargs)
    elif run_type == "eval":
        # Train Classifier
        if load_clf is not None:
            # load clf
            clf = read_model_from_disk("models/" + load_clf)
        else:
            # train clf
            clf = DecisionTreeClassifier()
            _duration_train = train_classifier(clf, x_train, yy_train)
            write_model_to_disk(clf, "dst_clf", analyzer, ngram_range, t_prefix)

        # Test
        X_test = vectorizer.transform(xx_test)
        _predicted_target, _duration_test = test_classifier(clf, X_test)
        # Print
        if not config.is_silent:
            print("Classifier Score: {:.2f}".format(clf.score(X_test, yy_test)))
        print_report(_predicted_target, yy_test, _duration_train, _duration_test, conf_matrix, **kwargs)
        # if not kwargs.get("no_cat_reports"):
        #     for top_category in categories_list:
        #         print("Testing for {}".format(top_category))
        #         new_test_data, new_test_target = clearup_categories(xx_test, yy_test, categories,
        #                                                             top_category)
        #         # Test for top category
        #         X_test = vectorizer.transform(new_test_data)
        #         _predicted_target, _duration_test = test_classifier(clf, X_test)
        #         print_report(_predicted_target, new_test_target, _duration_train, _duration_test, top_category)

    elif run_type == "learning_curve":
        # create new vectorizer and classifer
        clf = DecisionTreeClassifier()
        plot_learning_curve(clf, "DecisionTree Classifier", x_train, yy_train, cv=3)
        # Test
        X_test = vectorizer.transform(xx_test)
        _predicted_target, _duration_test = test_classifier(clf, X_test)
        # persist model
        vect_model = {"vectorizer": vectorizer,
                      "Xtrain": x_train}
        write_model_to_disk(vect_model, "vec_lc", analyzer, ngram_range, t_prefix)
        write_model_to_disk(clf, "dst_lc_clf", analyzer, ngram_range, t_prefix)
    else:
        raise Exception("Unsupported `run_type`!")
