import cv2
import imutils
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import PiecewiseAffineTransform, warp

def resize(img, target_size=(512, 512), keep_aspect_ratio=True):
    """
    Function resizes image to a given target_size.

    Args:
        img: ndarray
        target_size: tuple (height, width), optional. Default: (512,512)
        keep_aspect_ratio: bool, optional. Default: True

    Returns:
        ndarray of target_size
    """

    def resize_without_keeping_aspect_ratio(img):
        """
        Function to resize image without preserving aspect ratio.
        """
        ht, wt = target_size
        resized_img = cv2.resize(img, (wt, ht))

        return resized_img

    def patch(img):
        """
        Function force resize the image if the output does not match target size.
        """
        # resize image preserving the aspect ratio
        ht, wt = target_size  # target height and width
        hi, wi, *channels = img.shape  # initial image height, width, channels

        if (ht, wt) == (hi, wi):
            return img  # do nothing
        else:
            return resize_without_keeping_aspect_ratio(img)  # force resize

    if keep_aspect_ratio is False:
        return resize_without_keeping_aspect_ratio(img)

    else:
        # resize image preserving the aspect ratio
        ht, wt = target_size  # target height and width
        hi, wi, *channels = img.shape  # initial image height, width, channels

        if hi == wi:  # (image width = image height)
            # image is square, aspect ratio is preserved by definition
            img = imutils.resize(img, height=ht, width=wt)
            return patch(img)

        elif wi > hi:  # (image width > image height)
            # resize width to match target width
            img = imutils.resize(img, width=wt)

        elif wi < hi:  # (image width < image height)
            # resize height to match target height
            img = imutils.resize(img, height=ht)

        # recompute resized image properties
        hi, wi, *channels = img.shape  # initial image height, width, channels

        if len(channels) > 0:
            # image has multiple channels
            c = channels[0]  # number of channels
            image_array = np.zeros((ht, wt, c))  # zeros array
            image_array[0:hi, 0:wi, :] = img  # fill array with image

        else:
            image_array = np.zeros((ht, wt))  # zeros array
            image_array[0:hi, 0:wi] = img  # fill array with image

        # shift image contents to the center
        if hi == ht:  # (image height = target height)
            shift = (wt-wi) / 2
            image_array = np.roll(image_array, int(shift), axis=1)

        else:
            shift = (ht - hi) / 2
            image_array = np.roll(image_array, int(shift), axis=0)

        return patch(image_array)