import datetime

import matplotlib.pyplot as plt
from sklearn.ensemble import VotingClassifier
from sklearn.linear_model import SGDClassifier, LogisticRegression
from sklearn.metrics import plot_confusion_matrix
from sklearn.model_selection import GridSearchCV, cross_validate
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.tree import DecisionTreeClassifier
import numpy as np
import config
from classifiers.common import write_model_to_disk, read_model_from_disk, \
    test_classifier, train_classifier, \
    print_report, print_cv_report, custom_scorer, dump_features
from classifiers.tfidf import create_tfidf_vectorizer
from learning_curve import plot_learning_curve


def create_sgd_clf(max_iter, penalty='l2', loss='hinge', alpha=0.0001):
    # try 'penalty: elasticnet, l2, loss: modified_huber, hinge, max-iter=1000'
    return SGDClassifier(alpha=alpha, loss=loss, penalty=penalty, random_state=42, max_iter=max_iter, tol=None,
                         n_jobs=-1)


# 0.0010
# 0.0001

def run_sgd_classifier(vectorizer, x_train, yy_train, xx_test, yy_test, **kwargs):
    conf_matrix = kwargs.get("conf_matrix")
    max_iter = kwargs.get("max_iter")
    run_type = kwargs.get("run_type")
    ngram_range = kwargs.get("ngram_range")
    analyzer = kwargs.get("analyzer")
    category = kwargs.get("category")
    load_clf = kwargs.get("load_clf")
    alpha = kwargs.get("alpha")
    loss = kwargs.get("loss") if kwargs.get("loss") != "" else 'hinge'
    penalty = kwargs.get("penalty") if kwargs.get("penalty") != "" else 'l2'
    t_prefix = datetime.datetime.now().strftime('%s')
    categories = kwargs.get("categories")

    report_categories = categories
    if isinstance(categories[0], list):
        report_categories = [item for sublist in categories for item in sublist[:1]]

    if run_type == "default":
        if not config.is_silent:
            print("Cross Validation!: Using {:,} train documents".format(len(yy_train)))
            print("penalty: %s, loss: %s, alpha: %f" % (penalty, loss, alpha))
        clf = create_sgd_clf(max_iter, penalty=penalty, loss=loss, alpha=alpha)
        cv_results = cross_validate(clf, x_train, yy_train, cv=3, scoring=custom_scorer)
        print_cv_report(cv_results, **kwargs)
        _duration_train = train_classifier(clf, x_train, yy_train)
        dump_features(vectorizer, clf, category, str(ngram_range), x_train,
                      yy_train)
    elif run_type == "dump":
        clf = create_sgd_clf(max_iter, penalty=penalty, loss=loss, alpha=alpha)
        _duration_train = train_classifier(clf, x_train, yy_train)
        write_model_to_disk(clf, str(category) + "_sgd_clf", analyzer, ngram_range, t_prefix)
        dump_features(vectorizer, clf, category, str(ngram_range), x_train,
                      yy_train)
    elif run_type == "eval":
        # ---------Classifier------
        if load_clf is not None:
            # load clf
            clf = read_model_from_disk("models/" + load_clf)
            _duration_train = 0
        else:
            # train clf
            clf = create_sgd_clf(max_iter, penalty=penalty, loss=loss, alpha=alpha)
            _duration_train = train_classifier(clf, x_train, yy_train)
            write_model_to_disk(clf, str(category) + "_sgd_clf", analyzer, ngram_range, t_prefix)

        # Test
        x_test = vectorizer.transform(xx_test)
        _predicted_target, _duration_test = test_classifier(clf, x_test)
        print_report(_predicted_target, yy_test, _duration_train, _duration_test, conf_matrix, **kwargs)
        plt.figure(figsize=(20,20))
        disp = plot_confusion_matrix(clf, x_test, yy_test, display_labels=report_categories, cmap=plt.cm.Blues,
                                     normalize=None)
        disp.ax_.set_title("Confusion matrix for " + " vs. ".join(report_categories))
        # disp.ax_.set_title("Confusion Matrix for All 23 Countries")
        tick_marks = np.arange(len(report_categories))
        plt.xticks(tick_marks, report_categories, rotation=90)
        plt.yticks(tick_marks, report_categories)
        plt.tight_layout()
        plt.savefig("images/" + str(category) + "_confusion_matrix_" + datetime.datetime.now().strftime(
            '%s') + ".png", dpi=(250), bbox_inches='tight')
        # Print
        # if not config.is_silent:
        #     print("Classifier Score: {:.2f}".format(clf.score(x_test, yy_test)))

        # if not kwargs.get("no_cat_reports"):
        #     for top_category in categories_list:
        #         print("Testing for {}".format(top_category))
        #         new_test_data, new_test_target = clearup_categories(xx_test, yy_test, categories,
        #                                                             top_category)
        #         # Test for top category
        #         X_test = vectorizer.transform(new_test_data)
        #         _predicted_target, _duration_test = test_classifier(clf, X_test)
        #         print_report(_predicted_target, new_test_target, _duration_train, _duration_test, top_category)
        # plot_coefficients(sdg_classifier, vectorizer, categories)
        #
        # from sklearn.feature_selection import SelectKBest, chi2
        # sel = SelectKBest(chi2, k=100)
        # X_new = sel.fit_transform(x_test, yy_test)
        # X_indices = sel.get_support(indices=True)
        # X_xx = vectorizer.get_feature_names()
        # X_yy = np.array([[X_xx[i], clf.coef_[0][i]] for i in X_indices])
        # import matplotlib.pyplot as plt
        # plt.figure()
        # plt.bar([x for x in range(len(X_yy))], X_yy[:, 1])
        # plt.xticks([x for x in range(len(X_yy))], X_yy[:, 0])
        # plt.show()
    elif run_type == "voting":
        # ensemble classifier
        classifiers = [
            ('SGD', create_sgd_clf(max_iter)),
            ('NBMultinomial', MultinomialNB(alpha=0.01)),
            ('DecisionTree', DecisionTreeClassifier()),
            ('LogisticRegression', LogisticRegression()),
        ]
        if not config.is_silent:
            print("vectorizing...")
        if not config.is_silent:
            print("training model...")
        clf = VotingClassifier(classifiers, n_jobs=-1)
        clf.fit(x_train, yy_train)
        print(clf)
    elif run_type == "grid_search":
        params = {
            # 'tfidf__ngram_range': (
            #     (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11), (2, 2), (2, 3),
            #     (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (2, 9), (2, 10), (2, 11), (2, 12)),
            # 'tfidf__stop_words': (None, text.ENGLISH_STOP_WORDS),
            'clf__alpha': (1.0000000000000001e-05, 9.9999999999999995e-07),
            'clf__max_iter': (10, 50, 80),
            'clf__penalty': ('l2', 'elasticnet'),
        }

        pipeline = Pipeline(
            [('tfidf', create_tfidf_vectorizer(**kwargs)),
             ('clf', create_sgd_clf(max_iter))
             ])

        grid = GridSearchCV(
            pipeline,  # pipeline from above
            params,  # parameters to tune via cross validation
            refit=True,  # fit using all available data at the end, on the best found param combination
            n_jobs=-1,  # number of cores to use for parallelization; -1 for "all cores"
            scoring='accuracy',  # what score are we optimizing?
        )
        # Train
        load_fs = kwargs.get("load_fs")
        if load_fs is None:
            clf, _duration_train = train_classifier(grid, x_train, yy_train)
            # Write to Disk
            write_model_to_disk(grid, str(category) + "_sgd_grid", analyzer, ngram_range, t_prefix)
        else:
            grid = read_model_from_disk("models/" + load_fs)
            _duration_train = 0

        # Test
        _predicted_target, _duration_test = test_classifier(grid, xx_test)
        # Print
        print_report(_predicted_target, yy_test, _duration_train, _duration_test, conf_matrix, **kwargs)
    elif run_type == "learning_curve":
        # create new vectorizer and classifer
        clf = create_sgd_clf(max_iter, penalty=penalty, loss=loss, alpha=alpha)
        plot_learning_curve(clf, "SGD Classifier " + str(category), x_train, yy_train, cv=3, category=category)
        # # persist model
        # vect_model = {"vectorizer": vectorizer,
        #               "Xtrain": x_train}
        # write_model_to_disk(vect_model, "vec", analyzer, ngram_range, t_prefix)
        # write_model_to_disk(clf, "sgd_lc_clf", analyzer, ngram_range, t_prefix)
    else:
        raise Exception("Unsupported `run_type`!")
