import numpy as np


def generate_masks(N=1000, d=10):

    masks = {key:None for key in range(d)}
    for dim in range(d):
        masks[dim] = np.random.choice([0, 1], size=(N, d), p=[1. / 2, 1. / 2])
        masks[dim][ :, dim] = 1
        masks[dim] = np.unique(masks[dim], axis=0)
    return masks

def explain(model, input, masks, batch_size=1024):

    num_points = input.shape[0]
    num_dim = input.shape[1]
    explanations = np.zeros(shape=(num_points, num_dim,))

    for d in range(num_dim):
        prediction = []
        for mask in masks[d]:
            pred = model.predict(input * mask, batch_size=batch_size)
            prediction.append(pred)
        prediction = np.array(prediction)
        explanations[:, d] = prediction.mean(axis=0).max(axis=1)

    return explanations