import datetime
import operator
import os
from time import time

import eli5
import matplotlib.pyplot as plt
import numpy as np
from joblib import load, dump
from sklearn import metrics
from sklearn.feature_selection import SelectKBest, chi2
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, balanced_accuracy_score, \
    f1_score, make_scorer

import config
from classifiers.tfidf import create_tfidf_vectorizer
from util.categories import categories_list
from util.corpus import fetch_ice, fetch_itaki
from sklearn.inspection import permutation_importance

def train_vectorizer(vectorizer, x):
    if not config.is_silent:
        print("vectorizing train model...")
    t0 = time()
    trainX = vectorizer.fit_transform(x)
    duration_train = time() - t0
    if not config.is_silent:
        print("done in %.2fs" % duration_train)
    return trainX, duration_train


def get_vectorizer(xx_train, **kwargs):
    load_vec = kwargs.get("load_vec")
    ngram_range = kwargs.get("ngram_range")
    analyzer = kwargs.get("analyzer")
    t_prefix = kwargs.get("t_prefix")
    category = kwargs.get("category")

    _duration_train = 0
    # -------Vectorizer------
    if load_vec is not None:
        # load vec
        vect_model = read_model_from_disk("models/" + load_vec)
        vectorizer = vect_model['vectorizer']
        x_train = vect_model['x_train']
        if not config.is_silent:
            print("Vectorizer loaded from disk!")
    else:
        # train vec
        vectorizer = create_tfidf_vectorizer(**kwargs)
        x_train, t_ = train_vectorizer(vectorizer, xx_train)
        _duration_train += t_
        vect_model = {"vectorizer": vectorizer,
                      "x_train": x_train}
        write_model_to_disk(vect_model, str(category) + "_vec", analyzer, ngram_range, t_prefix)
    return x_train, _duration_train, vectorizer


def train_classifier(estimator, train_x, train_y):
    if not config.is_silent:
        print("training model...")
    t0 = time()
    estimator.fit(train_x, train_y)
    duration_train = time() - t0
    if not config.is_silent:
        print("done in %.2fs" % duration_train)
    return duration_train


def test_classifier(estimator, testX):
    if not config.is_silent:
        print("testing model...")
    t0 = time()
    predicted_target = estimator.predict(testX)
    duration_test = time() - t0
    if not config.is_silent:
        print("done in %.2fs" % duration_test)
    return predicted_target, duration_test


def write_model_to_disk(estimator, clf="", analyzer="", ngram=(0, 0), prefix="", file_name=None, force=False):
    if config.no_dump and not force:
        return
    if file_name is None:
        file_name = "models/" + clf + "_" + analyzer + "_" + str(ngram[0]) + "_" + str(
            ngram[1]) + "_" + prefix
    if not config.is_silent:
        print("MODEL: writting {} to the disk".format(file_name))
    if os.path.exists(file_name + '.joblib'):
        os.remove(file_name + '.joblib')
    dump(estimator, file_name + '.joblib')
    if not config.is_silent:
        print("DONE!")


def read_model_from_disk(file_name):
    pkl_filename = file_name + '.joblib'
    if not config.is_silent:
        print("MODEL: loading {} from disk".format(pkl_filename))
    model = load(pkl_filename)
    if not config.is_silent:
        print("DONE!")
    return model


