import datetime

import matplotlib.pyplot as plt
from sklearn.metrics import plot_confusion_matrix
from sklearn.model_selection import cross_validate
from sklearn.naive_bayes import MultinomialNB

import config
from classifiers.common import print_report, print_cv_report, custom_scorer
from classifiers.common import write_model_to_disk, read_model_from_disk, test_classifier, train_classifier
from learning_curve import plot_learning_curve


def create_mnb_clf():
    return MultinomialNB(alpha=0.01)


def run_naive_classifier(vectorizer, x_train, yy_train, xx_test, yy_test, **kwargs):
    run_type = kwargs.get("run_type")
    ngram_range = kwargs.get("ngram_range")
    analyzer = kwargs.get("analyzer")
    conf_matrix = kwargs.get("conf_matrix")
    load_clf = kwargs.get("load_clf")
    category = kwargs.get("category")
    categories = kwargs.get("categories")
    t_prefix = datetime.datetime.now().strftime('%s')

    report_categories = categories
    if isinstance(categories[0], list):
        report_categories = [item for sublist in categories for item in sublist[:1]]

    # naive_classifier = create_mnb_clf()
    # vectorizer = get_tfidf_vectorizer(**kwargs)

    # pipeline = Pipeline(
    #     [('tfidf', vectorizer),
    #      ('clf', naive_classifier)
    #      ])
    if run_type == "default":
        if not config.is_silent:
            print("Cross Validation!: Using {:,} train documents".format(len(yy_train)))
        clf = create_mnb_clf()
        cv_results = cross_validate(clf, x_train, yy_train, cv=3, scoring=custom_scorer)
        print_cv_report(cv_results, **kwargs)
    elif run_type == "dump":
        clf = create_mnb_clf()
        _duration_train = train_classifier(clf, x_train, yy_train)
        write_model_to_disk(clf, str(category) + "_mnb_clf", analyzer, ngram_range, t_prefix)
    elif run_type == "eval":
        # Train Classifier
        if load_clf is not None:
            # load clf
            clf = read_model_from_disk("models/" + load_clf)
            _duration_train = 0
        else:
            # train clf
            clf = create_mnb_clf()
            _duration_train = train_classifier(clf, x_train, yy_train)
            write_model_to_disk(clf, str(category) + "_mnb_clf", analyzer, ngram_range, t_prefix)

        # Test
        X_test = vectorizer.transform(xx_test)
        _predicted_target, _duration_test = test_classifier(clf, X_test)
        # Print
        if not config.is_silent:
            print("Classifier Score: {:.2f}".format(clf.score(X_test, yy_test)))
        print_report(_predicted_target, yy_test, _duration_train, _duration_test, conf_matrix, **kwargs)
        plt.figure(figsize=(20,20))
        disp = plot_confusion_matrix(clf, X_test, yy_test, display_labels=report_categories, cmap=plt.cm.Blues,
                                     normalize=None)
        disp.ax_.set_title("Confusion matrix for " + " vs. ".join(report_categories))
        plt.tight_layout()
        plt.savefig("images/" + str(category) + "_confusion_matrix_" + datetime.datetime.now().strftime(
            '%s') + ".png", dpi=(250), bbox_inches='tight')
        # if not kwargs.get("no_cat_reports"):
        #     for top_category in categories_list:
        #         print("Testing for {}".format(top_category))
        #         new_test_data, new_test_target = clearup_categories(xx_test, yy_test, categories,
        #                                                             top_category)
        #         # Test for top category
        #         X_test = vectorizer.transform(new_test_data)
        #         _predicted_target, _duration_test = test_classifier(clf, X_test)
        #         print_report(_predicted_target, new_test_target, _duration_train, _duration_test, top_category)
    elif run_type == "learning_curve":
        # create new vectorizer and classifer
        clf = create_mnb_clf()
        plot_learning_curve(clf, "MultinomialNB Classifier " + str(category), x_train, yy_train, cv=3, category=category)
        # # Test
        # X_test = vectorizer.transform(xx_test)
        # _predicted_target, _duration_test = test_classifier(clf, X_test)
        # # persist model
        # vect_model = {"vectorizer": vectorizer,
        #               "Xtrain": x_train}
        # write_model_to_disk(vect_model, "vec_lc", analyzer, ngram_range, t_prefix)
        # write_model_to_disk(clf, "sgd_lc_clf", analyzer, ngram_range, t_prefix)
    else:
        raise Exception("Unsupported `run_type`!")
