#!/usr/local/bin/python3
import collections
import datetime
import math
from optparse import OptionParser
from time import time

import matplotlib.pyplot as plt
from sklearn.feature_extraction import text
from sklearn.metrics import plot_confusion_matrix
import numpy as np
import config
from classifiers.clusters import run_unsupervised_clustering
from classifiers.common import write_model_to_disk, get_vectorizer, read_model_from_disk, print_report, plot_features
from classifiers.cust import run_cust_classifier
from classifiers.dtree import run_decision_tree_classifier
from classifiers.lgr import run_lgr_classifier
from classifiers.mnb import run_naive_classifier
from classifiers.rfc import run_rfc_classifier
from classifiers.sgd import run_sgd_classifier
from classifiers.svd import run_svd_classifier
from classifiers.tfidf import run_tfidf_insights
from solutions.solution import italki_root, ice_root_kachrus, ice_root_geo
from solutions.cross_corpus import cross_corpus
from util.categories import categories_list
from util.corpus import fetch_itaki, fetch_ice


def get_comma_separated_args(option, opt, value, _parser):
    if value is not None:
        setattr(_parser.values, option.dest, value.split(','))


if __name__ == "__main__":
    ####################################################################################################################
    #                                             Language Categories
    ####################################################################################################################

    # // 2 categories
    # non_native = ["Telugu", "Romanian"]

    # Sinhala vs Non-Sinhala

    # Sinhala vs. Tamil

    ####################################################################################################################
    #                                                Read Parameters
    ####################################################################################################################
    # ------- Parse Terminal Input Params --------
    parser = OptionParser(usage="usage: %prog [options] arg1 arg2")
    parser.add_option("-l", "--languages", dest="languages", type="string",
                      help="languages", action="callback", callback=get_comma_separated_args)
    parser.add_option("--without-langs", dest="without_langs", type="string",
                      help="without_langs", action="callback", callback=get_comma_separated_args)
    parser.add_option("-c", "--category", dest="category",
                      help="category 0, 1, 2...14", type="int")
    parser.add_option("-t", "--train", dest="train_precentage", type="float",
                      help="train_precentage", metavar="0.7")
    parser.add_option("-s", "--strip-accents", dest="strip_accents", help="strip accents", action="store_true",
                      default=False)
    parser.add_option("--lemma", "--lemmatization", dest="lemmatization", help="lemmatization", action="store_true",
                      default=False)
    parser.add_option("-a", "--analyzer", dest="analyzer", help="analyzer char or word", default="char", metavar="char")
    parser.add_option("-n", "--ngram-range", type="string",
                      help="ngram range", action="callback", callback=get_comma_separated_args)
    parser.add_option("--stop-words", dest="stop_words", help="stop words", action="store_true", default=False)
    parser.add_option("--no-trim", dest="trim", help="trim words", action="store_false", default=True)
    parser.add_option("-i", "--max-iter", dest="max_iter", help="max_iter", type="int", default=50)
    parser.add_option("--alpha", dest="alpha", help="alpha", type="float", default=0.0001)
    parser.add_option("--loss", dest="loss", help="loss", type="string", default="")
    parser.add_option("--penalty", dest="penalty", help="penalty", type="string", default="")
    parser.add_option("--silent", dest="silent", help="silent", action="store_true", default=False)
    parser.add_option("--no-header", dest="no_header", help="no_header", action="store_true", default=False)
    parser.add_option("--clf", dest="clf", help="classifier", type="string", metavar="tf-idf|sdg|dtree")
    parser.add_option("--run-type", dest="run_type", help="run_type", type="string", default="default",
                      metavar="grid_search|learning_curve|default")
    parser.add_option("--load-vec", dest="load_vec", help="load_vec", type="string")
    parser.add_option("--load-clf", dest="load_clf", help="load_clf", type="string")
    parser.add_option("--write-corpus", dest="write_corpus", help="write_corpus", action="store_true", default=False)
    parser.add_option("--read-corpus", dest="read_corpus", help="read_corpus", action="store_true", default=False)
    parser.add_option("--conf-matrix", dest="conf_matrix", help="conf_matrix", action="store_true", default=False)
    parser.add_option("--no-dump", dest="no_dump", help="no_dump", action="store_true", default=False)
    parser.add_option("--ice-corpus", dest="ice_corpus", help="ice_corpus", action="store_true", default=False)
    parser.add_option("--mini-data", dest="mini_data", type=int, help="mini_data")
    (options, args) = parser.parse_args()
    config.is_silent = options.silent
    config.no_header = options.no_header
    config.no_dump = options.no_dump
    # ------------------- Set Params -----------------
    # ------ Dataset -------------
    _categories = []
    if options.languages:
        _categories = options.languages

    if options.category is not None:
        _categories = categories_list[options.category]

    if options.without_langs:
        for rem_lang in options.without_langs:
            if rem_lang in _categories:
                _categories.remove(rem_lang)
        print("Removing Languages '" + ",".join(options.without_langs) + "' !")

    _train_precentage = 0.7
    if options.train_precentage:
        _train_precentage = options.train_precentage

    _strip_accents = options.strip_accents
    # ------ TF-IDF -------------

    _lemmatization = False
    if options.lemmatization:
        _lemmatization = options.lemmatization

    _analyzer = 'char'
    # analyzer = 'word'
    if options.analyzer:
        _analyzer = options.analyzer

    _ngram_range = (1, 12)
    # ngram_range = (2, 11)
    if options.ngram_range:
        _ngram_range = (int(options.ngram_range[0]), int(options.ngram_range[1]))

    _stop_words = None
    # stop_words = text.ENGLISH_STOP_WORDS
    if options.stop_words:
        _stop_words = text.ENGLISH_STOP_WORDS

    # ------ SDG Classifier -----
    _max_iter = 50
    # max_iter=1000
    if options.max_iter:
        _max_iter = options.max_iter

    _alpha = 0.0001
    if options.alpha:
        _alpha = options.alpha

    _run_type = options.run_type
    _trim_by_category = options.trim
    _load_clf = options.load_clf
    _load_vec = options.load_vec
    _loss = options.loss
    _penalty = options.penalty

    ####################################################################################################################
    #                                                  Read Corpus
    ####################################################################################################################
    if not config.is_silent:
        print("########### Variables ################")
        print("train_precentage: {:.2f}\n_strip_accents: {}\n_lemmatization: {}\n_analyzer: {}\n_ngram_range: {}\n"
              "_stop_words: {}\n_max_iter: {}\n_loss: {}\n_penalty: {}\n_categories: {}\n_clf: {}\n_run_type: {}\n"
              "_trim_by_category: {}\n_load_clf: {}\n_load_vec: {}\n"
              .format(_train_precentage, _strip_accents, _lemmatization, _analyzer, _ngram_range,
                      _stop_words is not None, _max_iter, _loss, _penalty, _categories, options.clf, _run_type, _trim_by_category,
                      _load_clf,
                      _load_vec))

    if not config.is_silent:
        print("######### Loading Corpora ############")
    t0 = time()
    if options.read_corpus:
        corpus_model = read_model_from_disk(f"util/lang_corpus_{options.category}")
        if not config.is_silent:
            print("Corpus loaded from disk...")
            print("categories: {}\nstrip_accents: {}\ntrain_precentage: {}\ntrim_by_category: {}\n".format(
                corpus_model['categories'], corpus_model['strip_accents'],
                corpus_model['train_precentage'], corpus_model['trim_by_category']))
        dataset = corpus_model['dataset']
    else:
       if options.ice_corpus:
            dataset = fetch_ice(subsets=['train', 'test'], categories=_categories,
                                strip_accents=_strip_accents, train_precentage=_train_precentage,
                                trim_by_category=_trim_by_category, verbose=False)
       else:
           dataset = fetch_itaki(subsets=['train', 'test'], categories=_categories,
                                 strip_accents=_strip_accents, train_precentage=_train_precentage,
                                 trim_by_category=_trim_by_category, verbose=False)

    if options.write_corpus:
        corpus_model = {"categories": _categories,
                        "strip_accents": _strip_accents,
                        "train_precentage": _train_precentage,
                        "trim_by_category": _trim_by_category,
                        "dataset": dataset
                        }
        write_model_to_disk(corpus_model, file_name=f"util/lang_corpus_{options.category}", force=True)

    train_dataset = dataset["train"]
    test_dataset = dataset["test"]
    del dataset  # free-memory

    duration = time() - t0
    if not config.is_silent:
        print("done in %fs" % duration)

    ####################################################################################################################
    #                                             Train and Test Classifiers
    ####################################################################################################################
    if not config.is_silent:
        print("########### Running Classifier ##############")

    params = {"lemmatization": _lemmatization, "ngram_range": _ngram_range,
              "analyzer": _analyzer, "stop_words": _stop_words, "max_iter": _max_iter, "categories": _categories,
              "run_type": _run_type, "trim_by_category": _trim_by_category, "load_clf": _load_clf,
              "strip_accents": _strip_accents, "alpha": _alpha, "category": options.category,
              "load_vec": _load_vec, "loss": _loss, "penalty": _penalty, "train_precentage": _train_precentage}

    _xx_train = train_dataset.data
    _yy_train = train_dataset.target
    _xx_test = test_dataset.data
    _yy_test = test_dataset.target

    if options.mini_data is not None:
        mini_data_size = options.mini_data
        train_size = int(mini_data_size * _train_precentage)
        test_size = int(mini_data_size * (1 - _train_precentage))
        _xx_train = _xx_train[:int(train_size)]
        _yy_train = _yy_train[:int(train_size)]
        _xx_test = _xx_test[:test_size]
        _yy_test = _yy_test[:test_size]
        if not config.is_silent:
            print("MINI_DATA!: Using {:,} training and {:,} testing docs".format(train_size, test_size))

    params['t_prefix'] = datetime.datetime.now().strftime('%s')
    if options.run_type == "progressive_validation":
        params['run_type'] = "default"
        params['no_cat_reports'] = True

        # Train Vectorizer
        _x_train, _duration_train, _vectorizer = get_vectorizer(_xx_train, **params)

        print("-----Running DecisionTree------")
        run_decision_tree_classifier(_vectorizer, _x_train.copy(), _yy_train, _xx_test, _yy_test, **params)

        # print("-----Running LogisiticRegression------")
        # run_lgr_classifier(vectorizer, x_train.copy(), train_dataset, test_dataset, **params)

        print("-----Running SGD L2, Hinge------")
        params["penalty"] = "l2"
        params["loss"] = "hinge"
        run_sgd_classifier(_vectorizer, _x_train.copy(), _yy_train, _xx_test, _yy_test, **params)
        # try 'penalty: elasticnet, l2, loss: modified_huber, hinge, max-iter=1000'

        print("-----Running SGD elasticnet, ModHuber------")
        params["penalty"] = "elasticnet"
        params["loss"] = "modified_huber"
        run_sgd_classifier(_vectorizer, _x_train.copy(), _yy_train, _xx_test, _yy_test, **params)

        print("-----Running SGD elasticnet, Hinge------")
        params["penalty"] = "elasticnet"
        params["loss"] = "hinge"
        run_sgd_classifier(_vectorizer, _x_train.copy(), _yy_train, _xx_test, _yy_test, **params)

        print("-----Running SGD L2, ModHuber------")
        params["penalty"] = "l2"
        params["loss"] = "modified_huber"
        run_sgd_classifier(_vectorizer, _x_train.copy(), _yy_train, _xx_test, _yy_test, **params)

        print("-----Running Naive------")
        params["penalty"] = None
        params["loss"] = None
        run_naive_classifier(_vectorizer, _x_train.copy(), _yy_train, _xx_test, _yy_test, **params)
        # run_prog_validation()
        print("DONE!")
        exit(0)
    elif options.run_type == "minmax":
        doc_freq_by_category = {}
        for doc in _xx_train:
            doc_category = math.ceil(len(doc.split()) / 10) * 10
            doc_freq_of_category = 0
            if doc_category in doc_freq_by_category:
                doc_freq_of_category = doc_freq_by_category[doc_category]
            doc_freq_of_category += 1
            doc_freq_by_category[doc_category] = doc_freq_of_category
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ordered_dict = collections.OrderedDict(sorted(doc_freq_by_category.items()))
        cats = [str(key) for key in ordered_dict.keys()]
        freqs = [val for val in ordered_dict.values()]
        ax.bar(cats, freqs)
        for tick in ax.xaxis.get_major_ticks():
            tick.label.set_fontsize(8)
            # specify integer or one of preset strings, e.g.
            # tick.label.set_fontsize('x-small')
            tick.label.set_rotation('vertical')
        # plt.plot(cats, freqs)
        for a, b in zip(cats, freqs):
            ax.text(a, b, str(b), fontsize=8, rotation=90)
        ax.set_title('Word Frequencies by Documents')
        plt.xlabel('Word Frequency Category')
        plt.ylabel('Number of Documents')
        plt.show()
        print("DONE!")
        exit(0)
    elif options.run_type == "features":
        plot_features(**params)
    elif options.run_type == "cross-corpus":
        cross_corpus(**params)
    elif options.run_type == "solution":
        report_categories = _categories
        if isinstance(_categories[0], list):
            report_categories = [item for sublist in _categories for item in sublist[:1]]

        # Italki
        # _duration_test, predicted_target = italki_root.score(_xx_test, **params)
        # print_report(predicted_target, _yy_test, 0, _duration_test, True, **params)
        # ax = plt.figure(figsize=(20,20))
        # disp = plot_confusion_matrix(italki_root, _xx_test, _yy_test, display_labels=report_categories, cmap=plt.cm.Blues,
        #                              normalize=None)
        # disp.ax_.set_title("Confusion Matrix for All 23 Countries")
        # tick_marks = np.arange(len(report_categories))
        # plt.xticks(tick_marks, report_categories, rotation=90)
        # plt.yticks(tick_marks, report_categories)
        # plt.tight_layout()
        # plt.savefig("images/" + str(options.category) + "_confusion_matrix_" + datetime.datetime.now().strftime(
        #     '%s') + ".png", dpi=(250), bbox_inches='tight')

        #ICE Katchrus
        _duration_test, predicted_target = ice_root_kachrus.score(_xx_test, **params)
        print_report(predicted_target, _yy_test, 0, _duration_test, True, **params)
        ax = plt.figure(figsize=(20,20))
        disp = plot_confusion_matrix(ice_root_kachrus, _xx_test, _yy_test, display_labels=report_categories, cmap=plt.cm.Blues,
                                     normalize=None)
        disp.ax_.set_title("Confusion Matrix for All 10 Countries")
        tick_marks = np.arange(len(report_categories))
        plt.xticks(tick_marks, report_categories, rotation=90)
        plt.yticks(tick_marks, report_categories)
        plt.tight_layout()
        plt.savefig("images/" + str(options.category) + "_confusion_matrix_" + datetime.datetime.now().strftime(
            '%s') + ".png", dpi=(250), bbox_inches='tight')

        # ICE Geo
        _duration_test, predicted_target = ice_root_geo.score(_xx_test, **params)
        print_report(predicted_target, _yy_test, 0, _duration_test, True, **params)
        ax = plt.figure(figsize=(20,20))
        disp = plot_confusion_matrix(ice_root_geo, _xx_test, _yy_test, display_labels=report_categories, cmap=plt.cm.Blues,
                                     normalize=None)
        disp.ax_.set_title("Confusion Matrix for All 10 Countries")

        tick_marks = np.arange(len(report_categories))
        plt.xticks(tick_marks, report_categories, rotation=90)
        plt.yticks(tick_marks, report_categories)
        plt.tight_layout()
        plt.savefig("images/" + str(options.category) + "_confusion_matrix_" + datetime.datetime.now().strftime(
            '%s') + ".png", dpi=(250), bbox_inches='tight')

        print("DONE!")
        exit(0)
    t0 = time()
    if options.clf == "tf-idf":
        run_tfidf_insights(_xx_train, _yy_train, **params)
    elif options.clf == "dst":
        # Train Vectorizer
        _x_train, _duration_train, _vectorizer = get_vectorizer(_xx_train, **params)
        run_decision_tree_classifier(_vectorizer, _x_train, _yy_train, _xx_test, _yy_test, **params)
    elif options.clf == "lgr":
        # Train Vectorizer
        _x_train, _duration_train, _vectorizer = get_vectorizer(_xx_train, **params)
        run_lgr_classifier(_vectorizer, _x_train, _yy_train, _xx_test, _yy_test, **params)
    elif options.clf == "sgd":
        # Train Vectorizer
        _x_train, _duration_train, _vectorizer = get_vectorizer(_xx_train, **params)
        run_sgd_classifier(_vectorizer, _x_train, _yy_train, _xx_test, _yy_test, **params)
    elif options.clf == "rfc":
        # Train Vectorizer
        _x_train, _duration_train, _vectorizer = get_vectorizer(_xx_train, **params)
        run_rfc_classifier(_vectorizer, _x_train, _yy_train, _xx_test, _yy_test, **params)
    elif options.clf == "naive" or options.clf == "mnb":
        # Train Vectorizer
        _x_train, _duration_train, _vectorizer = get_vectorizer(_xx_train, **params)
        run_naive_classifier(_vectorizer, _x_train, _yy_train, _xx_test, _yy_test, **params)
    elif options.clf == "svd":
        # Train Vectorizer
        _x_train, _duration_train, _vectorizer = get_vectorizer(_xx_train, **params)
        run_svd_classifier(_vectorizer, _x_train, _yy_train, _xx_test, _yy_test, **params)
    elif options.clf == "custom":
        run_cust_classifier(_xx_train, _yy_train, _xx_test, _yy_test, **params)
    elif options.clf == "cluster":
        run_unsupervised_clustering(_xx_train, _yy_train, **params)
    else:
        raise Exception("Invalid --clf parameter value!")
    duration = time() - t0
    if not config.is_silent:
        print("ran classifier for %fs" % duration)


def clearup_categories(test_data, test_target, categories, top_level_categories):
    new_a_test_data_list = []  # new list
    new_a_target_list = []  # new list
    for i, target_item in enumerate(test_target):  #
        for j, top_level_category in enumerate(top_level_categories):
            if categories[target_item] in top_level_category[1]:
                new_a_test_data_list.append(test_data[i])  # add document
                new_a_target_list.append(j)  # set target replacing target with new category index
                break
    return new_a_test_data_list, new_a_target_list