def print_report(predicted_target, test_actual_target, duration_train, duration_test, conf_matrix=None, **kwargs):
    train_precentage = kwargs.get("train_precentage")
    strip_accents = kwargs.get("strip_accents")
    lemmatization = kwargs.get("lemmatization")
    analyzer = kwargs.get("analyzer")
    ngram_range = kwargs.get("ngram_range")
    stop_words = kwargs.get("stop_words")
    max_iter = kwargs.get("max_iter")
    trim_by_category = kwargs.get("trim_by_category")
    categories = kwargs.get("categories")

    report_categories = categories
    if isinstance(categories[0], list):
        report_categories = [item for sublist in categories for item in sublist[:1]]
    # for doc, orig_category, predic_category in zip(docs_new, docs_new_target, predicted):
    #     print('%s is predicated as %s [%r]' % (yy_train_names[orig_category], yy_train_names[predic_category], doc))
    if not config.is_silent:
        print("Accuracy: %.4f" % np.mean(predicted_target == test_actual_target))
        print("Balanced Accuracy: %.4f" % balanced_accuracy_score(test_actual_target, predicted_target))
        print(metrics.classification_report(test_actual_target, predicted_target,
                                            labels=[i for i in range(0, len(report_categories))],
                                            target_names=report_categories))
    if conf_matrix:
        print("Confusion Matrix")
        print(confusion_matrix(test_actual_target, predicted_target))
    out = metrics.classification_report(test_actual_target, predicted_target,
                                        labels=[i for i in range(0, len(report_categories))],
                                        target_names=report_categories, output_dict=True)
    header = ["train_precentage", "strip_accents", "lemmatization", "analyzer", "ngram_range",
              "stop_words", "max_iter", "languages", "trim_by_category", "training time", "testing time",
              "accuracy",
              "precision", "recall", "f1-score"]
    if not config.no_header:
        print(header)
    row = ", ".join([str(item).replace(",", "|") for item in
                     [train_precentage, strip_accents, lemmatization, analyzer, ngram_range, stop_words is not None,
                      max_iter, report_categories, trim_by_category]])
    row += ",{:.2f}s, {:.2f}s".format(duration_train, duration_test)
    field = out["weighted avg"]
    if "accuracy" not in out:
        out["accuracy"] = -0.0
    row += ",{:.4f}, {:.2f}, {:.2f}, {:.2f}".format(out["accuracy"], field["precision"], field["recall"],
                                                    field["f1-score"])
    print(row)


custom_scorer = {'accuracy': make_scorer(accuracy_score),
                 'balanced_accuracy': make_scorer(balanced_accuracy_score),
                 'precision': make_scorer(precision_score, average='macro'),
                 'recall': make_scorer(recall_score, average='macro'),
                 'f1': make_scorer(f1_score, average='macro'),
                 }


def print_cv_report(cv_results, **kwargs):
    train_precentage = kwargs.get("train_precentage")
    strip_accents = kwargs.get("strip_accents")
    lemmatization = kwargs.get("lemmatization")
    analyzer = kwargs.get("analyzer")
    ngram_range = kwargs.get("ngram_range")
    stop_words = kwargs.get("stop_words")
    max_iter = kwargs.get("max_iter")
    trim_by_category = kwargs.get("trim_by_category")
    categories = kwargs.get("categories")

    report_categories = categories
    if isinstance(categories[0], list):
        report_categories = [item for sublist in categories for item in sublist[:1]]
    # for doc, orig_category, predic_category in zip(docs_new, docs_new_target, predicted):
    #     print('%s is predicated as %s [%r]' % (yy_train_names[orig_category], yy_train_names[predic_category], doc))

    acc = "%.4f (+/- %0.2f)" % (cv_results['test_accuracy'].mean(), cv_results['test_accuracy'].std() * 2)
    b_acc = "%.4f (+/- %0.2f)" % (
        cv_results['test_balanced_accuracy'].mean(), cv_results['test_balanced_accuracy'].std() * 2)
    prec = "%.4f (+/- %0.2f)" % (cv_results['test_precision'].mean(), cv_results['test_precision'].std() * 2)
    recall = "%.4f (+/- %0.2f)" % (cv_results['test_recall'].mean(), cv_results['test_recall'].std() * 2)
    f1 = "%.4f (+/- %0.2f)" % (cv_results['test_f1'].mean(), cv_results['test_f1'].std() * 2)
    if not config.is_silent:
        print("Accuracy: " + acc)
    header = ["train_precentage", "strip_accents", "lemmatization", "analyzer", "ngram_range",
              "stop_words", "max_iter", "languages", "trim_by_category", "training time", "testing time",
              "accuracy", "balanced_accuracy", "precision", "recall", "f1-score"]
    if not config.no_header:
        print(header)
    row = ", ".join([str(item).replace(",", "|") for item in
                     [train_precentage, strip_accents, lemmatization, analyzer, ngram_range, stop_words is not None,
                      max_iter, report_categories, trim_by_category]])
    row += ",{:.2f}s, {:.2f}s".format(cv_results['fit_time'].mean(), cv_results['score_time'].mean())
    row += ",{}, {}, {}, {}, {}".format(acc, b_acc, prec, recall, f1)
    print(row)


