import datetime

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import plot_confusion_matrix

from classifiers.common import read_model_from_disk, test_classifier, print_report
from util.categories import *
from util.corpus import fetch_itaki


def cross_corpus(**kwargs):
    classifiers = [
        "19_mnb_clf_char_2_6_1592751395", "25_sgd_clf_char_2_6_1592770313"
    ]
    vectorizers = [
        "19_vec_char_2_6_1592751344", "25_vec_char_2_6_1592770284"
    ]
    categories = [19, 25]
    country_to_targets = [{"Sinhala":0, "Hindi":0, "Telugu":0, "Malayalam":0, "Tamil":0, "Malay":1}, {"Sinhala":0, "Hindi":1, "Telugu":1, "Malayalam":1}]
    for i in range(len(vectorizers)):
        load_vec = vectorizers[i]
        load_clf = classifiers[i]
        category = categories[i]
        country_to_target = country_to_targets[i]

        report_categories = categories_list[category]
        if isinstance(categories_list[category][0], list):
            report_categories = [item for sublist in categories_list[category] for item in sublist[:1]]

        # loading models from ICE Corpus

        clf = read_model_from_disk("models/" + load_clf)
        _duration_train = 0

        vect_model = read_model_from_disk("models/" + load_vec)
        vectorizer = vect_model['vectorizer']
        del vect_model

        # loading test  data from italki Corpus
        dataset = fetch_itaki(subsets=['all'], categories=categories_list[26],
                              strip_accents=True, train_precentage=0.7,
                              trim_by_category=False, verbose=False)

        test_dataset = dataset["all"]
        del dataset  # free-memory
        xx_test = test_dataset.data
        yy_test = test_dataset.target

        for country in country_to_target:
            print("Testing " + country)
            country_index = test_dataset.target_names.index(country)
            xx_test_filtered = []
            for i, yy in enumerate(yy_test):
                if yy == country_index:
                    xx_test_filtered.append(xx_test[i])
            yy_test_filtered = [country_to_target[country] for i in range(len(xx_test_filtered))]

            x_test = vectorizer.transform(xx_test_filtered)
            _predicted_target, _duration_test = test_classifier(clf, x_test)
            kwargs["categories"] = report_categories
            print_report(_predicted_target, yy_test_filtered, _duration_train, _duration_test, True, **kwargs)
            plt.figure(figsize=(20, 20))
            disp = plot_confusion_matrix(clf, x_test, yy_test_filtered, display_labels=report_categories, cmap=plt.cm.Blues,
                                         normalize=None)
            disp.ax_.set_title("Confusion matrix for " + country + " in " + " vs. ".join(report_categories))
            # disp.ax_.set_title("Confusion Matrix for All 23 Countries")
            # tick_marks = np.arange(len(report_categories))
            # plt.xticks(tick_marks, report_categories, rotation=90)
            # plt.yticks(tick_marks, report_categories)
            plt.tight_layout()
            plt.savefig("images/" + str(category) + "_confusion_matrix_" + datetime.datetime.now().strftime(
                '%s') + ".png", dpi=(250), bbox_inches='tight')


    print("DONE!")
    exit(0)
    # plt.show()
