"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import numpy as np
import random
from PIL import Image


def rand(a=0, b=1):
    return np.random.rand()*(b-a) + a


def letterbox_resize(image, target_size):
    """
    Resize image with unchanged aspect ratio using padding

    :param image: origin image to be resize
            PIL Image object containing image data
    :param target_size: target image size,
            tuple of format (width, height).
    :param return_padding_info: whether to return padding size & offset info
            Boolean flag to control return value
    :return new_image: resized PIL Image object.
    :return padding_size: padding image size (keep aspect ratio).
            will be used to reshape the ground truth bounding box
    :return offset: top-left offset in target image padding.
            will be used to reshape the ground truth bounding box
    """
    src_w, src_h = image.size
    target_w, target_h = target_size

    # calculate padding scale and padding offset
    scale = min(target_w/src_w, target_h/src_h)
    padding_w = int(src_w * scale)
    padding_h = int(src_h * scale)
    padding_size = (padding_w, padding_h)

    dx = (target_w - padding_w)//2
    dy = (target_h - padding_h)//2
    offset = (dx, dy)

    # create letterbox resized image
    image = image.resize(padding_size, Image.BICUBIC)
    new_image = Image.new('RGB', target_size, (128,128,128))
    new_image.paste(image, offset)

    return new_image


def random_resize_crop_pad(image, target_size, aspect_ratio_jitter=0.1, scale_jitter=0.7):
    """
    Randomly resize image and crop|padding to target size. It can
    be used for data augment in training data preprocess

    :param image: origin image to be resize
            PIL Image object containing image data
    :param target_size: target image size,
            tuple of format (width, height).
    :param aspect_ratio_jitter: jitter range for random aspect ratio,
            scalar to control the aspect ratio of random resized image.
    :param scale_jitter: jitter range for random resize scale,
            scalar to control the resize scale of random resized image.
    :return new_image: target sized PIL Image object.
    :return padding_size: random generated padding image size.
            will be used to reshape the ground truth bounding box
    :return padding_offset: random generated offset in target image padding.
            will be used to reshape the ground truth bounding box
    """
    target_w, target_h = target_size

    # generate random aspect ratio & scale for resize
    rand_aspect_ratio = (target_w/target_h * rand(1-aspect_ratio_jitter,1+aspect_ratio_jitter)) \
                        / (rand(1-aspect_ratio_jitter,1+aspect_ratio_jitter))
    rand_scale = rand(scale_jitter, 1/scale_jitter)

    # calculate random padding size and resize
    if rand_aspect_ratio < 1:
        padding_h = int(rand_scale * target_h)
        padding_w = int(padding_h * rand_aspect_ratio)
    else:
        padding_w = int(rand_scale * target_w)
        padding_h = int(padding_w / rand_aspect_ratio)
    padding_size = (padding_w, padding_h)
    image = image.resize(padding_size, Image.BICUBIC)

    # get random offset in padding image
    dx = int(rand(0, target_w - padding_w))
    dy = int(rand(0, target_h - padding_h))
    padding_offset = (dx, dy)

    # create target image
    new_image = Image.new('RGB', (target_w, target_h), (128,128,128))
    new_image.paste(image, padding_offset)

    return new_image, padding_size, padding_offset


