from socket import LOCAL_PEERCRED
import numpy as np
import scipy as sp
import math
from sklearn.linear_model import Ridge, lars_path, lasso_path
from sklearn.utils import check_random_state
from itertools import chain, combinations

from explainer.WSVRss import weighted_svr
from svsvr import *

class SvsvlExpBase(object):
    """Class for learning a locally linear sparse model from perturbed data"""
    def __init__(self,
                 kernel_fn,
                 verbose=False,
                 random_state=None):
        """Init function

        Args:
            kernel_fn: function that transforms an array of distances into an
                        array of proximity values (floats).
            verbose: if true, print local prediction values from linear model.
            random_state: an integer or numpy.RandomState that will be used to
                generate random numbers. If None, the random state will be
                initialized using the internal numpy seed.
        """
        self.kernel_fn = kernel_fn
        self.verbose = verbose
        self.random_state = check_random_state(random_state)

    @staticmethod
    def generate_lars_path(weighted_data, weighted_labels):
        """Generates the lars path for weighted data.

        Args:
            weighted_data: data that has been weighted by kernel
            weighted_label: labels, weighted by kernel

        Returns:
            (alphas, coefs), both are arrays corresponding to the
            regularization parameter and coefficients, respectively
        """
        alphas, _, coefs = lars_path(weighted_data,
                                     weighted_labels,
                                     method='lasso',
                                     eps=0.01,
                                     verbose=False)

        # alphas, coefs, _ = lasso_path(weighted_data,
        #                              weighted_labels,
        #                              n_alphas=30,
        #                              verbose=False)
        return alphas, coefs

    def feature_selection(self, data, labels, weights, num_features, method):
        sample_no, feature_no = data.shape

        #print("Choquex + SVR")

        C = 100
        k_additivity= 10
        if feature_no > 30:
            k_additivity = 2
        temp = range(0,feature_no)
        pow_set = list(chain.from_iterable(combinations(temp, r) for r in range(k_additivity+1)))
        pow_set.remove(())

        #data_mobius = mobius_transformation(data, k=k_additivity)

        #choq_kernel = ChoquetKernel(data, type="binary", k_additivity=feature_no).get_kernel()


        #m = np.inner(data_mobius.T, alpha).squeeze()

        # sv = np.zeros((feature_no,))
        # lens = np.array([1/len(p) for p in pow_set], dtype=float)
        # for j in range(feature_no):
        #     idx = [idx for idx, pset in enumerate(pow_set) if set(pow_set[j]).issubset(pset)]
        #     sv[j] = np.inner(m[idx], lens[idx])

        # nonzero = sv.nonzero()[0]



        gmat = Shapley_kernel(data, data, feature_type='binary')
        alpha, intercept = weighted_svr(gmat, labels, C, weights)
        sv = Shapley_value(data, alpha, list(range(sample_no)))

        sf_no = np.min((feature_no, np.max((num_features, 10))))
        sf = np.argsort(-np.abs(sv.squeeze()))[:sf_no]
                    
        if method == 'svr':
            used_features = sf[:feature_no]
            return used_features, sv[used_features], sv, local_pred, intercept

        elif method == "lasso":
            ## Transforming data

            #power set of the selected features
            pow_set_sf = list(chain.from_iterable(combinations(sf, r) for r in range(0,sf_no+1)))
            pow_set_sf.remove(())
            data_hat_sf, _ = data_transformation(data, pow_set=pow_set_sf)
            
            ## Applying the weights coming from the distance of the samples to the instance under explanation
            weighted_data = ((data_hat_sf - np.average(data_hat_sf, axis=0, weights=weights))* np.sqrt(weights[:, np.newaxis]))
            weighted_labels = ((labels - np.average(labels, weights=weights))* np.sqrt(weights))

            ## Getting coefficients by lars path
            m = np.empty(data_hat_sf[0,:].shape)
            _, coefs = self.generate_lars_path(weighted_data, weighted_labels)
            
            lens = np.array([1/len(p) for p in pow_set_sf], dtype=float)
            for i in range(len(coefs.T)-1, 0, -1):
                m = coefs.T[i]
                sv = np.zeros((sf_no,))
                for j in range(sf_no):
                    idx = [idx for idx, pset in enumerate(pow_set_sf) if set(pow_set_sf[j]).issubset(pset)]
                    sv[j] = np.inner(m[idx], lens[idx])

                nonzero = sv.nonzero()[0]
                if len(nonzero) <= num_features:
                    break
            used_features = sf[nonzero]
            local_pred = np.inner(m, data_hat_sf[0,:])
            intercept = 0
            interaction = np.zeros((len(m),))
            all_importance = np.zeros((len(m),1))
            
            for j in range(len(m)):
                values = [( m[idx] / (len(pset) - len(pow_set_sf[j]) + 1) ) for idx, pset in enumerate(pow_set_sf) if set(pow_set_sf[j]).issubset(pset)]
                interaction[j] = np.sum(values)
            
            for j in range(len(m)):
                values = [interaction[idx] for idx, pset in enumerate(pow_set_sf) if set(pset).issubset(pow_set_sf[j])]
                all_importance[j] = np.sum(values)
            
            interaction_cont = sorted(zip(pow_set_sf, interaction), key= lambda x: (np.abs(x[1]), len(x[0]) ), reverse=True )
            
            filtered = filter(lambda o: np.abs(o[1]) > 10e-5, interaction_cont)
            
            explanation = list(filtered)
            mobius = list(zip(pow_set_sf, m))
            all_importance = list(zip(pow_set_sf, all_importance))

            return explanation, local_pred, m, intercept, mobius, all_importance


    def explain_instance_with_data(self,
                                   neighborhood_data,
                                   neighborhood_labels,
                                   distances,
                                   label,
                                   num_features,
                                   feature_selection='svr',
                                   model_regressor=None):

        weights = self.kernel_fn(distances)
        labels_column = neighborhood_labels[:, label]
        
        explanation, local_pred, m, intercept_, mobius, all_importance = self.feature_selection(neighborhood_data,
                                                                                    labels_column,
                                                                                    weights,
                                                                                    num_features,
                                                                                    feature_selection)
        prediction_score = 0

        

        return (explanation, m, intercept_, prediction_score, local_pred, mobius, all_importance)
