import datetime
import itertools
import os
from time import time

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rcParams
from sklearn.feature_extraction.text import HashingVectorizer
from sklearn.linear_model import LogisticRegression, PassiveAggressiveClassifier
from sklearn.tree import DecisionTreeClassifier

from classifiers.mnb import create_mnb_clf
from classifiers.sgd import create_sgd_clf
from classifiers.tfidf import create_tfidf_vectorizer


def run_prog_validation(xx_train, **kwargs):
    vectorizer = create_tfidf_vectorizer(**kwargs)
    partial_fit_classifiers = {
        'SGD-L2': create_sgd_clf(max_iter=1000),
        'SGD-ElasticNet': create_sgd_clf(max_iter=1000),
        'NBMultinomial': create_mnb_clf(),
        'PassiveAggClf': PassiveAggressiveClassifier(n_jobs=-1),
        'DecisionTree': DecisionTreeClassifier(),
        'LogisticRegression': LogisticRegression(max_iter=500, n_jobs=-1),
    }
    plot_progressive_learn(partial_fit_classifiers, vectorizer, xx_train)


def plot_progressive_learn(partial_fit_classifiers, vectorizer, train_dataset):
    vectorizer = HashingVectorizer(analyzer='char', ngram_range=(2, 10), decode_error='ignore', n_features=2 ** 18,
                                   alternate_sign=False)
    cls_stats = {}
    for cls_name in partial_fit_classifiers:
        stats = {'n_train': 0, 'n_train_pos': 0,
                 'accuracy': 0.0, 'accuracy_history': [(0, 0)], 't0': time(),
                 'runtime_history': [(0, 0)], 'total_fit_time': 0.0}
        cls_stats[cls_name] = stats

    all_classes = np.array([i for i in range(0, len(train_dataset.target_names))])
    n_test_documents = 1000

    # Iterator over parsed docs
    data_stream = stream_reuters_documents(train_dataset)

    tick = time()
    X_test_text, y_test = get_minibatch(data_stream, 1000)
    parsing_time = time() - tick
    tick = time()
    X_test = vectorizer.transform(X_test_text)
    vectorizing_time = time() - tick

    iter_minibatches(data_stream, n_test_documents)
    # Discard test set

    # We will feed the classifier with mini-batches of 1000 documents; this means
    # we have at most 1000 docs in memory at any time.  The smaller the document
    # batch, the bigger the relative overhead of the partial fit methods.
    minibatch_size = 1000

    # Create the data_stream that parses Reuters SGML files and iterates on
    # documents as a stream.
    minibatch_iterators = iter_minibatches(data_stream, minibatch_size)
    total_vect_time = 0.0

    # Logistic test accuracy stats
    tick = time()
    trainX = vectorizer.transform(list(train_dataset.data))
    l_clf = partial_fit_classifiers[cls_name]
    l_clf.fit(trainX, train_dataset.target)
    cls_name = "LogisticRegression"
    cls_stats[cls_name]['total_fit_time'] += time() - tick
    cls_stats[cls_name]['n_train'] = trainX.shape[0]
    cls_stats[cls_name]['n_train_pos'] = trainX.shape[0]
    tick = time()
    cls_stats[cls_name]['accuracy'] = l_clf.score(X_test, y_test)
    cls_stats[cls_name]['prediction_time'] = time() - tick
    acc_history = (cls_stats[cls_name]['accuracy'],
                   cls_stats[cls_name]['n_train'])
    cls_stats[cls_name]['accuracy_history'].append(acc_history)
    run_history = (cls_stats[cls_name]['accuracy'],
                   total_vect_time + cls_stats[cls_name]['total_fit_time'])
    cls_stats[cls_name]['runtime_history'].append(run_history)

    # Decision Tree test accuracy stats
    cls_name = "DecisionTree"
    tick = time()
    dt_clf = partial_fit_classifiers[cls_name]
    dt_clf.fit(trainX, train_dataset.target)
    cls_stats[cls_name]['total_fit_time'] += time() - tick
    cls_stats[cls_name]['n_train'] = trainX.shape[0]
    cls_stats[cls_name]['n_train_pos'] = trainX.shape[0]
    tick = time()
    cls_stats[cls_name]['accuracy'] = dt_clf.score(X_test, y_test)
    cls_stats[cls_name]['prediction_time'] = time() - tick
    acc_history = (cls_stats[cls_name]['accuracy'],
                   cls_stats[cls_name]['n_train'])
    cls_stats[cls_name]['accuracy_history'].append(acc_history)
    run_history = (cls_stats[cls_name]['accuracy'],
                   total_vect_time + cls_stats[cls_name]['total_fit_time'])
    cls_stats[cls_name]['runtime_history'].append(run_history)

    # Main loop : iterate on mini-batches of examples
    for i, (X_train_text, y_train) in enumerate(minibatch_iterators):

        tick = time()
        X_train = vectorizer.transform(X_train_text)
        total_vect_time += time() - tick

        for i, (cls_name, cls) in enumerate(partial_fit_classifiers.items()):
            if cls_name in ["LogisticRegression", "DecisionTree"]:
                if i != 0:
                    # just duplicate
                    cls_stats[cls_name]['total_fit_time'] = cls_stats[cls_name]['total_fit_time']
                    cls_stats[cls_name]['n_train'] = cls_stats[cls_name]['n_train']
                    cls_stats[cls_name]['n_train_pos'] = cls_stats[cls_name]['n_train_pos']
                    cls_stats[cls_name]['accuracy'] = cls_stats[cls_name]['accuracy']
                    cls_stats[cls_name]['prediction_time'] = cls_stats[cls_name]['prediction_time']
                    acc_history = (cls_stats[cls_name]['accuracy'],
                                   cls_stats[cls_name]['n_train'])
                    cls_stats[cls_name]['accuracy_history'].append(acc_history)
                    run_history = (cls_stats[cls_name]['accuracy'],
                                   total_vect_time + cls_stats[cls_name]['total_fit_time'])
                    cls_stats[cls_name]['runtime_history'].append(run_history)
            else:
                tick = time()
                # update estimator with examples in the current mini-batch
                cls.partial_fit(X_train, y_train, classes=all_classes)
                # accumulate test accuracy stats
                cls_stats[cls_name]['total_fit_time'] += time() - tick
                cls_stats[cls_name]['n_train'] += X_train.shape[0]
                cls_stats[cls_name]['n_train_pos'] += sum(y_train)
                tick = time()
                cls_stats[cls_name]['accuracy'] = cls.score(X_test, y_test)
                cls_stats[cls_name]['prediction_time'] = time() - tick
                acc_history = (cls_stats[cls_name]['accuracy'],
                               cls_stats[cls_name]['n_train'])
                cls_stats[cls_name]['accuracy_history'].append(acc_history)
                run_history = (cls_stats[cls_name]['accuracy'],
                               total_vect_time + cls_stats[cls_name]['total_fit_time'])
                cls_stats[cls_name]['runtime_history'].append(run_history)

            if i % len(partial_fit_classifiers.items()) == 0:
                print(progress(cls_name, cls_stats[cls_name]))
        if i % len(partial_fit_classifiers.items()) == 0:
            print('\n')

    rcParams['legend.fontsize'] = 10
    cls_names = list(sorted(cls_stats.keys()))

    # Plot accuracy evolution
    plt.figure()
    for _, stats in sorted(cls_stats.items()):
        # Plot accuracy evolution with #examples
        accuracy, n_examples = zip(*stats['accuracy_history'])
        plot_accuracy(n_examples, accuracy, "training examples (#)")
        ax = plt.gca()
        ax.set_ylim((0.4, 1))
    plt.legend(cls_names, loc='best')

    plt.savefig("images/" + os.path.basename(__file__) + "_progressive_validation_accuracy_"
                + datetime.datetime.now().strftime('%s') + ".png", dpi=(250), bbox_inches='tight')

    plt.figure()
    for _, stats in sorted(cls_stats.items()):
        # Plot accuracy evolution with runtime
        accuracy, runtime = zip(*stats['runtime_history'])
        plot_accuracy(runtime, accuracy, 'runtime (s)')
        ax = plt.gca()
        ax.set_ylim((0.4, 1))
    plt.legend(cls_names, loc='best')
    plt.savefig("images/" + os.path.basename(__file__) + "_progressive_validation_runtime_"
                + datetime.datetime.now().strftime('%s') + ".png", dpi=(250), bbox_inches='tight')

    # Plot fitting times
    plt.figure()
    fig = plt.gcf()
    cls_runtime = [stats['total_fit_time']
                   for cls_name, stats in sorted(cls_stats.items())]

    cls_runtime.append(total_vect_time)
    cls_names.append('Vectorization')
    bar_colors = ['b', 'g', 'r', 'c', 'm', 'y']

    ax = plt.subplot(111)
    rectangles = plt.bar(range(len(cls_names)), cls_runtime, width=0.5,
                         color=bar_colors)

    ax.set_xticks(np.linspace(0, len(cls_names) - 1, len(cls_names)))
    ax.set_xticklabels(cls_names, fontsize=10)
    ymax = max(cls_runtime) * 1.2
    ax.set_ylim((0, ymax))
    ax.set_ylabel('runtime (s)')
    ax.set_title('Training Times')

    autolabel(ax, rectangles)
    plt.tight_layout()
    plt.savefig("images/" + os.path.basename(__file__) + "_progressive_validation_training_time_"
                + datetime.datetime.now().strftime('%s') + ".png", dpi=(250), bbox_inches='tight')
    plt.show()

    # Plot prediction times
    plt.figure()
    cls_runtime = []
    cls_names = list(sorted(cls_stats.keys()))
    for cls_name, stats in sorted(cls_stats.items()):
        cls_runtime.append(stats['prediction_time'])
    cls_runtime.append(parsing_time)
    cls_names.append('Read/Parse\n+Feat.Extr.')
    cls_runtime.append(vectorizing_time)
    cls_names.append('Hashing\n+Vect.')

    ax = plt.subplot(111)
    rectangles = plt.bar(range(len(cls_names)), cls_runtime, width=0.5,
                         color=bar_colors)

    ax.set_xticks(np.linspace(0, len(cls_names) - 1, len(cls_names)))
    ax.set_xticklabels(cls_names, fontsize=8)
    plt.setp(plt.xticks()[1], rotation=30)
    ymax = max(cls_runtime) * 1.2
    ax.set_ylim((0, ymax))
    ax.set_ylabel('runtime (s)')
    ax.set_title('Prediction Times (%d instances)' % n_test_documents)
    autolabel(ax, rectangles)
    plt.tight_layout()
    plt.savefig("images/" + os.path.basename(
        __file__) + "_progressive_validation_pred_time_" + datetime.datetime.now().strftime(
        '%s') + ".png",
                dpi=(250), bbox_inches='tight')
    plt.show()


