import datetime
from multiprocessing import Pool
from multiprocessing import cpu_count
from time import time

from sklearn.feature_extraction import DictVectorizer
from sklearn.feature_selection import VarianceThreshold

from classifiers.common import print_report, train_vectorizer
from classifiers.common import write_model_to_disk, read_model_from_disk, test_classifier, train_classifier
from classifiers.sgd import create_sgd_clf
import config
from util.feature_extractor import extract_features


def print_feature_report(x_train_features, target, categories):
    dict_by_lang = {}
    docs_count_by_lang = {}
    for feature_dict, t in zip(x_train_features, target):
        lang = categories[t]  # update lang docs count
        count = 0
        if lang in dict_by_lang:
            count = docs_count_by_lang[lang]
        count += 1
        docs_count_by_lang[lang] = count

        # accumulate values
        lang_dict = {}
        if lang in dict_by_lang:
            lang_dict = dict_by_lang[lang]

        for item in feature_dict.items():
            key = item[0]
            n_val = item[1]
            if isinstance(n_val, str):
                o_val = lang_dict[key] if key in lang_dict else ''
                n_val = o_val + '||' + n_val.replace(",", ";") + '||'
            else:
                o_val = lang_dict[key] if key in lang_dict else 0
                n_val = o_val + n_val
            lang_dict[key] = n_val

        dict_by_lang[lang] = lang_dict

    for lang_dict in dict_by_lang.items():
        lang = lang_dict[0]
        current_dict = lang_dict[1]
        for item in current_dict.items():
            key = item[0]
            n_val = item[1]
            if not isinstance(n_val, str):
                current_dict[key] = n_val / docs_count_by_lang[lang]
        dict_by_lang[lang] = current_dict

    print_header = True
    header = ''
    rows = []
    for lang_dict in dict_by_lang.items():
        row = lang_dict[0]
        if (print_header):
            if print_header:
                header += "Language, "
        current_dict = lang_dict[1]
        for item in current_dict.items():
            key = item[0]
            val = item[1]
            if print_header:
                header += key + ', '
            if isinstance(val, str):
                row += '"' + val + '", '
            elif isinstance(val, float):
                row += "{:.2f}".format(val) + ', '
            else:
                row += "{}".format(val) + ', '
        rows.append(row)

    print("writing features.txt")
    f = open("features.txt", "w+")
    f.write(header + "\n")
    for row in rows:
        f.write(row + "\n")

    # print(header)
    # for row in rows:
    #     print(row)


def run_cust_classifier(xx_train, yy_train, xx_test, yy_test, **kwargs):
    ngram_range = kwargs.get("ngram_range")
    analyzer = kwargs.get("analyzer")
    max_iter = kwargs.get("max_iter")
    load_vec = kwargs.get("load_vec")
    load_clf = kwargs.get("load_clf")
    conf_matrix = kwargs.get("conf_matrix")
    categories = kwargs.get("categories")
    t_prefix = datetime.datetime.now().strftime('%s')

    if not config.is_silent:
        print("extracting features...")

    with Pool(cpu_count()) as pool:
        x_test_features = pool.map(extract_features, xx_test)

    _duration_train = 0
    # Train Vectorizer
    if load_vec is not None:
        # load vec
        vect_model = read_model_from_disk("models/" + load_vec)
        vectorizer = vect_model['vectorizer']
        x_train = vect_model['Xtrain']
        selector = vect_model['featureSelector']
    else:
        t0 = time()
        with Pool(cpu_count()) as pool:
            XTrain_features = pool.map(extract_features, xx_train)
        duration = time() - t0
        if not config.is_silent:
            print("done in %.2fs" % duration)

        # train vec
        vectorizer = DictVectorizer(sparse=False)
        selector = VarianceThreshold(threshold=(.8 * (1 - .8)))
        x_train, t_ = train_vectorizer(vectorizer, XTrain_features)
        x_train = selector.fit_transform(x_train)
        print("selected features...")
        print(selector.get_support(indices=True))
        _duration_train += t_
        vect_model = {"vectorizer": vectorizer, "featureSelector": selector,
                      "Xtrain": x_train}
        # write_model_to_disk(vect_model, "custvec", "na", (0, 0), t_prefix)
        print_feature_report(XTrain_features, yy_train, categories)

    # ---------Classifier------
    if load_clf is not None:
        # load clf
        clf = read_model_from_disk("models/" + load_clf)
    else:
        # train clf
        clf = create_sgd_clf(max_iter)
        _duration_train = train_classifier(clf, x_train, yy_train)
        write_model_to_disk(clf, "custom_clf", analyzer, ngram_range, t_prefix)

    # Test
    x_test = vectorizer.transform(x_test_features)
    x_test = selector.transform(x_test)
    _predicted_target, _duration_test = test_classifier(clf, x_test)
    # Print
    if not config.is_silent:
        print("Classifier Score: {:.2f}".format(clf.score(x_test, yy_test)))
    print_report(_predicted_target, yy_test, _duration_train, _duration_test, conf_matrix, **kwargs)
    # if not kwargs.get("no_cat_reports"):
    #     for top_category in categories_list:
    #         print("Testing for {}".format(top_category))
    #         new_test_data, new_test_target = clearup_categories(xx_test, yy_test, categories,
    #                                                             top_category)
    #         # Test for top category
    #         x_test = vectorizer.transform(new_test_data)
    #         _predicted_target, _duration_test = test_classifier(clf, x_test)
    #         print_report(_predicted_target, new_test_target, _duration_train, _duration_test, top_category)
