class weighted_boxes_fusion():
    def __init__(self,):
        pass
        self.clusters = []
        self.fusions = []
    def update_fusion(self,cluster):
        sum_score=0
        new_box=[0,0]
        for box in cluster:
            sum_score+=box[-1]
        for box in cluster:
            new_box[0]+=box[0]*box[-1]/sum_score
            new_box[1]+=box[1]*box[-1]/sum_score
        label=self.get_cluster_label(cluster)
        return new_box+[label,sum_score/len(cluster)]
    def add_boxes(self, box, label, score, iou_thr=0.8):
        
        find_cluster=False
        for cluster,fusion in zip(self.clusters,self.fusions):
            left_bounud = fusion[0]
            right_bound = fusion[1]
            # 计算box的iou
            if box[1]<left_bounud or box[0]>right_bound:
                continue
            iou=(box[1]-left_bounud)/(right_bound-box[0]+0.000001)
            if iou>1:
                iou=1/iou
            if iou>iou_thr :
                # fusion
                cluster.append([box[0],box[1],label, score])
                new_fusion=self.update_fusion(cluster)
                fusion=new_fusion
                find_cluster = True
                break
        if find_cluster == False:
            self.clusters.append([[box[0],box[1],label, score]])
            self.fusions.append([box[0],box[1],label,score])
        return self.fusions
        
    def get_cluster_label(self,cluster):
            # label=cluster[0][3]
        label_dict={}
        for box in cluster:
            if box[3] not in label_dict:
                label_dict[box[3]] = box[-1]
            else:
                label_dict[box[3]] += box[-1]
        max_key = max(label_dict, key=label_dict.get)
        return max_key