def stream_reuters_documents(data_set):
    return zip(data_set.data, data_set.target)


def get_minibatch(doc_iter, size):
    data = [(doc_chunk[0], doc_chunk[1])
            for doc_chunk in itertools.islice(doc_iter, size)]

    if not len(data):
        return np.asarray([], dtype=int), np.asarray([], dtype=int)
    X_text, y = zip(*data)
    return X_text, np.asarray(y, dtype=int)


def iter_minibatches(doc_iter, minibatch_size):
    """Generator of minibatches."""
    X_text, y = get_minibatch(doc_iter, minibatch_size)
    while len(X_text):
        yield X_text, y
        X_text, y = get_minibatch(doc_iter, minibatch_size)


def progress(cls_name, stats):
    """Report progress information, return a string."""
    duration = time() - stats['t0']
    s = "%20s classifier : \t" % cls_name
    s += "%(n_train)6d train docs (%(n_train_pos)6d positive) " % stats
    s += "accuracy: %(accuracy).3f " % stats
    s += "in %.2fs (%5d docs/s)" % (duration, stats['n_train'] / duration)
    return s


def autolabel(ax, rectangles):
    """attach some text vi autolabel on rectangles."""
    for rect in rectangles:
        height = rect.get_height()
        ax.text(rect.get_x() + rect.get_width() / 2.,
                1.05 * height, '%.4f' % height,
                ha='center', va='bottom')
        plt.setp(plt.xticks()[1], rotation=30)


def plot_accuracy(x, y, x_legend):
    """Plot accuracy as a function of x."""
    x = np.array(x)
    y = np.array(y)
    plt.title('Classification accuracy as a function of %s' % x_legend)
    plt.xlabel('%s' % x_legend)
    plt.ylabel('Accuracy')
    plt.grid(True)
    plt.plot(x, y)
