"""The data layer used during training to train a Fast R-CNN network.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch.utils.data as data
from PIL import Image
import torch

from model.utils.config import cfg
from roi_data_layer.minibatch import get_minibatch, get_minibatch
from model.rpn.bbox_transform import bbox_transform_inv, clip_boxes

import numpy as np
import random
import time
import copy
import pdb


class roibatchLoader(data.Dataset):
    def __init__(self, roidb, ratio_list, ratio_index, batch_size, num_classes, training=True, normalize=None):
        self._roidb = roidb
        self._num_classes = num_classes
        # we make the height of image consistent to trim_height, trim_width
        self.trim_height = cfg.TRAIN.TRIM_HEIGHT
        self.trim_width = cfg.TRAIN.TRIM_WIDTH
        self.max_num_box = cfg.MAX_NUM_GT_BOXES
        self.max_num_relation = cfg.MAX_NUM_RELATIONS
        self.training = training
        self.normalize = normalize
        self.ratio_list = ratio_list
        self.ratio_index = ratio_index
        self.batch_size = batch_size
        self.data_size = len(self.ratio_list)

        # given the ratio_list, we want to make the ratio same for each batch.
        self.ratio_list_batch = torch.Tensor(self.data_size).zero_()
        num_batch = int(np.ceil(len(ratio_index) / batch_size))
        for i in range(num_batch):
            left_idx = i * batch_size
            right_idx = min((i + 1) * batch_size - 1, self.data_size - 1)

            if ratio_list[right_idx] < 1:
                # for ratio < 1, we preserve the leftmost in each batch.
                target_ratio = ratio_list[left_idx]
            elif ratio_list[left_idx] > 1:
                # for ratio > 1, we preserve the rightmost in each batch.
                target_ratio = ratio_list[right_idx]
            else:
                # for ratio cross 1, we make it to be 1.
                target_ratio = 1

            self.ratio_list_batch[left_idx:(right_idx + 1)] = target_ratio

    def __getitem__(self, index):
        valid_flag = False
        while not valid_flag:
            if self.training:
                index_ratio = int(self.ratio_index[index])
            else:
                index_ratio = index

            # get the anchor index for current sample index
            # here we set the anchor index to the last one
            # sample in this group
            minibatch_db = [self._roidb[index_ratio]]
            blobs = get_minibatch(minibatch_db, self._num_classes)
            data = torch.from_numpy(blobs['data'])
            im_info = torch.from_numpy(blobs['im_info'])

            # shuffle the bounding boxes and relation triples
            data_height, data_width = data.size(1), data.size(2)

            if data_height > 0 and data_width > 0:
                valid_flag = True
            else:
                index = (index + 1) if index < len(self.ratio_list) else (index - self.batch_size)

        if self.training:
            np.random.shuffle(blobs['obj_gt_boxes'])
            obj_gt_boxes = torch.from_numpy(blobs['obj_gt_boxes'])
            obj_gt_boxes[:, 0].clamp_(0, data_width)
            obj_gt_boxes[:, 2].clamp_(0, data_width)
            obj_gt_boxes[:, 1].clamp_(0, data_height)
            obj_gt_boxes[:, 3].clamp_(0, data_height)

            np.random.shuffle(blobs['region_gt_boxes'])
            region_gt_boxes = torch.from_numpy(blobs['region_gt_boxes'])
            region_gt_boxes[:, 0].clamp_(0, data_width)
            region_gt_boxes[:, 2].clamp_(0, data_width)
            region_gt_boxes[:, 1].clamp_(0, data_height)
            region_gt_boxes[:, 3].clamp_(0, data_height)

            np.random.shuffle(blobs['relation_triples'])
            relation_triples = torch.from_numpy(blobs['relation_triples'])

            ########################################################
            # padding the input image to fixed size for each group #
            ########################################################

            # if the image need to crop, crop to the target size
            ratio = self.ratio_list_batch[index]

            if self._roidb[index_ratio]['need_crop']:
                if ratio < 1:
                    # crop the data_height
                    obj_min_y = int(torch.min(obj_gt_boxes[:, 1]))
                    obj_max_y = int(torch.max(obj_gt_boxes[:, 3]))
                    region_min_y = int(torch.min(region_gt_boxes[:, 1]))
                    region_max_y = int(torch.max(region_gt_boxes[:, 3]))
                    min_y = min(obj_min_y, region_min_y)
                    max_y = max(obj_max_y, region_max_y)

                    trim_size = int(np.floor(data_width / ratio))
                    if trim_size > data_height:
                        trim_size = data_height
                    box_region = max_y - min_y + 1
                    if min_y == 0:
                        y_s = 0
                    else:
                        if (box_region - trim_size) < 0:
                            y_s_min = max(max_y - trim_size, 0)
                            y_s_max = min(min_y, data_height - trim_size)
                            if y_s_min == y_s_max:
                                y_s = y_s_min
                            else:
                                y_s = np.random.choice(range(y_s_min, y_s_max))
                        else:
                            y_s_add = int((box_region - trim_size) / 2)
                            if y_s_add == 0:
                                y_s = min_y
                            else:
                                y_s = np.random.choice(range(min_y, min_y + y_s_add))
                    # crop the image
                    data = data[:, y_s:(y_s + trim_size), :, :]

                    # shift y coordinate of gt boxes
                    obj_gt_boxes[:, 1] = obj_gt_boxes[:, 1] - float(y_s)
                    obj_gt_boxes[:, 3] = obj_gt_boxes[:, 3] - float(y_s)
                    region_gt_boxes[:, 1] = region_gt_boxes[:, 1] - float(y_s)
                    region_gt_boxes[:, 3] = region_gt_boxes[:, 3] - float(y_s)

                    # clip the y coordinate of bounding boxes
                    obj_gt_boxes[:, 1].clamp_(0, trim_size - 1)
                    obj_gt_boxes[:, 3].clamp_(0, trim_size - 1)
                    region_gt_boxes[:, 1].clamp_(0, trim_size - 1)
                    region_gt_boxes[:, 3].clamp_(0, trim_size - 1)
                else:
                    # crop the data_width
                    obj_min_x = int(torch.min(obj_gt_boxes[:, 0]))
                    obj_max_x = int(torch.max(obj_gt_boxes[:, 2]))
                    region_min_x = int(torch.min(region_gt_boxes[:, 0]))
                    region_max_x = int(torch.max(region_gt_boxes[:, 2]))
                    min_x = min(obj_min_x, region_min_x)
                    max_x = max(obj_max_x, region_max_x)

                    trim_size = int(np.ceil(data_height * ratio))
                    if trim_size > data_width:
                        trim_size = data_width
                    box_region = max_x - min_x + 1
                    if min_x == 0:
                        x_s = 0
                    else:
                        if (box_region - trim_size) < 0:
                            x_s_min = max(max_x - trim_size, 0)
                            x_s_max = min(min_x, data_width - trim_size)
                            if x_s_min == x_s_max:
                                x_s = x_s_min
                            else:
                                x_s = np.random.choice(range(x_s_min, x_s_max))
                        else:
                            x_s_add = int((box_region - trim_size) / 2)
                            if x_s_add == 0:
                                x_s = min_x
                            else:
                                x_s = np.random.choice(range(min_x, min_x + x_s_add))
                    # crop the image
                    data = data[:, :, x_s:(x_s + trim_size), :]

                    # shift x coordinate of gt boxes
                    obj_gt_boxes[:, 0] = obj_gt_boxes[:, 0] - float(x_s)
                    obj_gt_boxes[:, 2] = obj_gt_boxes[:, 2] - float(x_s)
                    region_gt_boxes[:, 0] = region_gt_boxes[:, 0] - float(x_s)
                    region_gt_boxes[:, 2] = region_gt_boxes[:, 2] - float(x_s)

                    # clip the x coordinate of bounding boxes
                    obj_gt_boxes[:, 0].clamp_(0, trim_size - 1)
                    obj_gt_boxes[:, 2].clamp_(0, trim_size - 1)
                    region_gt_boxes[:, 0].clamp_(0, trim_size - 1)
                    region_gt_boxes[:, 2].clamp_(0, trim_size - 1)

            # based on the ratio, padding the image.
            if ratio < 1:
                # this means that data_width < data_height
                trim_size = int(np.floor(data_width / ratio))

                padding_data = torch.FloatTensor(int(np.ceil(data_width / ratio)), \
                                                 data_width, 3).zero_()

                padding_data[:data_height, :, :] = data[0]
                # update im_info
                im_info[0, 0] = padding_data.size(0)
                # print("height %d %d \n" %(index, anchor_idx))
            elif ratio > 1:
                # this means that data_width > data_height
                # if the image need to crop.
                padding_data = torch.FloatTensor(data_height, \
                                                 int(np.ceil(data_height * ratio)), 3).zero_()
                try:
                    padding_data[:, :data_width, :] = data[0]
                except:
                    raise ValueError('Padding Error')
                im_info[0, 1] = padding_data.size(1)
            else:
                trim_size = min(data_height, data_width)
                padding_data = torch.FloatTensor(trim_size, trim_size, 3).zero_()
                padding_data = data[0][:trim_size, :trim_size, :]
                # gt_boxes.clamp_(0, trim_size)
                obj_gt_boxes[:, :4].clamp_(0, trim_size)
                region_gt_boxes[:, :4].clamp_(0, trim_size)
                im_info[0, 0] = trim_size
                im_info[0, 1] = trim_size

            # check the object bounding box
            obj_not_keep = (obj_gt_boxes[:, 0] == obj_gt_boxes[:, 2]) | (obj_gt_boxes[:, 1] == obj_gt_boxes[:, 3])
            obj_keep = torch.nonzero(obj_not_keep == 0).view(-1)

            obj_gt_boxes_padding = torch.FloatTensor(self.max_num_box, obj_gt_boxes.size(1)).zero_()
            if obj_keep.numel() != 0:
                obj_gt_boxes = obj_gt_boxes[obj_keep]
                obj_num_boxes = min(obj_gt_boxes.size(0), self.max_num_box)
                obj_gt_boxes_padding[:obj_num_boxes, :] = obj_gt_boxes[:obj_num_boxes]
            else:
                obj_num_boxes = 0

            # check the region bounding boxes
            region_not_keep = (region_gt_boxes[:, 0] == region_gt_boxes[:, 2]) | \
                              (region_gt_boxes[:, 1] == region_gt_boxes[:, 3])
            region_keep = torch.nonzero(region_not_keep == 0).view(-1)

            region_gt_boxes_padding = torch.FloatTensor(self.max_num_box, region_gt_boxes.size(1)).zero_()
            if region_keep.numel() != 0:
                region_gt_boxes = region_gt_boxes[region_keep]
                region_num_boxes = min(region_gt_boxes.size(0), self.max_num_box)
                region_gt_boxes_padding[:region_num_boxes, :] = region_gt_boxes[:region_num_boxes]
            else:
                region_num_boxes = 0

            # pad the relation triples
            relation_triples_padding = torch.FloatTensor(self.max_num_relation, relation_triples.size(1)).zero_()
            relation_num = min(relation_triples.size(0), self.max_num_relation)
            if relation_num > 0:
                relation_triples_padding[:relation_num, :] = relation_triples[:relation_num]

            # permute trim_data to adapt to downstream processing
            padding_data = padding_data.permute(2, 0, 1).contiguous()
            im_info = im_info.view(3)

            return padding_data, im_info, obj_gt_boxes_padding, obj_num_boxes, region_gt_boxes_padding, region_num_boxes, \
                   relation_triples_padding, relation_num
        else:
            data = data.permute(0, 3, 1, 2).contiguous().view(3, data_height, data_width)
            im_info = im_info.view(3)

            # obj_gt_boxes = torch.FloatTensor([1, 1, 1, 1, 1, 1])
            # obj_num_boxes = 0
            #
            # region_gt_boxes = torch.FloatTensor([1, 1, 1, 1, 1])
            # region_num_boxes = 0
            #
            # relation_triples = torch.FloatTensor([1, 1, 1])

            obj_gt_boxes = torch.from_numpy(blobs['obj_gt_boxes'])
            obj_gt_boxes[:, 0].clamp_(0, data_width)
            obj_gt_boxes[:, 2].clamp_(0, data_width)
            obj_gt_boxes[:, 1].clamp_(0, data_height)
            obj_gt_boxes[:, 3].clamp_(0, data_height)
            obj_num_boxes = obj_gt_boxes.size(0)

            region_gt_boxes = torch.from_numpy(blobs['region_gt_boxes'])
            region_gt_boxes[:, 0].clamp_(0, data_width)
            region_gt_boxes[:, 2].clamp_(0, data_width)
            region_gt_boxes[:, 1].clamp_(0, data_height)
            region_gt_boxes[:, 3].clamp_(0, data_height)
            region_num_boxes = region_gt_boxes.size(0)

            relation_triples = torch.from_numpy(blobs['relation_triples'])
            relation_num = relation_triples.size(0)

            return data, im_info, obj_gt_boxes, obj_num_boxes, region_gt_boxes, region_num_boxes, relation_triples, \
                   relation_num

    def __len__(self):
        return len(self._roidb)