def most_informative_feature_for_class(vectorizer, classifier, categories, n=10):
    classlabels = [i for i in range(0, len(categories))]
    for classlabel in classlabels:
        labelid = list(classifier.classes_).index(classlabel)
        feature_names = vectorizer.get_feature_names()
        topn = sorted(zip(classifier.coef_[labelid], feature_names))[-n:]

        for coef, feat in topn:
            print(categories[classlabel][0], feat, coef)

        if classifier.coef_.shape[0] == 1:
            break


def plot_coefficients(classifier, vectorizer, categories, top_features=20):
    for i, coef in enumerate(classifier.coef_):
        feature_names = vectorizer.get_feature_names()
        # coef = classifier.coef_.ravel()
        top_positive_coefficients = np.argsort(coef)[-top_features:]
        top_negative_coefficients = np.argsort(coef)[:top_features]
        top_coefficients = np.hstack([top_negative_coefficients, top_positive_coefficients])
        # create plot
        plt.figure(figsize=(15, 5))
        colors = ['red' if c < 0 else 'blue' for c in coef[top_coefficients]]
        plt.bar(np.arange(2 * top_features), coef[top_coefficients], color=colors)
        plt.title("label = " + str(categories[i][0]), fontsize=16)
        feature_names = np.array(feature_names)
        plt.xticks(np.arange(1, 1 + 2 * top_features), feature_names[top_coefficients], rotation=60, ha='right')
        plt.show()
    pass


def dump_features(vectorizer, clf, category, ngram, x_train, yy_train):
    # eli5.explain_prediction(clf, x_train, vectorized=True, vec=vectorizer)
    _duration_train = 0

    # get names
    names = [i[0] for i in categories_list[category]]
    common = set(names[0].split("_"))
    not_common = []
    for name in names:
        common.intersection_update(set(name.split("_")))
        common = common & set(name.split("_"))
    for i, name in enumerate(names):
        names[i] = " ".join([part for part in name.split("_") if
                             part not in common or part in not_common])
    for i, name in enumerate(names):
        if names[i] == "":
            names = [i[0] for i in categories_list[category]]
            break

    # results = permutation_importance(clf, x_train.todense(), yy_train, scoring='accuracy')
    # # get importance
    # importance = results.importances_mean
    # # summarize feature importance
    # for i, v in enumerate(importance):
    #     print('Feature: %0d, Score: %.5f' % (i, v))

    expl = eli5.explain_weights(clf, vec=vectorizer, target_names=names, top=(20,5))
    a = 1
    f = open("features.html", "w")
    f.write(eli5.format_as_html(expl))
    f.close()

    # print_200(category, clf, ngram, vectorizer, x_train, yy_train, names)
    # plot_100(category, clf, ngram, vectorizer, x_train, yy_train, names)

