
"""Utilities to manipulate and convert boxes"""
from collections import defaultdict
from typing import Any, Dict

import torch
from torchvision.ops.boxes import box_iou

from .unionfind import UnionFind


def obj_to_box(obj: Dict[str, Any]):
    """Extract the bounding box of a given object as a list"""
    return [obj["x"], obj["y"], obj["w"], obj["h"]]


def region_to_box(obj: Dict[str, Any]):
    """Extract the bounding box of a given region as a list"""
    return [obj["x"], obj["y"], obj["width"], obj["height"]]


def get_boxes_equiv(orig_boxes, iou_threshold):
    """Given a set of boxes, returns a dict containing clusters of boxes that are highly overlapping.
    For optimization, return None if none of the boxes are overlapping
    A high overlap is characterized by the iou_threshold
    Boxes are expected as [top_left_x, top_left_y, width, height]
    """
    boxes = torch.as_tensor(orig_boxes, dtype=torch.float)
    
    boxes[:, 2:] += boxes[:, :2]
    ious = box_iou(boxes, boxes)
    uf = UnionFind(len(boxes))
    for i in range(len(boxes)):
        for j in range(i + 1, len(boxes)):
            if ious[i][j] >= iou_threshold:
                uf.unite(i, j)
    if len(orig_boxes) == uf.nb_compo:
        
        
        return None, None
    
    compo2boxes = defaultdict(list)
    compo2id = defaultdict(list)

    for i in range(len(boxes)):
        compo2boxes[uf.find(i)].append(boxes[i])
        compo2id[uf.find(i)].append(i)
    assert len(compo2boxes) == uf.nb_compo
    return compo2boxes, compo2id


def xyxy_to_xywh(boxes: torch.Tensor):
    """Converts a set of boxes in [top_left_x, top_left_y, bottom_right_x, bottom_right_y] format to
    [top_left_x, top_left_y, width, height] format"""
    assert boxes.shape[-1] == 4
    converted = boxes.clone()
    converted[..., 2:] -= converted[..., :2]
    return converted


def combine_boxes(orig_boxes, iou_threshold=0.7):
    """Given a set of boxes, returns the average of all clusters of boxes that are highly overlapping.
    A high overlap is characterized by the iou_threshold
    Boxes are expected as [top_left_x, top_left_y, width, height]
    """
    compo2boxes, _ = get_boxes_equiv(orig_boxes, iou_threshold)
    if compo2boxes is None:
        return orig_boxes
    result_boxes = []
    for box_list in compo2boxes.values():
        result_boxes.append(xyxy_to_xywh(torch.stack(box_list, 0).mean(0)).tolist())
    return result_boxes


def box_iou_helper(b1, b2):
    """returns the iou matrix between two sets of boxes
    The boxes are expected in the format [top_left_x, top_left_y, w, h]
    """
    boxes_r1 = torch.as_tensor(b1, dtype=torch.float)
    
    boxes_r1[:, 2:] += boxes_r1[:, :2]
    boxes_r2 = torch.as_tensor(b2, dtype=torch.float)
    
    boxes_r2[:, 2:] += boxes_r2[:, :2]
    return box_iou(boxes_r1, boxes_r2)
