# TODO: heapq in numba does not yet support Typed Lists so we can move to them yet...
import heapq

import numba.typed
import numpy as np
from numba import njit

from shap._serializable import Deserializer, Serializer
from shap.utils import assert_import, record_import_error, safe_isinstance
from shap.utils._exceptions import DimensionError
from shap.maskers._masker import Masker

try:
    import cv2
except ImportError as e:
    record_import_error("cv2", "cv2 could not be imported!", e)


class Image(Masker):
    """Masks out image regions with blurring or inpainting."""

    def __init__(self, mask_value, shape=None):
        """Build a new Image masker with the given masking value.

        Parameters
        ----------
        mask_value : np.array, "blur(kernel_xsize, kernel_xsize)", "inpaint_telea", or "inpaint_ns"
            The value used to mask hidden regions of the image.
        shape : None or tuple
            If the mask_value is an auto-generated masker instead of a dataset then the input
            image shape needs to be provided.

        """
        if shape is None:
            if isinstance(mask_value, str):
                raise TypeError("When the mask_value is a string the shape parameter must be given!")
            self.input_shape = mask_value.shape # the (1,) is because we only return a single masked sample to average over
        else:
            self.input_shape = shape

        self.input_mask_value = mask_value

        # This is the shape of the masks we expect
        self.shape = (1, np.prod(self.input_shape)) # the (1, ...) is because we only return a single masked sample to average over

        self.image_data = True

        self.blur_kernel = None
        self._blur_value_cache = None
        if issubclass(type(mask_value), np.ndarray):
            self.mask_value = mask_value.flatten()
        elif isinstance(mask_value, str):
            assert_import("cv2")
            self.mask_value = mask_value
            if mask_value.startswith("blur("):
                self.blur_kernel = tuple(map(int, mask_value[5:-1].split(",")))
        else:
            self.mask_value = np.ones(self.input_shape).flatten() * mask_value
        self.build_partition_tree()

        # note if this masker can use different background for different samples
        self.fixed_background = not isinstance(self.mask_value, str)

        #self.scratch_mask = np.zeros(self.input_shape[:-1], dtype=bool)
        self.last_xid = None

        # flag that we return outputs that will not get changed by later masking calls
        self.immutable_outputs = True

    def __call__(self, mask, x):

        if safe_isinstance(x, "torch.Tensor"):
            x = x.cpu().numpy()

        if np.prod(x.shape) != np.prod(self.input_shape):
            raise DimensionError("The length of the image to be masked must match the shape given in the " + \
                            "ImageMasker constructor: "+" * ".join([str(i) for i in x.shape])+ \
                            " != "+" * ".join([str(i) for i in self.input_shape]))

        # unwrap single element lists (which are how single input models look in multi-input format)
        if isinstance(x, list) and len(x) == 1:
            x = x[0]

        # we preserve flattened inputs as flattened and full-shaped inputs as their original shape
        in_shape = x.shape
        if len(x.shape) > 1:
            x = x.ravel()

        # if mask is not given then we mask the whole image
        if mask is None:
            mask = np.zeros(np.prod(x.shape), dtype=bool)

        if isinstance(self.mask_value, str):
            if self.blur_kernel is not None:
                if self.last_xid != id(x):
                    self._blur_value_cache = cv2.blur(x.reshape(self.input_shape), self.blur_kernel).ravel()
                    self.last_xid = id(x)
                out = x.copy()
                out[~mask] = self._blur_value_cache[~mask]

            elif self.mask_value == "inpaint_telea":
                out = self.inpaint(x, ~mask, "INPAINT_TELEA")
            elif self.mask_value == "inpaint_ns":
                out = self.inpaint(x, ~mask, "INPAINT_NS")
            elif self.mask_value == "mdp":
                out = x.copy()
                mean_rgb = np.mean(out[~mask].reshape(-1, 3), axis=0)  # Compute mean for each RGB channel
                farthest_rgb_point = self.farthest_rgb(mean_rgb)  # Get the farthest RGB corner
                reshaped_mask = out[~mask].reshape(-1, 3)  # Reshape to match the RGB structure
                reshaped_mask[:] = farthest_rgb_point  # Assign the farthest RGB value
                out[~mask] = reshaped_mask.ravel()  # Flatten and assign back to the original array
            elif self.mask_value == "median_mdp":
                out = x.copy()
                mean_rgb = np.median(out[~mask].reshape(-1, 3), axis=0)  # Compute mean for each RGB channel
                farthest_rgb_point = self.farthest_rgb(mean_rgb)  # Get the farthest RGB corner
                reshaped_mask = out[~mask].reshape(-1, 3)  # Reshape to match the RGB structure
                reshaped_mask[:] = farthest_rgb_point  # Assign the farthest RGB value
                out[~mask] = reshaped_mask.ravel()  # Flatten and assign back to the original array
        else:
            out = x.copy()
            out[~mask] = self.mask_value[~mask]

        return (out.reshape(1, *in_shape),)

    def inpaint(self, x, mask, method):
        """Fill in the masked parts of the image through inpainting."""
        reshaped_mask = mask.reshape(self.input_shape).astype(np.uint8).max(2)
        if reshaped_mask.sum() == np.prod(self.input_shape[:-1]):
            out = x.reshape(self.input_shape).copy()
            out[:] = out.mean((0, 1))
            return out.ravel()

        return cv2.inpaint(
            x.reshape(self.input_shape).astype(np.uint8),
            reshaped_mask,
            inpaintRadius=3,
            flags=getattr(cv2, method)
        ).astype(x.dtype).ravel()
    
    def farthest_rgb(self, mean_rgb):
        # Define all 8 corners of the RGB cube 
        corners = np.array([
            [0, 0, 0],  # Black
            [1, 0, 0],   # Red
            [0, 1, 0],   # Green
            [0, 0, 1],   # Blue
            [1, 1, 0],    # Yellow
            [1, 0, 1],    # Magenta
            [0, 1, 1],    # Cyan
            [1, 1, 1]      # White
        ])
        # Calculate distances from the mean RGB to each corner
        distances = np.linalg.norm(corners - mean_rgb, axis=1)
        
        # Find the farthest corner
        return corners[np.argmax(distances)]


    def build_partition_tree(self):
        """This partitions an image into a herarchical clustering based on axis-aligned splits."""
        xmin = 0
        xmax = self.input_shape[0]
        ymin = 0
        ymax = self.input_shape[1]
        zmin = 0
        zmax = self.input_shape[2]
        #total_xwidth = xmax - xmin
        total_ywidth = ymax - ymin
        total_zwidth = zmax - zmin
        q = numba.typed.List([(0, xmin, xmax, ymin, ymax, zmin, zmax, -1, False)])
        M = int((xmax - xmin) * (ymax - ymin) * (zmax - zmin))
        clustering = np.zeros((M - 1, 4))
        _jit_build_partition_tree(xmin, xmax, ymin, ymax, zmin, zmax, total_ywidth, total_zwidth, M, clustering, q)
        self.clustering = clustering

    def save(self, out_file):
        """Write a Image masker to a file stream."""
        super().save(out_file)

        # Increment the version number when the encoding changes!
        with Serializer(out_file, "shap.maskers.Image", version=0) as s:
            s.save("mask_value", self.input_mask_value)
            s.save("shape", self.input_shape)

    @classmethod
    def load(cls, in_file, instantiate=True):
        """Load a Image masker from a file stream."""
        if instantiate:
            return cls._instantiated_load(in_file)

        kwargs = super().load(in_file, instantiate=False)
        with Deserializer(in_file, "shap.maskers.Image", min_version=0, max_version=0) as s:
            kwargs["mask_value"] = s.load("mask_value")
            kwargs["shape"] = s.load("shape")
        return kwargs