def plot_features(strip_accents, train_precentage, trim_by_category, **kwargs):
    # classifiers = [
    #     "2_sgd_clf_char_1_9_1595684678",
    #     "3_sgd_clf_char_1_10_1595685123",
    #     "4_sgd_clf_char_1_10_1595685389",
    #     "7_mnb_clf_char_1_11_1595685507",
    #     "8_sgd_clf_char_1_6_1595685542",
    #     "10_sgd_clf_char_1_7_1595685556",
    #     "11_sgd_clf_char_1_9_1595685576",
    # ]
    # vectorizers = [
    #     "2_vec_char_1_9_1595684593",
    #     "3_vec_char_1_10_1595684862",
    #     "4_vec_char_1_10_1595685189",
    #     "7_vec_char_1_11_1595685467",
    #     "8_vec_char_1_6_1595685534",
    #     "10_vec_char_1_7_1595685551",
    #     "11_vec_char_1_9_1595685565",
    # ]
    # categories = [2, 3, 4, 7, 8, 10, 11]
    # ngrams = ["1,9", "1,10", "1,10", "1,11", "1,6", "1,7", "1,9"]
    # classifiers = [
    #     "13_mnb_clf_word_1_3_1596431293",
    #     "14_mnb_clf_word_1_3_1596431325",
    #     "15_sgd_clf_char_1_6_1596431397",
    #     "16_sgd_clf_char_1_6_1596431476",
    #     "17_sgd_clf_char_1_9_1596431594",
    #     "18_mnb_clf_word_1_3_1596431706",
    #     "19_mnb_clf_word_1_2_1596431737",
    #     "22_sgd_clf_char_1_6_1596431858",
    #     "23_mnb_clf_word_1_3_1596431896",
    #     "20_mnb_clf_word_1_2_1596431753",
    #     "21_mnb_clf_char_1_6_1596431805",
    # ]
    # vectorizers = [
    #     "13_vec_word_1_3_1596431262",
    #     "14_vec_word_1_3_1596431308",
    #     "15_vec_char_1_6_1596431351",
    #     "16_vec_char_1_6_1596431421",
    #     "17_vec_char_1_9_1596431507",
    #     "18_vec_word_1_3_1596431656",
    #     "19_vec_word_1_2_1596431728",
    #     "22_vec_char_1_6_1596431824",
    #     "23_vec_word_1_3_1596431879",
    #     "20_vec_word_1_2_1596431748",
    #     "21_vec_char_1_6_1596431768",
    # ]
    # categories = [13, 14, 15, 16, 17, 18, 19, 22, 23, 20, 21]
    # ngrams = ["1,3", "1,3", "1,6", "1,6", "1,9", "1,3", "1,2", "1,6", "1,3", "1,2", "1,6"]
    # classifiers = [
    #     "28_mnb_clf_word_1_4_1596621482",
    #     "29_sgd_clf_word_1_2_1596621493",
    #     "30_mnb_clf_word_1_2_1596621518",
    #     "31_mnb_clf_word_1_4_1596621555",
    # ]
    # vectorizers = [
    #     "28_vec_word_1_4_1596621480",
    #     "29_vec_word_1_2_1596621492",
    #     "30_vec_word_1_2_1596621512",
    #     "31_vec_word_1_4_1596621532",
    # ]
    # categories = [28, 29, 30, 31]
    # ngrams = ["1,4", "1,2", "1,2", "1,4"]
    classifiers = [
        "12_sgd_clf_word_1_1_1633450769",
    ]
    vectorizers = [
        "12_vec_word_1_1_1633450767",
    ]
    categories = [12]
    ngrams = ["1,1"]
    # classifiers = [
    #     "12_sgd_clf_char_1_9_1596430964",
    # ]
    # vectorizers = [
    #     "12_vec_char_1_9_1596430579",
    # ]
    # categories = [12]
    # ngrams = ["1,9"]

    for i in range(len(vectorizers)):
        load_vec = vectorizers[i]
        load_clf = classifiers[i]
        category = categories[i]
        ngram = ngrams[i]

        clf = read_model_from_disk("models/" + load_clf)
        _duration_train = 0

        vect_model = read_model_from_disk("models/" + load_vec)
        vectorizer = vect_model['vectorizer']
        del vect_model

        if category <= 11 or category == 24 or category in [28, 29]:
            dataset = fetch_itaki(subsets=['train', 'test'], categories=categories_list[category],
                                  strip_accents=strip_accents, train_precentage=train_precentage,
                                  trim_by_category=trim_by_category, verbose=False)
        else:
            dataset = fetch_ice(subsets=['train', 'test'], categories=categories_list[category],
                                strip_accents=strip_accents, train_precentage=train_precentage,
                                trim_by_category=trim_by_category, verbose=False)

        train_dataset = dataset["train"]
        test_dataset = dataset["test"]
        del dataset  # free-memory
        xx_train = train_dataset.data
        yy_train = train_dataset.target

        x_train = vectorizer.transform(xx_train)

        # get names
        names = [i[0] for i in categories_list[category]]
        common = set(names[0].split("_"))
        not_common = []
        for name in names:
            common.intersection_update(set(name.split("_")))
            common = common & set(name.split("_"))
        for i, name in enumerate(names):
            names[i] = " ".join([part for part in name.split("_") if part not in common or part in not_common])
        for i, name in enumerate(names):
            if names[i] == "":
                names = [i[0] for i in categories_list[category]]
                break

        print_200(category, clf, ngram, vectorizer, x_train, yy_train, names)
        plot_100(category, clf, ngram, vectorizer, x_train, yy_train, names)
    print("DONE!")
    exit(0)
    # plt.show()