def reshape_boxes(boxes, src_shape, target_shape, padding_shape, offset, horizontal_flip=False, vertical_flip=False):
    """
    Reshape bounding boxes from src_shape image to target_shape image,
    usually for training data preprocess

    :param boxes: Ground truth object bounding boxes,
            numpy array of shape (num_boxes, 5),
            box format (xmin, ymin, xmax, ymax, cls_id).
    :param src_shape: origin image shape,
            tuple of format (width, height).
    :param target_shape: target image shape,
            tuple of format (width, height).
    :param padding_shape: padding image shape,
            tuple of format (width, height).
    :param offset: top-left offset when padding target image.
            tuple of format (dx, dy).
    :param horizontal_flip: whether to do horizontal flip.
            boolean flag.
    :param vertical_flip: whether to do vertical flip.
            boolean flag.
    :return boxes: reshaped bounding box numpy array
    """
    if len(boxes)>0:
        src_w, src_h = src_shape
        target_w, target_h = target_shape
        padding_w, padding_h = padding_shape
        dx, dy = offset

        # shuffle and reshape boxes
        np.random.shuffle(boxes)
        boxes[:, [0,2]] = boxes[:, [0,2]]*padding_w/src_w + dx
        boxes[:, [1,3]] = boxes[:, [1,3]]*padding_h/src_h + dy
        # horizontal flip boxes if needed
        if horizontal_flip:
            boxes[:, [0,2]] = target_w - boxes[:, [2,0]]
        # vertical flip boxes if needed
        if vertical_flip:
            boxes[:, [1,3]] = target_h - boxes[:, [3,1]]

        # check box coordinate range
        boxes[:, 0:2][boxes[:, 0:2] < 0] = 0
        boxes[:, 2][boxes[:, 2] > target_w] = target_w
        boxes[:, 3][boxes[:, 3] > target_h] = target_h

        # check box width and height to discard invalid box
        boxes_w = boxes[:, 2] - boxes[:, 0]
        boxes_h = boxes[:, 3] - boxes[:, 1]
        boxes = boxes[np.logical_and(boxes_w>1, boxes_h>1)] # discard invalid box

    return boxes


def random_horizontal_flip(image, prob=.5):
    """
    Random horizontal flip for image

    :param image: origin image for horizontal flip
            PIL Image object containing image data
    :param prob: probability for random flip,
            scalar to control the flip probability.
    :return image: adjusted PIL Image object.
    :return flip: boolean flag for horizontal flip action
    """
    flip = rand() < prob
    if flip:
        image = image.transpose(Image.FLIP_LEFT_RIGHT)

    return image, flip


def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1):  # box1(4,n), box2(4,n)
    # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
    w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
    w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
    ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16))  # aspect ratio
    return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + 1e-16) > area_thr) & (ar < ar_thr)  # candidates


def merge_mosaic_bboxes(bboxes, crop_x, crop_y, image_size): # pylint: disable=R0912
    # adjust & merge mosaic samples bboxes as following area order:
    # -----------
    # |     |   |
    # |  0  | 3 |
    # |     |   |
    # -----------
    # |  1  | 2 |
    # -----------
    assert bboxes.shape[0] == 4, 'mosaic sample number should be 4'
    max_boxes = bboxes.shape[1]
    height, width = image_size
    merge_bbox = []
    for i in range(bboxes.shape[0]):
        for box in bboxes[i]:
            x_min, y_min, x_max, y_max = box[0], box[1], box[2], box[3]

            if i == 0: # bboxes[0] is for top-left area
                if y_min > crop_y or x_min > crop_x:
                    continue
                if y_min < crop_y < y_max:
                    y_max = crop_y
                if x_min < crop_x < x_max:
                    x_max = crop_x

            if i == 1: # bboxes[1] is for bottom-left area
                if y_max < crop_y or x_min > crop_x:
                    continue
                if y_min < crop_y < y_max:
                    y_min = crop_y
                if x_min < crop_x < x_max:
                    x_max = crop_x

            if i == 2: # bboxes[2] is for bottom-right area
                if y_max < crop_y or x_max < crop_x:
                    continue
                if y_min < crop_y < y_max:
                    y_min = crop_y
                if x_min < crop_x < x_max:
                    x_min = crop_x

            if i == 3: # bboxes[3] is for top-right area
                if y_min > crop_y or x_max < crop_x:
                    continue
                if y_min < crop_y < y_max:
                    y_max = crop_y
                if x_min < crop_x < x_max:
                    x_min = crop_x

            if abs(x_max-x_min) < max(10, width*0.01) or abs(y_max-y_min) < max(10, height*0.01):
                #if the adjusted bbox is too small, bypass it
                continue

            merge_bbox.append([x_min, y_min, x_max, y_max, box[4]])

    if len(merge_bbox) > max_boxes:
        merge_bbox = merge_bbox[:max_boxes]

    box_data = np.zeros((max_boxes,5))
    if len(merge_bbox) > 0:
        box_data[:len(merge_bbox)] = merge_bbox
    return box_data