@njit
def _jit_build_partition_tree(xmin, xmax, ymin, ymax, zmin, zmax, total_ywidth, total_zwidth, M, clustering, q):
    """This partitions an image into a herarchical clustering based on axis-aligned splits."""
    # heapq.heappush(q, (0, xmin, xmax, ymin, ymax, zmin, zmax, -1, False))

    # q.put((0, xmin, xmax, ymin, ymax, zmin, zmax, -1, False))
    ind = len(clustering) - 1
    while len(q) > 0: # q.empty()
        _, xmin, xmax, ymin, ymax, zmin, zmax, parent_ind, is_left =  heapq.heappop(q)
        # _, xmin, xmax, ymin, ymax, zmin, zmax, parent_ind, is_left = q.get()

        if parent_ind >= 0:
            clustering[parent_ind, 0 if is_left else 1] = ind + M

        # make sure we line up with a flattened indexing scheme
        if ind < 0:
            assert -ind - 1 == xmin * total_ywidth * total_zwidth + ymin * total_zwidth + zmin

        xwidth = xmax - xmin
        ywidth = ymax - ymin
        zwidth = zmax - zmin
        if xwidth == 1 and ywidth == 1 and zwidth == 1:
            pass
        else:

            # by default our ranges remain unchanged
            lxmin = rxmin = xmin
            lxmax = rxmax = xmax
            lymin = rymin = ymin
            lymax = rymax = ymax
            lzmin = rzmin = zmin
            lzmax = rzmax = zmax

            # split the xaxis if it is the largest dimension
            if xwidth >= ywidth and xwidth > 1:
                xmid = xmin + xwidth // 2
                lxmax = xmid
                rxmin = xmid

            # split the yaxis
            elif ywidth > 1:
                ymid = ymin + ywidth // 2
                lymax = ymid
                rymin = ymid

            # split the zaxis only when the other ranges are already width 1
            else:
                zmid = zmin + zwidth // 2
                lzmax = zmid
                rzmin = zmid

            lsize = (lxmax - lxmin) * (lymax - lymin) * (lzmax - lzmin)
            rsize = (rxmax - rxmin) * (rymax - rymin) * (rzmax - rzmin)

            heapq.heappush(q, (-lsize, lxmin, lxmax, lymin, lymax, lzmin, lzmax, ind, True))
            heapq.heappush(q, (-rsize, rxmin, rxmax, rymin, rymax, rzmin, rzmax, ind, False))
            # q.put((-lsize, lxmin, lxmax, lymin, lymax, lzmin, lzmax, ind, True))
            # q.put((-rsize, rxmin, rxmax, rymin, rymax, rzmin, rzmax, ind, False))

        ind -= 1

    # fill in the group sizes
    for i in range(len(clustering)):
        li = int(clustering[i, 0])
        ri = int(clustering[i, 1])
        lsize = 1 if li < M else clustering[li-M, 3]
        rsize = 1 if ri < M else clustering[ri-M, 3]
        clustering[i, 3] = lsize + rsize
