# Copyright 2021 Toyota Research Institute.  All rights reserved.
import torch

from mmcv.layers import cat

from adzoo.bevformer.mmdet3d_plugin.dd3d.structures.boxes3d import Boxes3D

INF = 100000000.


class DD3DTargetPreparer():
    def __init__(self, 
                 num_classes, 
                 input_shape,
                 box3d_on=True,
                 center_sample=True,
                 pos_radius=1.5,
                 sizes_of_interest=None):
        self.num_classes = num_classes
        self.center_sample = center_sample
        self.strides = [shape.stride for shape in input_shape]
        self.radius = pos_radius
        self.dd3d_enabled = box3d_on

        # generate sizes of interest
        # NOTE:
        # soi = []
        # prev_size = -1
        # for s in sizes_of_interest:
        #     soi.append([prev_size, s])
        #     prev_size = s
        # soi.append([prev_size, INF])
        self.sizes_of_interest = sizes_of_interest

    def __call__(self, locations, gt_instances, feature_shapes):
        num_loc_list = [len(loc) for loc in locations]

        # compute locations to size ranges
        loc_to_size_range = []
        for l, loc_per_level in enumerate(locations):
            loc_to_size_range_per_level = loc_per_level.new_tensor(self.sizes_of_interest[l])
            loc_to_size_range.append(loc_to_size_range_per_level[None].expand(num_loc_list[l], -1))

        loc_to_size_range = torch.cat(loc_to_size_range, dim=0)
        locations = torch.cat(locations, dim=0)

        training_targets = self.compute_targets_for_locations(locations, gt_instances, loc_to_size_range, num_loc_list)

        training_targets["locations"] = [locations.clone() for _ in range(len(gt_instances))]
        training_targets["im_inds"] = [
            locations.new_ones(locations.size(0), dtype=torch.long) * i for i in range(len(gt_instances))
        ]

        box2d = training_targets.pop("box2d", None)

        # transpose im first training_targets to level first ones
        training_targets = {k: self._transpose(v, num_loc_list) for k, v in training_targets.items() if k != "box2d"}

        training_targets["fpn_levels"] = [
            loc.new_ones(len(loc), dtype=torch.long) * level for level, loc in enumerate(training_targets["locations"])
        ]

        # Flatten targets: (L x B x H x W, TARGET_SIZE)
        labels = cat([x.reshape(-1) for x in training_targets["labels"]])
        box2d_reg_targets = cat([x.reshape(-1, 4) for x in training_targets["box2d_reg"]])

        target_inds = cat([x.reshape(-1) for x in training_targets["target_inds"]])
        locations = cat([x.reshape(-1, 2) for x in training_targets["locations"]])
        im_inds = cat([x.reshape(-1) for x in training_targets["im_inds"]])
        fpn_levels = cat([x.reshape(-1) for x in training_targets["fpn_levels"]])

        pos_inds = torch.nonzero(labels != self.num_classes).squeeze(1)

        targets = {
            "labels": labels,
            "box2d_reg_targets": box2d_reg_targets,
            "locations": locations,
            "target_inds": target_inds,
            "im_inds": im_inds,
            "fpn_levels": fpn_levels,
            "pos_inds": pos_inds
        }

        if self.dd3d_enabled:
            box3d_targets = Boxes3D.cat(training_targets["box3d"])
            targets.update({"box3d_targets": box3d_targets})

            if box2d is not None:
                # Original format is B x L x (H x W, 4)
                # Need to be in L x (B, 4, H, W).
                batched_box2d = []
                for lvl, per_lvl_box2d in enumerate(zip(*box2d)):
                    # B x (H x W, 4)
                    h, w = feature_shapes[lvl]
                    batched_box2d_lvl = torch.stack([x.T.reshape(4, h, w) for x in per_lvl_box2d], dim=0)
                    batched_box2d.append(batched_box2d_lvl)
                targets.update({"batched_box2d": batched_box2d})

        return targets

    def compute_targets_for_locations(self, locations, targets, size_ranges, num_loc_list):
        labels = []
        box2d_reg = []

        if self.dd3d_enabled:
            box3d = []

        target_inds = []
        xs, ys = locations[:, 0], locations[:, 1]

        num_targets = 0
        for im_i in range(len(targets)):
            targets_per_im = targets[im_i]
            bboxes = targets_per_im.gt_boxes.tensor
            labels_per_im = targets_per_im.gt_classes

            # no gt
            if bboxes.numel() == 0:
                labels.append(labels_per_im.new_zeros(locations.size(0)) + self.num_classes)
                # reg_targets.append(locations.new_zeros((locations.size(0), 4)))
                box2d_reg.append(locations.new_zeros((locations.size(0), 4)))
                target_inds.append(labels_per_im.new_zeros(locations.size(0)) - 1)

                if self.dd3d_enabled:
                    box3d.append(
                        Boxes3D(
                            locations.new_zeros(locations.size(0), 4),
                            locations.new_zeros(locations.size(0), 2),
                            locations.new_zeros(locations.size(0), 1),
                            locations.new_zeros(locations.size(0), 3),
                            locations.new_zeros(locations.size(0), 3, 3),
                        ).to(torch.float32)
                    )
                continue

            area = targets_per_im.gt_boxes.area()

            l = xs[:, None] - bboxes[:, 0][None]
            t = ys[:, None] - bboxes[:, 1][None]
            r = bboxes[:, 2][None] - xs[:, None]
            b = bboxes[:, 3][None] - ys[:, None]
            # reg_targets_per_im = torch.stack([l, t, r, b], dim=2)
            box2d_reg_per_im = torch.stack([l, t, r, b], dim=2)

            if self.center_sample:
                is_in_boxes = self.get_sample_region(bboxes, num_loc_list, xs, ys)
            else:
                is_in_boxes = box2d_reg_per_im.min(dim=2)[0] > 0

            max_reg_targets_per_im = box2d_reg_per_im.max(dim=2)[0]
            # limit the regression range for each location
            is_cared_in_the_level = \
                (max_reg_targets_per_im >= size_ranges[:, [0]]) & \
                (max_reg_targets_per_im <= size_ranges[:, [1]])

            locations_to_gt_area = area[None].repeat(len(locations), 1)
            locations_to_gt_area[is_in_boxes == 0] = INF
            locations_to_gt_area[is_cared_in_the_level == 0] = INF

            # if there are still more than one objects for a location,
            # we choose the one with minimal area
            locations_to_min_area, locations_to_gt_inds = locations_to_gt_area.min(dim=1)

            box2d_reg_per_im = box2d_reg_per_im[range(len(locations)), locations_to_gt_inds]
            target_inds_per_im = locations_to_gt_inds + num_targets
            num_targets += len(targets_per_im)

            labels_per_im = labels_per_im[locations_to_gt_inds]
            labels_per_im[locations_to_min_area == INF] = self.num_classes

            labels.append(labels_per_im)
            box2d_reg.append(box2d_reg_per_im)
            target_inds.append(target_inds_per_im)

            if self.dd3d_enabled:
                # 3D box targets
                box3d_per_im = targets_per_im.gt_boxes3d[locations_to_gt_inds]
                box3d.append(box3d_per_im)

        ret = {"labels": labels, "box2d_reg": box2d_reg, "target_inds": target_inds}
        if self.dd3d_enabled:
            ret.update({"box3d": box3d})

        return ret

    def get_sample_region(self, boxes, num_loc_list, loc_xs, loc_ys):
        center_x = boxes[..., [0, 2]].sum(dim=-1) * 0.5
        center_y = boxes[..., [1, 3]].sum(dim=-1) * 0.5

        num_gts = boxes.shape[0]
        K = len(loc_xs)
        boxes = boxes[None].expand(K, num_gts, 4)
        center_x = center_x[None].expand(K, num_gts)
        center_y = center_y[None].expand(K, num_gts)
        center_gt = boxes.new_zeros(boxes.shape)
        # no gt
        if center_x.numel() == 0 or center_x[..., 0].sum() == 0:
            return loc_xs.new_zeros(loc_xs.shape, dtype=torch.uint8)
        beg = 0
        for level, num_loc in enumerate(num_loc_list):
            end = beg + num_loc
            stride = self.strides[level] * self.radius
            xmin = center_x[beg:end] - stride
            ymin = center_y[beg:end] - stride
            xmax = center_x[beg:end] + stride
            ymax = center_y[beg:end] + stride
            # limit sample region in gt
            center_gt[beg:end, :, 0] = torch.where(xmin > boxes[beg:end, :, 0], xmin, boxes[beg:end, :, 0])
            center_gt[beg:end, :, 1] = torch.where(ymin > boxes[beg:end, :, 1], ymin, boxes[beg:end, :, 1])
            center_gt[beg:end, :, 2] = torch.where(xmax > boxes[beg:end, :, 2], boxes[beg:end, :, 2], xmax)
            center_gt[beg:end, :, 3] = torch.where(ymax > boxes[beg:end, :, 3], boxes[beg:end, :, 3], ymax)
            beg = end
        left = loc_xs[:, None] - center_gt[..., 0]
        right = center_gt[..., 2] - loc_xs[:, None]
        top = loc_ys[:, None] - center_gt[..., 1]
        bottom = center_gt[..., 3] - loc_ys[:, None]
        center_bbox = torch.stack((left, top, right, bottom), -1)
        inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
        return inside_gt_bbox_mask

    def _transpose(self, training_targets, num_loc_list):
        '''
        This function is used to transpose image first training targets to level first ones
        :return: level first training targets
        '''
        if isinstance(training_targets[0], Boxes3D):
            for im_i in range(len(training_targets)):
                # training_targets[im_i] = torch.split(training_targets[im_i], num_loc_list, dim=0)
                training_targets[im_i] = training_targets[im_i].split(num_loc_list, dim=0)

            targets_level_first = []
            for targets_per_level in zip(*training_targets):
                targets_level_first.append(Boxes3D.cat(targets_per_level, dim=0))
            return targets_level_first

        for im_i in range(len(training_targets)):
            training_targets[im_i] = torch.split(training_targets[im_i], num_loc_list, dim=0)

        targets_level_first = []
        for targets_per_level in zip(*training_targets):
            targets_level_first.append(torch.cat(targets_per_level, dim=0))
        return targets_level_first
