from time import time

from classifiers.common import read_model_from_disk, test_classifier
from util.categories import *


def flat_category(kwargs):
    categories = kwargs.get("categories")
    flat_categories = categories
    if isinstance(categories[0], list):
        flat_categories = [item for sublist in categories for item in sublist[:1]]
    kwargs["categories"] = flat_categories
    return kwargs


class DecisionPoint:
    def __init__(self, name, supported_countries, category_id=99999999, clf=None, vec=None):
        self.name = name
        self.category_id = category_id
        self.load_clf = clf
        self.load_vec = vec
        self.supported_countries = supported_countries
        self._estimator_type = "classifier"

    def is_non_terminal(self):
        return self.load_clf is not None and self.load_vec is not None

    def has_children(self):
        return hasattr(self, "children")

    def add_child(self, node):
        if not hasattr(self, "children"):
            self.children = []
        self.children.append(node)

    def print(self, level=1):
        result = (" " * level) + self.name
        # if self.load_clf and self.load_vec:
        #     result += "[" + self.load_clf + "," + self.load_vec + "]"
        result += "[" + str(len(self.supported_countries)) + ":" + ",".join(self.supported_countries) + "]"
        if hasattr(self, "children"):
            result += " {\n"
            level += 4
            result += "\n".join([child.print(level) for child in self.children])
            result += "\n" + (" " * level) + "}"
        return result

    def validate(self):
        if self.has_children():
            screened = []
            for child in self.children:
                if child.supported_countries in screened:
                    raise Exception("Duplicated countries in between children of %s" % (self.name))
                screened.extend(child.supported_countries)

    def predict(self, xx_test):
        # return saved prediction to spoof conf_matrix in scikit
        return self.predicted_target

    def score(self, xx_test, **kwargs):
        self.validate()
        t0 = time()
        self.predicted_target, _ = self._score(xx_test, [i for i in range(len(xx_test))],
                                          flat_category(kwargs)["categories"], 0,
                                          **kwargs)
        _duration_test = time() - t0
        return _duration_test, self.predicted_target

    def _score(self, xx_test, yy_indexes, categories_all, child_index, **kwargs):
        print("*** processing " + self.name + "[" + str(len(self.supported_countries)) + ":" + ",".join(
            self.supported_countries) + "]")
        categories_for_id = [i[0] for i in categories_list[self.category_id]]

        print("scoring with: " + self.load_clf)
        clf = read_model_from_disk("models/" + self.load_clf)
        _duration_train = 0

        vect_model = read_model_from_disk("models/" + self.load_vec)
        vectorizer = vect_model['vectorizer']
        del vect_model

        x_test = vectorizer.transform(xx_test)
        _predicted_target, _duration_test = test_classifier(clf, x_test)

        xx_test_for_children = {}
        yy_indexes_for_children = {}
        for i, _pred in enumerate(_predicted_target):
            if _pred in xx_test_for_children:
                xx_test_for_children[_pred].append(xx_test[i])
            else:
                xx_test_for_children[_pred] = [xx_test[i], ]

            if _pred in yy_indexes_for_children:
                yy_indexes_for_children[_pred].append(i)
            else:
                yy_indexes_for_children[_pred] = [i, ]

        if self.has_children():
            for child in self.children:
                if child.name in categories_for_id:
                    child_index = categories_for_id.index(child.name)
                    xx_test_for_child = xx_test_for_children[child_index]
                    yy_indexes_for_child = yy_indexes_for_children[child_index]
                    if child.has_children():
                        # override categories for report-gen
                        kwargs["categories"] = categories_list[child.category_id]
                    if child.is_non_terminal():
                        child_pred, child_pred_indexes = child._score(xx_test_for_child, yy_indexes_for_child,
                                                                      categories_all, child_index, **kwargs)
                        for i, index in enumerate(child_pred_indexes):
                            _predicted_target[index] = child_pred[i]
                    else:
                        if len(child.supported_countries) != 1:
                            raise Exception("terminal node cannot have more than one country!")
                        for index in yy_indexes_for_child:
                            _predicted_target[index] = categories_all.index(child.supported_countries[0])
                else:
                    raise Exception(
                        "child name: %s cannot find in parent's %s categories" % (child.name, self.name))
        else:
            if categories_for_id == self.supported_countries:
                for country in categories_for_id:
                    child_index = categories_for_id.index(country)
                    yy_indexes_for_child = yy_indexes_for_children[child_index]
                    for index in yy_indexes_for_child:
                        _predicted_target[index] = categories_all.index(country)
            else:
                raise Exception("non-terminal node cannot have zero children!")
        return _predicted_target, yy_indexes


