import numpy as np
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.ensemble import RandomForestClassifier
from collections import Counter
from bonsai import BonsaiForest
from scipy.sparse import csr_matrix, find

class Layer:
    def __init__(self, layer_index, num_forest, n_estimator, max_depth=8, tree_type='fastxml', n_jobs=8, nnC=0,weights=None):
        self.layer_index = layer_index
        self.n_estimator = n_estimator
        self.num_forest = num_forest
        self.n_jobs = n_jobs
        self.max_depth = max_depth
        self.tree_type = tree_type
        self.weights = weights
        self.nnC = nnC
        self.model = []

    def getNorm(self):
        return self.model[0].getNorm()

    def train(self, train_data, train_label):
        for forest_index in range(self.num_forest):
            '''
            if forest_index % 2 == 0:
                clf = CraftMLForest(num_tree=self.n_estimator, dx=500, dy=50, ns=20000,\
                                    nl=10, nc=10, md=self.max_depth, n_jobs=self.n_jobs)
            else:
                clf = CraftMLForest(num_tree=self.n_estimator, dx=500, dy=50, ns=20000,\
                                    nl=10, nc=10, md=self.max_depth, n_jobs=self.n_jobs)

            '''
            # use craftml
            if self.tree_type == 'craftml':
                clf = CraftMLForest(num_tree=self.n_estimator, dx=500, dy=500, ns=20000,\
                                    nl=10, nc=10, md=self.max_depth, n_jobs=self.n_jobs, rand_seed=(12345+self.layer_index)*forest_index)
                clf.fit(train_data, train_label)
                self.model.append(clf)
            elif self.tree_type == 'bonsai':
                clf = BonsaiForest(num_tree=self.n_estimator, nc=100, md=self.max_depth, n_jobs=self.n_jobs, nnC=self.nnC, weights=self.weights)
                clf.fit(train_data, train_label)
                self.model.append(clf)
            else:
                # use fastxml
                clf = Trainer(n_trees=self.n_estimator, n_jobs=self.n_jobs)
                X = [train_data[i] for i in range(train_data.shape[0])]
                clf.fit(X, self.csr2list(train_label))
                path = 'models/layer_' + str(self.layer_index) + '_forest_' + str(forest_index)
                clf.save(path)


    def predict(self, test_data, topk=5):
        predict_prob = [None] * self.num_forest
        #pred = [[[] for _ in range(test_data.shape[0])] for _ in range(self.num_forest)]
        pred = np.zeros((self.num_forest, test_data.shape[0], topk))

        for forest_index in range(self.num_forest):
            if self.tree_type == 'craftml':
                pred[forest_index], predict_prob[forest_index] = self.model[forest_index].predict(test_data)
            elif self.tree_type == 'bonsai':
                pred[forest_index], predict_prob[forest_index] = self.model[forest_index].predict(test_data)
            else:
                path = 'models/layer_' + str(self.layer_index) + '_forest_' + str(forest_index)
                clf = Inferencer(path)
                X = [test_data[i] for i in range(test_data.shape[0])]
                predict_prob[forest_index] = clf.predict(X)

                for i in range(test_data.shape[0]):
                    od = []
                    for idx in reversed(predict_prob[forest_index][i].data.argsort()[-topk:]):
                        od.append(predict_prob[forest_index][i].indices[idx])

                    if len(od) < topk:
                        od += [0] * (topk - len(od))
                    pred[forest_index, i] = od

        final_pred = np.zeros((test_data.shape[0], topk))

        predict_prob = csr_matrix(sum(predict_prob)/self.num_forest)
        return final_pred, predict_prob

    def train_and_predict(self, train_data, train_label, test_data):
        self.train(train_data, train_label)
        pred, prob = self.predict(test_data)
        return pred, prob

    def csr2list(self, M):
        row, col, _ = find(M)
        res = [[] for _ in range(M.shape[0])]
        for r, c in zip(row, col):
            res[r].append(c)
        return res

