import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from scipy import ndimage as ndi
from skimage.measure import label


def resize_image(img, factor, interpolation=Image.BILINEAR):
    """Resize PIL image by given factor.

    :param img: PIL image
    :param factor: float resize factor
    :return: re-sized PIL image
    """
    target_width = int(factor * img.size[0])
    target_height = int(factor * img.size[1])

    target_width = target_width if np.mod(target_width, 2) == 0 else target_width + 1
    target_height = target_height if np.mod(target_height, 2) == 0 else target_height + 1

    img = img.resize((target_width, target_height), interpolation)
    return img


def compute_boundary_weights(mask, factor=5.0, verbose=False):
    """
    Compute boundary weight mask from binary mask image.
    (required for SegmentationDataset)

    Parameters:
    -----------
        mask: numpy ndarray, binary mask image
        factor: float, scales the influence of boundary weights
        verbose: boolen, if True show weight map with matplotlib

    Returns:
    --------
        dist_img: numpy ndarray with the same shape as mask containing per-pixel weights.
    """

    mask = mask.copy()
    mask = (mask != 0).astype(np.int)

    split_label_img = label(mask)

    dist_img = np.zeros_like(mask, dtype=np.float32)
    for l in np.unique(split_label_img)[1::]:
        l_mask = (split_label_img != l).astype(np.uint8)

        l_dist = -ndi.distance_transform_bf(l_mask)
        l_dist = np.exp(l_dist)
        l_dist *= factor

        dist_img += l_dist

    dist_img[mask == 1] = 0
    dist_img += 1

    if verbose:
        plt.figure("Image")
        plt.clf()

        plt.subplot(121)
        plt.imshow(mask)
        plt.colorbar()

        plt.subplot(122)
        plt.imshow(dist_img, cmap="viridis", interpolation="nearest")
        plt.colorbar()

        plt.show(block=True)

    return dist_img


def get_weighted_crop_params(mask, output_size, factor=0.1, power=4.0, verbose=False):
    """
    Over-sample regions containing foreground pixels.
    (required for SegmentationDataset)
    (useful in combination with transforms.functional.crop)

    Parameters:
    -----------
        mask: numpy ndarray or PIL image, binary mask image
        output_size: integer tuple, crop size (height, width)
        factor: float, resize factor for computation (operation can be computationally very expensive)
        power: float,
        verbose: boolen, if True show weight map with matplotlib

    Returns:
    --------
        crop: integer tuple (i, j, th, tw), i ... row index, j ... column index, th ... height, tw ... width
    """

    if not isinstance(object, Image):
        mask = Image.fromarray(mask)

    w, h = mask.size
    th, tw = output_size

    if w == tw and h == th:
        return 0, 0, h, w

    # resize mask
    mask_rwz = np.array(resize_image(mask, factor=factor))
    mask_rwz = (mask_rwz > 0).astype(np.uint8)

    # compute sampling density
    density = ndi.distance_transform_bf(1.0 - mask_rwz)
    density = density.max() - density
    density = density ** power
    density /= density.sum()

    # draw samples
    probs = density.flatten()
    indices = np.arange(0, len(probs)).astype(np.float32)
    choice = np.random.choice(indices, size=100, p=probs)[0]

    # get center indices
    ci_rsz = float(choice) / density.shape[1]           # row center
    cj_rsz = float(np.mod(choice, density.shape[1]))    # column center

    ci = ci_rsz * (1.0 / factor)    # row center
    cj = cj_rsz * (1.0 / factor)    # column center

    i = int(ci - th // 2)   # first row index
    j = int(cj - tw // 2)   # second column index

    # clip indices
    i = np.clip(i, 0, h - th)
    j = np.clip(j, 0, w - tw)

    if verbose:
        plt.figure("Image")
        plt.clf()

        plt.subplot(121)
        plt.imshow(mask)
        plt.plot([j, j+tw], [i, i+th], "mo-", alpha=1.0)
        plt.plot(cj, ci, "mo", alpha=1.0)
        plt.colorbar()

        plt.subplot(122)
        plt.imshow(density, cmap="viridis", interpolation="nearest", vmin=0)
        plt.plot(cj_rsz, ci_rsz, "mo", alpha=1.0)
        plt.colorbar()

        plt.show(block=True)

    return i, j, th, tw


def pil_to_convnet(img):
    """
    Convert PIL Image to numpy array with appropriate dimensions
    for convnet processing.

    :param img: PIL image
    """

    # convert to array
    img = np.array(img)

    # add channel dimension
    if img.ndim == 2:
        img = img[:, :, None]

    # flip channels for conv-net
    return img.transpose((2, 0, 1))