all_countries = cat1 + cat2 + cat3

italki_root = DecisionPoint("root", all_countries, 1, "1_sgd_clf_char_1_9_1592452650", "1_vec_char_1_9_1592452498")

native_node = DecisionPoint("native", ["English"])
is_indo_euro_node = DecisionPoint("non_native", non_native, 2, "2_sgd_clf_char_1_9_1595684678",
                                  "2_vec_char_1_9_1595684593")
italki_root.add_child(native_node)
italki_root.add_child(is_indo_euro_node)

is_indo_aryan_node = DecisionPoint("non_native_indo_european", non_native_indo_european, 7,
                                   "7_mnb_clf_char_1_11_1595685507",
                                   "7_vec_char_1_11_1595685467")
non_indo_euro = DecisionPoint("non_native_non_indo_european", non_native_non_indo_european, 3,
                              "3_sgd_clf_char_1_10_1595685123",
                              "3_vec_char_1_10_1595684862")
is_indo_euro_node.add_child(is_indo_aryan_node)
is_indo_euro_node.add_child(non_indo_euro)

is_sinhala_node = DecisionPoint("non_native_indo_european_indo_aryan", ["Sinhala", "Hindi"], 11,
                                "11_sgd_clf_char_1_9_1595685576",
                                "11_vec_char_1_9_1595685565")
non_native_indo_european_indo_aryan_sinhala_node = DecisionPoint("non_native_indo_european_indo_aryan_sinhala",
                                                                 ["Sinhala"])
non_native_indo_european_indo_aryan_non_sinhala_node = DecisionPoint("non_native_indo_european_indo_aryan_non_sinhala",
                                                                     ["Hindi"])
is_sinhala_node.add_child(non_native_indo_european_indo_aryan_sinhala_node)
is_sinhala_node.add_child(non_native_indo_european_indo_aryan_non_sinhala_node)

non_indo_aryan_node = DecisionPoint("non_native_indo_european_non_indo_aryan",
                                    non_native_indo_european_non_indo_aryan_germanic + non_native_indo_european_non_indo_aryan_balto_slavic + non_native_indo_european_non_indo_aryan_romance,
                                    8, "8_sgd_clf_char_1_6_1595685542", "8_vec_char_1_6_1595685534")
is_indo_aryan_node.add_child(is_sinhala_node)
is_indo_aryan_node.add_child(non_indo_aryan_node)

germanic_node = DecisionPoint("non_native_indo_european_non_indo_aryan_germanic",
                              non_native_indo_european_non_indo_aryan_germanic)
balto_slavic_node = DecisionPoint("non_native_indo_european_non_indo_aryan_balto_slavic",
                                  non_native_indo_european_non_indo_aryan_balto_slavic, 10,
                                  "10_sgd_clf_char_1_7_1595685556", "10_vec_char_1_7_1595685551")
romance_node = DecisionPoint("non_native_indo_european_non_indo_aryan_romance",
                             non_native_indo_european_non_indo_aryan_romance, 9,
                             "9_sgd_clf_char_1_8_1592453298", "9_vec_char_1_8_1592453253")
non_indo_aryan_node.add_child(germanic_node)
non_indo_aryan_node.add_child(balto_slavic_node)
non_indo_aryan_node.add_child(romance_node)

altaic_node = DecisionPoint("non_native_non_indo_european_altaic", non_native_non_indo_european_altaic, 4,
                            "4_sgd_clf_char_1_10_1595685389",
                            "4_vec_char_1_10_1595685189")
