
import torch
from torchvision.ops.boxes import box_area
import math

# modified from torchvision to also return the union
'''Note that this function only supports shape (N,4)'''


def box_iou(boxes1, boxes2):
    """

    :param boxes1: (N, 4) (x1,y1,x2,y2)
    :param boxes2: (N, 4) (x1,y1,x2,y2)
    :return:
    """
    area1 = box_area(boxes1) # (N,)
    area2 = box_area(boxes2) # (N,)

    lt = torch.max(boxes1[:, :2], boxes2[:, :2])  # (N,2)
    rb = torch.min(boxes1[:, 2:], boxes2[:, 2:])  # (N,2)

    wh = (rb - lt).clamp(min=0)  # (N,2)
    inter = wh[:, 0] * wh[:, 1]  # (N,)

    union = area1 + area2 - inter

    iou = inter / union
    return iou, union


'''Note that this implementation is different from DETR's'''


def generalized_box_iou(boxes1, boxes2):
    """
    Generalized IoU from https://giou.stanford.edu/

    The boxes should be in [x0, y0, x1, y1] format

    boxes1: (N, 4)
    boxes2: (N, 4)
    """
    # degenerate boxes gives inf / nan results
    # so do an early check
    # try:
    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
    iou, union = box_iou(boxes1, boxes2) # (N,)

    lt = torch.min(boxes1[:, :2], boxes2[:, :2])
    rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])

    wh = (rb - lt).clamp(min=0)  # (N,2)
    area = wh[:, 0] * wh[:, 1] # (N,)

    return iou - (area - union) / area, iou


def giou_loss(boxes1, boxes2):
    """

    :param boxes1: (N, 4) (x1,y1,x2,y2)
    :param boxes2: (N, 4) (x1,y1,x2,y2)
    :return:
    """
    giou, iou = generalized_box_iou(boxes1, boxes2)
    return (1 - giou).mean(), iou

def ciou_loss(boxa, boxb):
    device = boxa.device
    eps = torch.tensor([1e-8],device=device)
    inter_x1, inter_y1 = torch.maximum(boxa[:,0], boxb[:,0]), torch.maximum(boxa[:,1], boxb[:,1])
    inter_x2, inter_y2 = torch.minimum(boxa[:,2], boxb[:,2]), torch.minimum(boxa[:,3], boxb[:,3])
    inter_h = torch.maximum(torch.tensor([0],device=device), inter_y2 - inter_y1 + 1.0)
    inter_w = torch.maximum(torch.tensor([0],device=device), inter_x2 - inter_x1 + 1.0)
    inter_area = inter_w * inter_h

    # 求并集
    union_area = ((boxa[:,3] - boxa[:,1] + 1.0) * (boxa[:,2] - boxa[:,0] + 1.0))+((boxb[:,3] - boxb[:,1] + 1.0) * (boxb[:,2] - boxb[:,0] + 1.0)) - inter_area + eps # + 1e-8 防止除零

    # 求最小闭包区域的x1,y1,x2,y2
    ac_x1, ac_y1 = torch.minimum(boxa[:,0], boxb[:,0]), torch.minimum(boxa[:,1], boxb[:,1])
    ac_x2, ac_y2 = torch.maximum(boxa[:,2], boxb[:,2]), torch.maximum(boxa[:,3], boxb[:,3])

    # 把两个bbox的x1,y1,x2,y2转换成ctr_x,ctr_y
    boxa_ctrx, boxa_ctry = boxa[:,0] + (boxa[:,2] - boxa[:,0]) / 2, boxa[:,1] + (boxa[:,3] - boxa[:,1]) / 2
    boxb_ctrx, boxb_ctry = boxb[:,0] + (boxb[:,2] - boxb[:,0]) / 2, boxb[:,1] + (boxb[:,3] - boxb[:,1]) / 2
    boxa_w, boxa_h = boxa[:,2] - boxa[:,0] + 1.0, boxa[:,3] - boxa[:,1] + 1.0
    boxb_w, boxb_h = boxb[:,2] - boxb[:,0] + 1.0, boxb[:,3] - boxb[:,1] + 1.0

    # 求两个box中心点距离平方length_box_ctr，最小闭包区域对角线距离平方length_ac
    length_box_ctr = (boxb_ctrx - boxa_ctrx) * (boxb_ctrx - boxa_ctrx) + (boxb_ctry - boxa_ctry) * (boxb_ctry - boxa_ctry)
    length_ac = (ac_x2 - ac_x1) * (ac_x2 - ac_x1) + (ac_y2 - ac_y1) * (ac_y2 - ac_y1) + eps

    a = (torch.atan(boxa_w / (boxa_h + eps)) - torch.atan(boxb_w / (boxb_h + eps)))
    v = (4 / (math.pi * math.pi)) * a * a
    iou = inter_area / (union_area + eps)
    alpha = v / ((1 - iou) + v)
    ciou = iou - length_box_ctr / length_ac - alpha * v
    ciou = torch.clamp(ciou,min=-1.0,max=1.0)
    # ciou_loss = 1 - ciou
    # ciou_loss = ciou_loss.sum()
    
    return torch.mean(1-ciou), iou