def print_200(category, clf, ngram, vectorizer, x_train, yy_train, names):
    sel = SelectKBest(chi2, k=500)  # select k top features meta-classifier
    sel.fit_transform(x_train, yy_train)
    X_indices = sel.get_support(indices=True)
    X_xx = vectorizer.get_feature_names()

    if (len(clf.coef_) == 1):
        X_yy = np.array([[X_xx[i], clf.coef_[0][i]] for i in X_indices])
        positives = [[x[0], float(x[1])] for x in X_yy if float(x[1]) >= 0]
        negatives = [[x[0], float(x[1])] for x in X_yy if float(x[1]) < 0]
        # sort features by value
        negatives = sorted(negatives, key=operator.itemgetter(1))
        positives = sorted(positives, key=operator.itemgetter(1), reverse=True)
        print(str(category) + "Top 200 Features for " + (" vs. ".join(names)).title())
        print("Char n-gram(" + ngram + ") features sorted by Classifier coef_ value")
        print(names[0].title())
        print("--------")
        for negative in negatives:
            print("`{}` [{:.4f}]".format(negative[0], negative[1]))
        print(names[1].title())
        print("--------")
        for positive in positives:
            print("`{}` [{:.4f}]".format(positive[0], positive[1]))
    else:
        for i, coef_ in enumerate(clf.coef_):
            X_yy = np.array([[X_xx[i], coef_[i]] for i in X_indices])
            positives = [[x[0], float(x[1])] for x in X_yy if float(x[1]) >= 0]
            negatives = [[x[0], float(x[1])] for x in X_yy if float(x[1]) < 0]
            negatives = sorted(negatives, key=operator.itemgetter(1), reverse=True)
            positives = sorted(positives, key=operator.itemgetter(1))
            print(str(category) + "Top 200 Features for " + names[i].title() + " in (" + (
                " vs. ".join(names)).title() + ")")
            print("Char n-gram(" + ngram + ") features sorted by Classifier coef_ value")
            print(names[i].title() + " Positives")
            print("--------")
            for positive in positives:
                print("`{}` [{:.4f}]".format(positive[0], positive[1]))
            print(names[i].title() + " Negatives")
            print("--------")
            for negative in negatives:
                print("`{}` [{:.4f}]".format(negative[0], negative[1]))


def plot_100(category, clf, ngram, vectorizer, x_train, yy_train, names):
    sel = SelectKBest(chi2, k=100)  # select k top features meta-classifier
    sel.fit_transform(x_train, yy_train)
    X_indices = sel.get_support(indices=True)
    X_xx = vectorizer.get_feature_names()

    if (len(clf.coef_) == 1):
        X_yy = np.array([[X_xx[i], clf.coef_[0][i]] for i in X_indices])

        fig = plt.figure(figsize=(13, 3))
        plt.bar([x for x in range(len(X_yy))], [float(x[1]) for x in X_yy])
        plt.xticks([x for x in range(len(X_yy))], X_yy[:, 0], fontsize=8, rotation=90)
        plt.xlabel("Char n-gram(" + ngram + ") feature")
        plt.ylabel("Classifier coef_ value")
        plt.title("Top 100 Features for " + (" vs. ".join(names)).title())
        fig.tight_layout()
        plt.savefig("images/" + "topk_features_" + str(category) + "_"
                    + datetime.datetime.now().strftime('%s') + ".png", dpi=(250), bbox_inches='tight')
    else:
        for i, coef_ in enumerate(clf.coef_):
            X_yy = np.array([[X_xx[i], coef_[i]] for i in X_indices])

            fig = plt.figure(figsize=(13, 3))
            plt.bar([x for x in range(len(X_yy))], [float(x[1]) for x in X_yy])
            plt.xticks([x for x in range(len(X_yy))], X_yy[:, 0], fontsize=8, rotation=90)
            plt.xlabel("Char n-gram(" + ngram + ") feature")
            plt.ylabel("Classifier coef_ value")
            plt.title("Top 100 Features for " + names[i].title() + " in (" + (" vs. ".join(names)).title() + ")")
            fig.tight_layout()
            plt.savefig("images/" + "topk_features_" + str(category) + "_" + str(i) + "_"
                        + datetime.datetime.now().strftime('%s') + ".png", dpi=(250), bbox_inches='tight')