dravidian_node = DecisionPoint("non_native_non_indo_european_dravadian", dravidian, 5, "5_sgd_clf_char_1_7_1592453090",
                               "5_vec_char_1_7_1592453083")
austronisian_node = DecisionPoint("non_native_non_indo_european_austronisian",
                                  non_native_non_indo_european_austronisian, 6,
                                  "6_sgd_clf_char_1_7_1592453119", "6_vec_char_1_7_1592453099")
uralic_node = DecisionPoint("non_native_non_indo_european_uralic", non_native_non_indo_european_uralic)
austroasiatic_node = DecisionPoint("non_native_non_indo_european_austroasiatic",
                                   non_native_non_indo_european_austroasiatic)
afro_node = DecisionPoint("non_native_non_indo_european_afro", non_native_non_indo_european_afro)
sinotibetan_node = DecisionPoint("non_native_non_indo_european_sinotibetan", non_native_non_indo_european_sinotibetan)
non_indo_euro.add_child(altaic_node)
non_indo_euro.add_child(dravidian_node)
non_indo_euro.add_child(austronisian_node)
non_indo_euro.add_child(uralic_node)
non_indo_euro.add_child(austroasiatic_node)
non_indo_euro.add_child(afro_node)
non_indo_euro.add_child(sinotibetan_node)

# print(italki_root.print())

# yy = [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
# xx = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]
# root.score(xx, yy)

########### ICE
ice_root_kachrus = DecisionPoint("root", ice_all, 13, "13_mnb_clf_word_1_3_1596431293", "13_vec_word_1_3_1596431262")
ice_native_node = DecisionPoint("native", inner_circle, 14, "14_mnb_clf_word_1_3_1596431325", "14_vec_word_1_3_1596431308")
ice_non_native_node = DecisionPoint("non_native", outer_circle + expanded_circle, 15, "15_sgd_clf_char_1_6_1596431397", "15_vec_char_1_6_1596431351")
ice_root_kachrus.add_child(ice_native_node)
ice_root_kachrus.add_child(ice_non_native_node)


ice_outer_node = DecisionPoint("outer_circle", outer_circle, 16, "16_sgd_clf_char_1_6_1596431476", "16_vec_char_1_6_1596431421")
ice_expand_node = DecisionPoint("expanded_circle", expanded_circle, 17, "17_sgd_clf_char_1_9_1596431594", "17_vec_char_1_9_1596431507")
ice_non_native_node.add_child(ice_outer_node)
ice_non_native_node.add_child(ice_expand_node)
# print(ice_root_kachrus.print())

ice_root_geo = DecisionPoint("root", ice_all, 18, "18_mnb_clf_word_1_3_1596431706", "18_vec_word_1_3_1596431656")
ice_asia_node = DecisionPoint("asia", asia, 19, "19_mnb_clf_word_1_2_1596431737", "19_vec_word_1_2_1596431728")
ice_non_asia_node = DecisionPoint("non_asia", non_asia, 22, "22_sgd_clf_char_1_6_1596431858", "22_vec_char_1_6_1596431824")
ice_root_geo.add_child(ice_asia_node)
ice_root_geo.add_child(ice_non_asia_node)

qurope_node = DecisionPoint("europe", europe)
africa_node = DecisionPoint("africa", africa)
north_america_node = DecisionPoint("north_america", north_america, 23, "23_mnb_clf_word_1_3_1596431896", "23_vec_word_1_3_1596431879")
ice_non_asia_node.add_child(qurope_node)
ice_non_asia_node.add_child(africa_node)
ice_non_asia_node.add_child(north_america_node)

south_asia_node = DecisionPoint("south_asia", south_asia, 20, "20_mnb_clf_word_1_2_1596431753", "20_vec_word_1_2_1596431748")
non_south_asia_node = DecisionPoint("non_south_asia", non_south_asia, 21, "21_mnb_clf_char_1_6_1596431805", "21_vec_char_1_6_1596431768")
ice_asia_node.add_child(south_asia_node)
ice_asia_node.add_child(non_south_asia_node)

# print(ice_root_geo.print())