def random_mosaic_augment(image_data, boxes_data, prob=.1):
    """
    Random add mosaic augment on batch images and boxes, from YOLOv4
    reference:
        https://github.com/klauspa/Yolov4-tensorflow/blob/master/data.py
        https://github.com/clovaai/CutMix-PyTorch
        https://github.com/AlexeyAB/darknet

    :param image_data: origin images for mosaic augment
            numpy array for normalized batch image data
    :param boxes_data: origin bboxes for mosaic augment
            numpy array for batch bboxes
    :param prob: probability for augment ,
            scalar to control the augment probability.
    :return image_data: augmented batch image data.
    :return boxes_data: augmented batch bboxes data.
    """
    do_augment = rand() < prob

    if do_augment:
        batch_size = len(image_data)
        assert batch_size >= 4, 'mosaic augment need batch size >= 4'

        def get_mosaic_samples():
            # random select 4 images from batch as mosaic samples
            random_index = random.sample(list(range(batch_size)), 4)

            random_images = []
            random_bboxes = []
            for idx in random_index:
                random_images.append(image_data[idx])
                random_bboxes.append(boxes_data[idx])
            return random_images, np.array(random_bboxes)

        min_offset = 0.2
        new_images = []
        new_boxes = []
        height, width = image_data[0].shape[:2]
        #each batch has batch_size images, so we also need to
        #generate batch_size mosaic images
        for _ in range(batch_size):
            images, bboxes = get_mosaic_samples()

            #crop_x = np.random.randint(int(width*min_offset), int(width*(1 - min_offset)))
            #crop_y = np.random.randint(int(height*min_offset), int(height*(1 - min_offset)))
            crop_x = int(random.uniform(int(width*min_offset), int(width*(1-min_offset))))  # nosec
            crop_y = int(random.uniform(int(height*min_offset), int(height*(1 - min_offset))))  # nosec

            merged_boxes = merge_mosaic_bboxes(bboxes, crop_x, crop_y, image_size=(height, width))
            #no valid bboxes, drop this loop
            #if merged_boxes is None:
                #i = i - 1
                #continue

            # crop out selected area as following mosaic sample images order:
            # -----------
            # |     |   |
            # |  0  | 3 |
            # |     |   |
            # -----------
            # |  1  | 2 |
            # -----------
            area_0 = images[0][:crop_y, :crop_x, :]
            area_1 = images[1][crop_y:, :crop_x, :]
            area_2 = images[2][crop_y:, crop_x:, :]
            area_3 = images[3][:crop_y, crop_x:, :]

            #merge selected area to new image
            area_left = np.concatenate([area_0, area_1], axis=0)
            area_right = np.concatenate([area_3, area_2], axis=0)
            merged_image = np.concatenate([area_left, area_right], axis=1)

            new_images.append(merged_image)
            new_boxes.append(merged_boxes)

        new_images = np.stack(new_images)
        new_boxes = np.array(new_boxes)
        image_data = new_images
        boxes_data = new_boxes

    return image_data, boxes_data

def normalize_image(image):
    """
    Normalize image array from 0 ~ 255
    to 0.0 ~ 1.0

    :param image: origin input image
            numpy image array with dtype=float, 0.0 ~ 255.0
    :return image: numpy image array with dtype=float, 0.0 ~ 1.0
    """
    image = image / 255.0
    return image
