import numpy as np
import math
from itertools import chain, combinations

def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))



def general_single_datapoint(f, x, feat, weight):
    n = len(x)
    x = np.array(x, dtype=float)
    s = list(set(range(n)) - set([feat]))
    ret = 0.
    for subset in powerset(s):
        x_cut = [x[i] if i in subset else np.nan for i in range(n)]

        v2 = f(x_cut)

        x_cut[feat] = x[feat]
        v1 = f(x_cut)
        #print(x, x_cut, subset, v1,v2, v1 - v2, weight(len(subset), n))
        ret += (v1 - v2) * weight(len(subset), n)
    return ret


# shap value for single datapoint
def shap_single_datapoint(f, x, feat):
    def w(s,n ):
        return math.factorial(s)* math.factorial(n- s -1) / math.factorial(n)
    return general_single_datapoint(f,x,feat, w)

# shap value for single datapoint
def banzhaf_single_datapoint(f, x, feat):
    def w(s,n ):
        return 1. / 2**(len(x)-1)
    return general_single_datapoint(f,x,feat, w)

def shap_vector(f, matrix, feat):
    return [shap_single_datapoint(f, x, feat) for x in matrix]


def banzhaf_vector(f, matrix, feat):
    return [banzhaf_single_datapoint(f, x, feat) for x in matrix]



def shap_abs(f, matrix, feat):
    return np.mean(np.abs(shap_vector(f, matrix, feat)))


# sum of banzhaf values for dataset
def banzhaf_abs(f, matrix, feat):
    return np.mean(np.abs(banzhaf_vector(f, matrix, feat)))

