import numpy as np


def remove(image, attribution, mean, percentile, keep=False, return_mask=False, gray=False):
    """
    images       : tensor of shape [C,H,W]
    attributions : tensor of shape [H,W]
    mean         : mean of dataset
    percentile   : scalar between 0 and 100, inclusive
    keep         : if true keep q percent; otherwise remove q percent
    """

    # Convert to 1D nummpy array
    modified_image = np.copy(image)
    if gray:
        pixels_replace_threshold = int(percentile * (image.size / 3) / 100)
        attribution_tmp = np.array(np.ravel(np.copy(attribution[0])))
        mask = np.zeros(attribution_tmp.shape, dtype=bool)
    else:
        pixels_replace_threshold = int(percentile * image.size / 100)
        attribution_tmp = np.array(np.ravel(np.copy(attribution)))
        mask = np.zeros(attribution_tmp.shape, dtype=bool)

    if keep:
        # Todo
        lower_attribution_index = (attribution_tmp).argsort()[:pixels_replace_threshold][
                                  ::-1]  # Indices of lowest values
        mask[lower_attribution_index] = True
    else:
        higher_attribution_index = attribution_tmp.argsort()[-pixels_replace_threshold:][::-1]
        mask[higher_attribution_index] = True

    if gray:
        mask = mask.reshape(image[0].shape)
        # sum = 0
        for i in range(3):  # ToDo - Dont hardcode channels
            # sum += np.count_nonzero(mask)
            modified_image[i, mask] = mean[i]
    else:
        mask = mask.reshape(image.shape)
        # sum = 0
        for i in range(3):  # ToDo - Dont hardcode channels
            # sum += np.count_nonzero(mask[i])
            modified_image[i, mask[i]] = mean[i]
    # print(sum)
    if return_mask:
        return modified_image, mask
    return modified_image