import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
import random
from utils.utils import cvtColor, preprocess_input


class Dataset(Dataset):
    def __init__(self, root_path, data_lines, input_shape_reimg, input_shape_qimg, num_classes, train):
        super(Dataset, self).__init__()
        self.data_lines = data_lines
        self.input_shape_reimg = input_shape_reimg
        self.input_shape_qimg = input_shape_qimg
        self.num_classes = num_classes
        self.length = len(self.data_lines)
        self.train = train
        self.root_path = root_path

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        index = index % self.length

        reimg, box, _ = self.get_random_data(self.root_path, self.data_lines[index], self.input_shape_reimg,
                                             random=self.train, query=False)
        q_img, q_xy, q_mask = self.get_random_data(self.root_path, self.data_lines[index], self.input_shape_qimg,
                                                   random=False, query=True)
        bev_query, _, _ = self.get_random_data(self.root_path, self.data_lines[index], [256, 256], random=False,
                                               query=False, drone=True)
        # _, queryimg_name, rsimg_name, _, click_xy, bbox, _, cls_name = self.data_lines[index]

        reimg = np.transpose(preprocess_input(np.array(reimg, dtype=np.float32)), (2, 0, 1))
        q_img = np.transpose(preprocess_input(np.array(q_img, dtype=np.float32)), (2, 0, 1))
        bev_query = np.transpose(preprocess_input(np.array(bev_query, dtype=np.float32)), (2, 0, 1))
        box = np.array(box, dtype=np.float32)

        box, re_mask = self.process_box(box, query=False)
        qmask_train = self.process_box(q_xy, query=True)
        q_mask = np.concatenate((qmask_train, np.expand_dims(q_mask, 0)), 0)
        q_xys = np.array([q_xy[1] / self.input_shape_qimg[0], q_xy[0] / self.input_shape_qimg[1]])

        return q_img, reimg, q_mask, box, re_mask, q_xys, q_xy, bev_query

    def process_box(self, box, query=True):
        if query:
            click_hw = (int(box[1]), int(box[0]))
            mat_clickhw = np.zeros((1, self.input_shape_qimg[0], self.input_shape_qimg[1]), dtype=np.float32)
            mat_clickhw[0, click_hw[0], click_hw[1]] = 1
            click_h = [pow(one - click_hw[0], 2) for one in range(self.input_shape_qimg[0])]
            click_w = [pow(one - click_hw[1], 2) for one in range(self.input_shape_qimg[1])]
            norm_hw = pow(
                self.input_shape_qimg[0] * self.input_shape_qimg[0] + self.input_shape_qimg[1] * self.input_shape_qimg[
                    1], 0.5)
            for i in range(self.input_shape_qimg[0]):
                for j in range(self.input_shape_qimg[1]):
                    tmp_val = 1 - (pow(click_h[i] + click_w[j], 0.5) / norm_hw)
                    mat_clickhw[0, i, j] = tmp_val * tmp_val
            return mat_clickhw
        else:
            mask = np.zeros((1, self.input_shape_reimg[1], self.input_shape_reimg[0]))
            for single_box in box:
                mask[0, int(single_box[1]):int(single_box[3]), int(single_box[0]):int(single_box[2])] = 1
            if len(box) != 0:
                box[:, [0, 2]] = box[:, [0, 2]] / self.input_shape_reimg[0]
                box[:, [1, 3]] = box[:, [1, 3]] / self.input_shape_reimg[1]
                box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
                box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2
                box[:, 4] = 0
            else:
                box = np.zeros([0, 5], np.float32)
                box[:, 4] = 1
            return box, mask

    def rand(self, a=0, b=1):
        return np.random.rand() * (b - a) + a

    def get_random_data(self, root_path, data_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True,
                        query=True, drone=False):
        _, queryimg_name, rsimg_name, _, click_xy, bbox, _, cls_name = data_line

        mask = None
        if query:
            image = Image.open(root_path + '/query/' + queryimg_name)
            mask = Image.open(root_path + '/masks/' + queryimg_name[:-4] + '.png')
        elif drone:
            image = Image.open(root_path.replace('SVI', 'DroneAerial') + '/query/' + queryimg_name)
        else:
            image = Image.open(root_path + '/satellite/' + rsimg_name)

        image = cvtColor(image)
        iw, ih = image.size
        h, w = input_shape

        if query:
            box = [click_xy[0], click_xy[1]]
            mask = np.array(mask, np.float32)
            mask[mask > 0] = 1
            mask = Image.fromarray(mask)
        else:
            bbox.append(0)
            box = np.array([[x for x in bbox]])

        if not random:
            scale = min(w / iw, h / ih)
            nw = int(iw * scale)
            nh = int(ih * scale)
            dx = (w - nw) // 2
            dy = (h - nh) // 2

            image = image.resize((nw, nh), Image.BICUBIC)
            new_image = Image.new('RGB', (w, h), (128, 128, 128))
            new_image.paste(image, (dx, dy))
            image_data = np.array(new_image, np.float32)

            if query:
                mask = mask.resize((nw, nh), Image.NEAREST)
                new_mask = Image.new('L', (w, h), 0)
                new_mask.paste(mask, (dx, dy))
                mask = np.array(new_mask, np.float32)

                if len(box) > 0:
                    box[0] = box[0] * nw / iw + dx
                    box[1] = box[1] * nh / ih + dy
                    if box[0] < 0:
                        box[0] = 0
                    if box[1] < 0:
                        box[1] = 0
                    if box[0] > w:
                        box[0] = w
                    if box[1] > h:
                        box[1] = h

                return image_data, box, mask

            else:
                if len(box) > 0:
                    box[:, [0, 2]] = box[:, [0, 2]] * nw / iw + dx
                    box[:, [1, 3]] = box[:, [1, 3]] * nh / ih + dy
                    box[:, 0:2][box[:, 0:2] < 0] = 0
                    box[:, 2][box[:, 2] > w] = w
                    box[:, 3][box[:, 3] > h] = h
                    box_w = box[:, 2] - box[:, 0]
                    box_h = box[:, 3] - box[:, 1]
                    box = box[np.logical_and(box_w > 1, box_h > 1)]  # discard invalid box

                return image_data, box, mask

        new_ar = iw / ih * self.rand(1 - jitter, 1 + jitter) / self.rand(1 - jitter, 1 + jitter)
        scale = self.rand(.25, 2)
        if new_ar < 1:
            nh = int(scale * h)
            nw = int(nh * new_ar)
        else:
            nw = int(scale * w)
            nh = int(nw / new_ar)
        image = image.resize((nw, nh), Image.BICUBIC)

        dx = int(self.rand(0, w - nw))
        dy = int(self.rand(0, h - nh))
        new_image = Image.new('RGB', (w, h), (128, 128, 128))
        new_image.paste(image, (dx, dy))
        image = new_image

        flip = self.rand() < .5
        if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)

        image_data = np.array(image, np.uint8)

        r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1

        hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
        dtype = image_data.dtype

        x = np.arange(0, 256, dtype=r.dtype)
        lut_hue = ((x * r[0]) % 180).astype(dtype)
        lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
        lut_val = np.clip(x * r[2], 0, 255).astype(dtype)

        image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
        image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)

        if query:
            if len(box) > 0:
                box[0] = box[0] * nw / iw + dx
                box[1] = box[1] * nh / ih + dy
                if flip: box[0] = w - box[0]
                if box[0] < 0: box[0] = 0
                if box[1] < 0: box[1] = 0
                if box[0] > w: box[0] = w
                if box[1] > h: box[1] = h
        else:
            if len(box) > 0:
                box[:, [0, 2]] = box[:, [0, 2]] * nw / iw + dx
                box[:, [1, 3]] = box[:, [1, 3]] * nh / ih + dy
                if flip: box[:, [0, 2]] = w - box[:, [2, 0]]
                box[:, 0:2][box[:, 0:2] < 0] = 0
                box[:, 2][box[:, 2] > w] = w
                box[:, 3][box[:, 3] > h] = h
                box_w = box[:, 2] - box[:, 0]
                box_h = box[:, 3] - box[:, 1]
                box = box[np.logical_and(box_w > 1, box_h > 1)]

        return image_data, box, mask


