import datetime

from sklearn.ensemble import VotingClassifier, RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV, cross_validate
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.tree import DecisionTreeClassifier

import config
from classifiers.common import write_model_to_disk, read_model_from_disk, test_classifier, train_classifier, \
    print_report, print_cv_report, custom_scorer
from classifiers.tfidf import create_tfidf_vectorizer
from learning_curve import plot_learning_curve


def create_rfc_clf():
    return RandomForestClassifier(n_jobs=-1)


def run_rfc_classifier(vectorizer, x_train, yy_train, xx_test, yy_test, **kwargs):
    conf_matrix = kwargs.get("conf_matrix")
    run_type = kwargs.get("run_type")
    ngram_range = kwargs.get("ngram_range")
    analyzer = kwargs.get("analyzer")
    load_clf = kwargs.get("load_clf")
    t_prefix = datetime.datetime.now().strftime('%s')

    if run_type == "default":
        if not config.is_silent:
            print("Cross Validation!: Using {:,} train documents".format(len(yy_train)))
        clf = create_rfc_clf()
        cv_results = cross_validate(clf, x_train, yy_train, cv=3, scoring=custom_scorer)
        print_cv_report(cv_results, **kwargs)
    elif run_type == "eval":
        # ---------Classifier------
        if load_clf is not None:
            # load clf
            clf = read_model_from_disk("models/" + load_clf)
            _duration_train = 0
        else:
            # train clf
            clf = create_rfc_clf()
            _duration_train = train_classifier(clf, x_train, yy_train)
            write_model_to_disk(clf, "sgd_clf", analyzer, ngram_range, t_prefix)

        # Test
        x_test = vectorizer.transform(xx_test)
        _predicted_target, _duration_test = test_classifier(clf, x_test)
        # Print
        if not config.is_silent:
            print("Classifier Score: {:.2f}".format(clf.score(x_test, yy_test)))
        print_report(_predicted_target, yy_test, _duration_train, _duration_test, conf_matrix, **kwargs)
        # if not kwargs.get("no_cat_reports"):
        #     for top_category in categories_list:
        #         print("Testing for {}".format(top_category))
        #         new_test_data, new_test_target = clearup_categories(xx_test, yy_test, categories,
        #                                                             top_category)
        #         # Test for top category
        #         X_test = vectorizer.transform(new_test_data)
        #         _predicted_target, _duration_test = test_classifier(clf, X_test)
        #         print_report(_predicted_target, new_test_target, _duration_train, _duration_test, top_category)
    # plot_coefficients(sdg_classifier, vectorizer, categories)
    elif run_type == "voting":
        # ensemble classifier
        classifiers = [
            ('SGD', create_rfc_clf()),
            ('NBMultinomial', MultinomialNB(alpha=0.01)),
            ('DecisionTree', DecisionTreeClassifier()),
            ('LogisticRegression', LogisticRegression()),
        ]
        if not config.is_silent:
            print("vectorizing...")
        if not config.is_silent:
            print("training model...")
        clf = VotingClassifier(classifiers, n_jobs=-1)
        clf.fit(x_train, yy_train)
        print(clf)
    elif run_type == "grid_search":
        params = {
            # 'tfidf__ngram_range': (
            #     (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11), (2, 2), (2, 3),
            #     (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (2, 9), (2, 10), (2, 11), (2, 12)),
            # 'tfidf__stop_words': (None, text.ENGLISH_STOP_WORDS),
            'clf__alpha': (1.0000000000000001e-05, 9.9999999999999995e-07),
            'clf__max_iter': (10, 50, 80),
            'clf__penalty': ('l2', 'elasticnet'),
        }

        pipeline = Pipeline(
            [('tfidf', create_tfidf_vectorizer(**kwargs)),
             ('clf', create_rfc_clf())
             ])

        grid = GridSearchCV(
            pipeline,  # pipeline from above
            params,  # parameters to tune via cross validation
            refit=True,  # fit using all available data at the end, on the best found param combination
            n_jobs=-1,  # number of cores to use for parallelization; -1 for "all cores"
            scoring='accuracy',  # what score are we optimizing?
        )
        # Train
        load_fs = kwargs.get("load_fs")
        if load_fs is None:
            clf, _duration_train = train_classifier(grid, x_train, yy_train)
            # Write to Disk
            write_model_to_disk(grid, "sgd_grid", analyzer, ngram_range, t_prefix)
        else:
            grid = read_model_from_disk("models/" + load_fs)
            _duration_train = 0

        # Test
        _predicted_target, _duration_test = test_classifier(grid, xx_test)
        # Print
        print_report(_predicted_target, yy_test, _duration_train, _duration_test, conf_matrix, **kwargs)
    elif run_type == "learning_curve":
        # create new vectorizer and classifer
        clf = create_rfc_clf()
        plot_learning_curve(clf, "SGD Classifier", x_train, yy_train, cv=3)
        # persist model
        vect_model = {"vectorizer": vectorizer,
                      "Xtrain": x_train}
        write_model_to_disk(vect_model, "vec", analyzer, ngram_range, t_prefix)
        write_model_to_disk(clf, "sgd_lc_clf", analyzer, ngram_range, t_prefix)
    else:
        raise Exception("Unsupported `run_type`!")
