import numpy as np
from scipy.special import expit
from scipy import linalg
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge
from sklearn.neural_network import MLPRegressor
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize
from sklearn.kernel_ridge import KernelRidge
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor

from sklearn.neighbors import NearestNeighbors

import utils

tol = 1e-8


class Ensemble:
    def __init__(self, num_agents, training_data, sparse=False):
        self.training_data = training_data
        self.num_agents = num_agents
        self.T = None
        self.sparse_dataformat = sparse

    # train local models
    def train_local_models(self, type='linear', params=[1.0]):

        # unpack hyper-parameters for various models
        if type == 'DTR':
            max_depth = params[0]
        elif type == 'NN':
            alpha = params[0]
            solver = params[1]
            max_iter = params[2]
            hidden_layer_size = params[3]
        else:
            regularizer = params[0]

        self.local_models = []

        for k in range(self.num_agents):
            local_data = self.training_data[k]

            if type == 'linear':
                clf = LinearRegression().fit(local_data[0], local_data[1])

            elif type == 'ridge':
                clf = Ridge(alpha=regularizer).fit(local_data[0],
                                                   local_data[1])
            elif type == 'lasso':
                clf = Lasso(alpha=regularizer).fit(local_data[0],
                                                   local_data[1])

            elif type == 'NN':
                clf = MLPRegressor(alpha=alpha, solver=solver, max_iter=max_iter, hidden_layer_sizes=hidden_layer_size).fit(local_data[0],
                                                                                                                            local_data[1])

            elif type == 'DTR':
                clf = DecisionTreeRegressor(max_depth=max_depth).fit(
                    local_data[0], local_data[1])

            else:
                print('specify valid model type')

            self.local_models.append(clf)

    # evaluate every agent's predction
    def predict(self, test_point, index=[]):

        if self.sparse_dataformat == False:
            test_point = [test_point]

        if index == []:
            predictions = np.zeros(self.num_agents)

            for k in range(self.num_agents):

                prediction = self.local_models[k].predict(
                    test_point)
                predictions[k] = prediction

            return predictions

        prediction = self.local_models[index].predict(test_point)

        return prediction

    # iterative consensus finding procedure (Algorithm 1 in NeurIPS submission)
    def get_consensus_weights(self, test_point, num_neighbors, num_degroot_iter=30):

        # evaluation of mutual trust through local cross-validation
        local_mse = np.zeros([self.num_agents, self.num_agents])

        for k in range(self.num_agents):

            local_data = self.training_data[k]

            # find N nearest neighbors to x'
            neigh = NearestNeighbors(n_neighbors=num_neighbors)
            neigh.fit(local_data[0])

            if self.sparse_dataformat:
                out = neigh.kneighbors(
                    test_point, return_distance=True)
            else:
                out = neigh.kneighbors(
                    [test_point], return_distance=True)

            dist = out[0][0]
            idx_neighbors = out[1][0]

            # evaluate other agents on local data Di(x')
            for j in range(self.num_agents):

                mse = tol

                for l in idx_neighbors:
                    data_point_x = local_data[0][l]
                    data_point_y = local_data[1][l]

                    if self.sparse_dataformat == False:
                        data_point_x = [data_point_x]

                    pred = self.local_models[j].predict(data_point_x)
                    mse += (pred - data_point_y)**2

                if mse == 0:
                    print(j, k, mse, local_data[1][idx_neighbors])

                local_mse[k, j] = 1.0*mse / num_neighbors

        # construct trust matrix
        T = 1.0/local_mse
        T = normalize(T, norm="l1")

        # run power iterations to find consensus weights
        limT = np.linalg.matrix_power(T, num_degroot_iter)
        consensus_weights = np.mean(limT, axis=0)

        # alternative aggregation scheme MSE-avg
        totalmse = np.sum(local_mse, axis=0)
        mse_weights = 1.0/totalmse
        mse_weights = mse_weights/np.sum(mse_weights)

        # alternative aggregation scheme tau-avg
        mean_weights = np.mean(T, axis=0)

        return [consensus_weights, mean_weights, limT, T, mse_weights]

    # compute inverse-MSE weights on a given dataset
    def get_inverse_mse_weights(self, X_val, Y_val):
        num_samples_val = X_val.shape[0]

        predictions = np.zeros([self.num_agents, num_samples_val])
        mse = np.zeros(self.num_agents)

        for ex in range(num_samples_val):

            test_point = X_val[ex]

            if self.sparse_dataformat == False:
                test_point = [test_point]

            for k in range(self.num_agents):

                predictions[k, ex] = self.local_models[k].predict(test_point)

        for k in range(self.num_agents):
            mse[k] = max(np.mean((predictions[k, :]-Y_val)**2), tol)

        weights = 1.0/mse
        sum = np.sum(weights)

        return weights/sum