def dataset_collate(batch):
    q_imgs = []
    bboxes = []
    re_imgs = []
    q_masks = []
    re_masks = []
    q_xy_masks = []
    q_xys = []
    bev_querys = []
    for q_img, reimg, q_mask, box, re_mask, q_xy_mask, q_xy, bev_query in batch:
        q_imgs.append(q_img)
        re_imgs.append(reimg)
        q_masks.append(q_mask)
        bboxes.append(box)
        re_masks.append(re_mask)
        q_xy_masks.append(q_xy_mask)
        q_xys.append(q_xy)
        bev_querys.append(bev_query)
    q_imgs = torch.from_numpy(np.array(q_imgs)).type(torch.FloatTensor)
    re_imgs = torch.from_numpy(np.array(re_imgs)).type(torch.FloatTensor)
    re_masks = torch.from_numpy(np.array(re_masks)).type(torch.FloatTensor)
    q_masks = torch.from_numpy(np.array(q_masks)).type(torch.FloatTensor)
    q_xy_masks = torch.from_numpy(np.array(q_xy_masks)).type(torch.FloatTensor)
    bev_querys = torch.from_numpy(np.array(bev_querys)).type(torch.FloatTensor)
    q_xys = torch.from_numpy(np.array(q_xys)).type(torch.FloatTensor)
    targets = [
        {
            "labels": torch.from_numpy(ann[:, 4]).type(torch.LongTensor),
            "boxes": torch.from_numpy(ann[:, :4]).type(torch.FloatTensor),
        } for ann in bboxes
    ]

    return q_imgs, re_imgs, q_masks, targets, re_masks, q_xy_masks, q_xys, bev_querys
